aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
authornsfisis <nsfisis@gmail.com>2025-09-13 13:17:28 +0900
committernsfisis <nsfisis@gmail.com>2025-09-13 13:17:28 +0900
commita7c9c3407582f0d8d66539cf90e86fe3100787c5 (patch)
treeb8b4132ee0c116fae50f45107f3ae89a9a1911a6
parent8de7fa9da5fd8015f4fcc826b9270061b7b89478 (diff)
downloadducc-a7c9c3407582f0d8d66539cf90e86fe3100787c5.tar.gz
ducc-a7c9c3407582f0d8d66539cf90e86fe3100787c5.tar.zst
ducc-a7c9c3407582f0d8d66539cf90e86fe3100787c5.zip
feat: implement cast expression
-rw-r--r--src/ast.c7
-rw-r--r--src/ast.h2
-rw-r--r--src/codegen.c42
-rw-r--r--src/parse.c3
-rw-r--r--tests/test_cast_expressions.sh154
5 files changed, 206 insertions, 2 deletions
diff --git a/src/ast.c b/src/ast.c
index 84fe38e..d698543 100644
--- a/src/ast.c
+++ b/src/ast.c
@@ -282,6 +282,13 @@ AstNode* ast_new_member_access_expr(AstNode* obj, const char* name) {
return e;
}
+AstNode* ast_new_cast_expr(AstNode* operand, Type* result_ty) {
+ AstNode* e = ast_new(AstNodeKind_cast_expr);
+ e->node_operand = operand;
+ e->ty = result_ty;
+ return e;
+}
+
int type_sizeof_struct(Type* ty) {
int next_offset = 0;
int struct_align = 0;
diff --git a/src/ast.h b/src/ast.h
index e0d8410..ce84739 100644
--- a/src/ast.h
+++ b/src/ast.h
@@ -89,6 +89,7 @@ typedef enum {
AstNodeKind_assign_expr,
AstNodeKind_binary_expr,
AstNodeKind_break_stmt,
+ AstNodeKind_cast_expr,
AstNodeKind_cond_expr,
AstNodeKind_continue_stmt,
AstNodeKind_deref_expr,
@@ -178,5 +179,6 @@ AstNode* ast_new_assign_sub_expr(AstNode* lhs, AstNode* rhs);
AstNode* ast_new_ref_expr(AstNode* operand);
AstNode* ast_new_deref_expr(AstNode* operand);
AstNode* ast_new_member_access_expr(AstNode* obj, const char* name);
+AstNode* ast_new_cast_expr(AstNode* operand, Type* result_ty);
#endif
diff --git a/src/codegen.c b/src/codegen.c
index 000126b..dccabe4 100644
--- a/src/codegen.c
+++ b/src/codegen.c
@@ -134,6 +134,46 @@ static void codegen_deref_expr(CodeGen* g, AstNode* ast, GenMode gen_mode) {
}
}
+static void codegen_cast_expr(CodeGen* g, AstNode* ast) {
+ codegen_expr(g, ast->node_operand, GenMode_rval);
+
+ int src_size = type_sizeof(ast->node_operand->ty);
+ int dst_size = type_sizeof(ast->ty);
+
+ if (src_size == dst_size)
+ return;
+
+ fprintf(g->out, " pop rax\n");
+
+ if (dst_size == 1) {
+ fprintf(g->out, " movsx rax, al\n");
+ } else if (dst_size == 2) {
+ if (src_size == 1) {
+ fprintf(g->out, " movsx rax, al\n");
+ } else {
+ fprintf(g->out, " movsx rax, ax\n");
+ }
+ } else if (dst_size == 4) {
+ if (src_size == 1) {
+ fprintf(g->out, " movsx rax, al\n");
+ } else if (src_size == 2) {
+ fprintf(g->out, " movsx rax, ax\n");
+ } else {
+ fprintf(g->out, " movsxd rax, eax\n");
+ }
+ } else if (dst_size == 8) {
+ if (src_size == 1) {
+ fprintf(g->out, " movsx rax, al\n");
+ } else if (src_size == 2) {
+ fprintf(g->out, " movsx rax, ax\n");
+ } else if (src_size == 4) {
+ fprintf(g->out, " movsxd rax, eax\n");
+ }
+ }
+
+ fprintf(g->out, " push rax\n");
+}
+
static void codegen_logical_expr(CodeGen* g, AstNode* ast) {
int label = codegen_new_label(g);
@@ -406,6 +446,8 @@ static void codegen_expr(CodeGen* g, AstNode* ast, GenMode gen_mode) {
codegen_ref_expr(g, ast, gen_mode);
} else if (ast->kind == AstNodeKind_deref_expr) {
codegen_deref_expr(g, ast, gen_mode);
+ } else if (ast->kind == AstNodeKind_cast_expr) {
+ codegen_cast_expr(g, ast);
} else if (ast->kind == AstNodeKind_binary_expr) {
codegen_binary_expr(g, ast, gen_mode);
} else if (ast->kind == AstNodeKind_cond_expr) {
diff --git a/src/parse.c b/src/parse.c
index b6fa043..b65a6b3 100644
--- a/src/parse.c
+++ b/src/parse.c
@@ -577,8 +577,7 @@ static AstNode* parse_cast_expr(Parser* p) {
// TODO: check whether the original type can be casted to the result type.
AstNode* e = parse_cast_expr(p);
- e->ty = ty;
- return e;
+ return ast_new_cast_expr(e, ty);
}
return parse_prefix_expr(p);
}
diff --git a/tests/test_cast_expressions.sh b/tests/test_cast_expressions.sh
new file mode 100644
index 0000000..f6824ad
--- /dev/null
+++ b/tests/test_cast_expressions.sh
@@ -0,0 +1,154 @@
+cat <<'EOF' > expected
+65
+65
+127
+1
+42
+99
+10
+EOF
+test_diff <<'EOF'
+int printf(const char*, ...);
+
+int main() {
+ char c = 65;
+ int i = (int)c;
+ printf("%d\n", i);
+
+ int i2 = 321;
+ char c2 = (char)i2;
+ printf("%d\n", c2);
+
+ short s = 127;
+ int i3 = (int)s;
+ printf("%d\n", i3);
+
+ int i4 = 65537;
+ short s2 = (short)i4;
+ printf("%d\n", s2);
+
+ long l = 42;
+ int i5 = (int)l;
+ printf("%d\n", i5);
+
+ int i6 = 99;
+ long l2 = (long)i6;
+ printf("%d\n", (int)l2);
+
+ char c3 = 10;
+ short s3 = (short)c3;
+ int i7 = (int)s3;
+ long l3 = (long)i7;
+ printf("%d\n", (int)l3);
+
+ return 0;
+}
+EOF
+
+cat <<'EOF' > expected
+Result: 130
+EOF
+test_diff <<'EOF'
+int printf(const char*, ...);
+
+int main() {
+ char c = 65;
+ int result = (int)c + (int)c;
+ printf("Result: %d\n", result);
+ return 0;
+}
+EOF
+
+cat <<'EOF' > expected
+10
+20
+30
+EOF
+test_diff <<'EOF'
+int printf(const char*, ...);
+
+int main() {
+ char a = 5;
+ char b = 5;
+ int sum = (int)a + (int)b;
+ printf("%d\n", sum);
+
+ short s1 = 10;
+ short s2 = 10;
+ int sum2 = (int)s1 + (int)s2;
+ printf("%d\n", sum2);
+
+ long l1 = 15;
+ long l2 = 15;
+ int sum3 = (int)(l1 + l2);
+ printf("%d\n", sum3);
+
+ return 0;
+}
+EOF
+
+cat <<'EOF' > expected
+10
+EOF
+test_diff <<'EOF'
+int printf(const char*, ...);
+
+int main() {
+ char c = -10;
+ int i = (int)c;
+ printf("%d\n", -i);
+ return 0;
+}
+EOF
+
+cat <<'EOF' > expected
+Char: 65
+Int: 65
+EOF
+test_diff <<'EOF'
+int printf(const char*, ...);
+
+char get_char() {
+ return 65;
+}
+
+int main() {
+ char c = get_char();
+ int i = (int)get_char();
+ printf("Char: %d\n", c);
+ printf("Int: %d\n", i);
+ return 0;
+}
+EOF
+
+cat <<'EOF' > expected
+Equal
+EOF
+test_diff <<'EOF'
+int printf(const char*, ...);
+
+int main() {
+ char c = 42;
+ int i = 42;
+ if ((int)c == i) {
+ printf("Equal\n");
+ } else {
+ printf("Not equal\n");
+ }
+ return 0;
+}
+EOF
+
+cat <<'EOF' > expected
+55
+EOF
+test_diff <<'EOF'
+int printf(const char*, ...);
+
+int main() {
+ long l = 55;
+ char c = (char)(short)(int)l;
+ printf("%d\n", c);
+ return 0;
+}
+EOF