diff options
| author | nsfisis <nsfisis@gmail.com> | 2025-09-27 21:08:48 +0900 |
|---|---|---|
| committer | nsfisis <nsfisis@gmail.com> | 2025-09-28 10:36:35 +0900 |
| commit | 74dbe1fc92a6bb3f03f5582280f4e02b9158a523 (patch) | |
| tree | 9bc0eaca54895ac1784356846b14b6cff0281d6f | |
| parent | 931cbe657ccdcfefe4077cd7371f1ea4ad4e5b53 (diff) | |
| download | ducc-74dbe1fc92a6bb3f03f5582280f4e02b9158a523.tar.gz ducc-74dbe1fc92a6bb3f03f5582280f4e02b9158a523.tar.zst ducc-74dbe1fc92a6bb3f03f5582280f4e02b9158a523.zip | |
feat: implement switch statement
| -rw-r--r-- | src/ast.h | 3 | ||||
| -rw-r--r-- | src/codegen.c | 95 | ||||
| -rw-r--r-- | src/parse.c | 55 | ||||
| -rw-r--r-- | tests/test_switch.sh | 212 |
4 files changed, 362 insertions, 3 deletions
@@ -91,9 +91,11 @@ typedef enum { AstNodeKind_assign_expr, AstNodeKind_binary_expr, AstNodeKind_break_stmt, + AstNodeKind_case_label, AstNodeKind_cast_expr, AstNodeKind_cond_expr, AstNodeKind_continue_stmt, + AstNodeKind_default_label, AstNodeKind_deref_expr, AstNodeKind_do_while_stmt, AstNodeKind_enum_def, @@ -118,6 +120,7 @@ typedef enum { AstNodeKind_struct_decl, AstNodeKind_struct_def, AstNodeKind_struct_member, + AstNodeKind_switch_stmt, AstNodeKind_type, AstNodeKind_typedef_decl, AstNodeKind_unary_expr, diff --git a/src/codegen.c b/src/codegen.c index cc3f9a7..91f8cc6 100644 --- a/src/codegen.c +++ b/src/codegen.c @@ -14,6 +14,7 @@ typedef struct { int next_label; int* loop_labels; AstNode* current_func; + int switch_label; } CodeGen; static CodeGen* codegen_new(FILE* out) { @@ -21,6 +22,7 @@ static CodeGen* codegen_new(FILE* out) { g->out = out; g->next_label = 1; g->loop_labels = calloc(1024, sizeof(int)); + g->switch_label = -1; return g; } @@ -552,8 +554,12 @@ static void codegen_do_while_stmt(CodeGen* g, AstNode* ast) { } static void codegen_break_stmt(CodeGen* g, AstNode* ast) { - int label = *g->loop_labels; - fprintf(g->out, " jmp .Lend%d\n", label); + if (g->switch_label != -1) { + fprintf(g->out, " jmp .Lend%d\n", g->switch_label); + } else { + int label = *g->loop_labels; + fprintf(g->out, " jmp .Lend%d\n", label); + } } static void codegen_continue_stmt(CodeGen* g, AstNode* ast) { @@ -561,6 +567,86 @@ static void codegen_continue_stmt(CodeGen* g, AstNode* ast) { fprintf(g->out, " jmp .Lcontinue%d\n", label); } +// Helper to collect case values from the switch body +static void collect_cases(AstNode* stmt, int* case_values, int* case_labels, int* n_cases) { + if (!stmt) + return; + + if (stmt->kind == AstNodeKind_case_label) { + case_values[*n_cases] = stmt->node_int_value; + case_labels[*n_cases] = *n_cases + 1; + (*n_cases)++; + collect_cases(stmt->node_body, case_values, case_labels, n_cases); + } else if (stmt->kind == AstNodeKind_default_label) { + collect_cases(stmt->node_body, case_values, case_labels, n_cases); + } else if (stmt->kind == AstNodeKind_list) { + for (int i = 0; i < stmt->node_len; i++) { + collect_cases(stmt->node_items + i, case_values, case_labels, n_cases); + } + } +} + +static bool codegen_switch_body(CodeGen* g, AstNode* stmt, int* case_values, int* case_labels, int n_cases) { + if (!stmt) + return false; + + if (stmt->kind == AstNodeKind_case_label) { + int value = stmt->node_int_value; + for (int i = 0; i < n_cases; i++) { + if (case_values[i] == value) { + fprintf(g->out, ".Lcase%d_%d:\n", g->switch_label, case_labels[i]); + break; + } + } + codegen_stmt(g, stmt->node_body); + return false; + } else if (stmt->kind == AstNodeKind_default_label) { + fprintf(g->out, ".Ldefault%d:\n", g->switch_label); + codegen_stmt(g, stmt->node_body); + return true; + } else if (stmt->kind == AstNodeKind_list) { + bool default_label_emitted = false; + for (int i = 0; i < stmt->node_len; i++) { + default_label_emitted |= codegen_switch_body(g, stmt->node_items + i, case_values, case_labels, n_cases); + } + return default_label_emitted; + } else { + codegen_stmt(g, stmt); + return false; + } +} + +static void codegen_switch_stmt(CodeGen* g, AstNode* ast) { + int switch_label = codegen_new_label(g); + int prev_switch_label = g->switch_label; + g->switch_label = switch_label; + + // Collect all case values and assign labels + int case_values[256]; + int case_labels[256]; + int n_cases = 0; + collect_cases(ast->node_body, case_values, case_labels, &n_cases); + + // Generate jump instructions. + codegen_expr(g, ast->node_expr, GenMode_rval); + fprintf(g->out, " pop rax\n"); + for (int i = 0; i < n_cases; i++) { + fprintf(g->out, " cmp rax, %d\n", case_values[i]); + fprintf(g->out, " je .Lcase%d_%d\n", switch_label, case_labels[i]); + } + fprintf(g->out, " jmp .Ldefault%d\n", switch_label); + + // Generate the switch body with labels. + bool default_label_emitted = codegen_switch_body(g, ast->node_body, case_values, case_labels, n_cases); + + if (!default_label_emitted) { + fprintf(g->out, ".Ldefault%d:\n", switch_label); + } + fprintf(g->out, ".Lend%d:\n", switch_label); + + g->switch_label = prev_switch_label; +} + static void codegen_expr_stmt(CodeGen* g, AstNode* ast) { codegen_expr(g, ast->node_expr, GenMode_rval); // TODO: the expression on the stack can be more than 8 bytes. @@ -587,6 +673,8 @@ static void codegen_stmt(CodeGen* g, AstNode* ast) { codegen_return_stmt(g, ast); } else if (ast->kind == AstNodeKind_if_stmt) { codegen_if_stmt(g, ast); + } else if (ast->kind == AstNodeKind_switch_stmt) { + codegen_switch_stmt(g, ast); } else if (ast->kind == AstNodeKind_for_stmt) { codegen_for_stmt(g, ast); } else if (ast->kind == AstNodeKind_do_while_stmt) { @@ -601,6 +689,9 @@ static void codegen_stmt(CodeGen* g, AstNode* ast) { codegen_var_decl(g, ast); } else if (ast->kind == AstNodeKind_nop) { codegen_nop(g, ast); + } else if (ast->kind == AstNodeKind_case_label || ast->kind == AstNodeKind_default_label) { + // They are handled by codegen_switch_stmt(). + unreachable(); } else { unreachable(); } diff --git a/src/parse.c b/src/parse.c index ba9b7e5..dbbb61d 100644 --- a/src/parse.c +++ b/src/parse.c @@ -149,6 +149,7 @@ typedef struct { AstNode* typedefs; StrArray str_literals; int anonymous_user_type_counter; + AstNode* current_switch; } Parser; static Parser* parser_new(TokenArray* tokens) { @@ -1214,6 +1215,31 @@ static AstNode* parse_do_while_stmt(Parser* p) { return stmt; } +static AstNode* parse_switch_stmt(Parser* p) { + expect(p, TokenKind_keyword_switch); + expect(p, TokenKind_paren_l); + AstNode* expr = parse_expr(p); + expect(p, TokenKind_paren_r); + + AstNode* tmp_var = generate_temporary_lvar(p, expr->ty); + AstNode* assignment = ast_new_assign_expr(TokenKind_assign, tmp_var, expr); + AstNode* assign_stmt = ast_new(AstNodeKind_expr_stmt); + assign_stmt->node_expr = assignment; + + AstNode* switch_stmt = ast_new(AstNodeKind_switch_stmt); + switch_stmt->node_expr = tmp_var; + + AstNode* prev_switch = p->current_switch; + p->current_switch = switch_stmt; + switch_stmt->node_body = parse_stmt(p); + p->current_switch = prev_switch; + + AstNode* list = ast_new_list(2); + ast_append(list, assign_stmt); + ast_append(list, switch_stmt); + return list; +} + static AstNode* parse_break_stmt(Parser* p) { expect(p, TokenKind_keyword_break); expect(p, TokenKind_semicolon); @@ -1254,10 +1280,37 @@ static AstNode* parse_empty_stmt(Parser* p) { static AstNode* parse_stmt(Parser* p) { Token* t = peek_token(p); - if (t->kind == TokenKind_keyword_return) { + + if (t->kind == TokenKind_keyword_case) { + if (!p->current_switch) { + fatal_error("%s:%d: 'case' label not within a switch statement", t->loc.filename, t->loc.line); + } + expect(p, TokenKind_keyword_case); + AstNode* value = parse_constant_expression(p); + expect(p, TokenKind_colon); + AstNode* stmt = parse_stmt(p); + + AstNode* case_label = ast_new(AstNodeKind_case_label); + case_label->node_int_value = eval(value); + case_label->node_body = stmt; + return case_label; + } else if (t->kind == TokenKind_keyword_default) { + if (!p->current_switch) { + fatal_error("%s:%d: 'default' label not within a switch statement", t->loc.filename, t->loc.line); + } + expect(p, TokenKind_keyword_default); + expect(p, TokenKind_colon); + AstNode* stmt = parse_stmt(p); + + AstNode* default_label = ast_new(AstNodeKind_default_label); + default_label->node_body = stmt; + return default_label; + } else if (t->kind == TokenKind_keyword_return) { return parse_return_stmt(p); } else if (t->kind == TokenKind_keyword_if) { return parse_if_stmt(p); + } else if (t->kind == TokenKind_keyword_switch) { + return parse_switch_stmt(p); } else if (t->kind == TokenKind_keyword_for) { return parse_for_stmt(p); } else if (t->kind == TokenKind_keyword_while) { diff --git a/tests/test_switch.sh b/tests/test_switch.sh new file mode 100644 index 0000000..d5d3dae --- /dev/null +++ b/tests/test_switch.sh @@ -0,0 +1,212 @@ +#!/bin/bash + +test_exit_code 0 <<'EOF' +#include "../../helpers.h" + +int main() { + int x = 2; + int result = 0; + + switch (x) { + case 1: + result = 10; + break; + case 2: + result = 20; + break; + case 3: + result = 30; + break; + } + + ASSERT_EQ(20, result); +} +EOF + +test_exit_code 0 <<'EOF' +#include "../../helpers.h" + +int main() { + int x = 5; + int result = 0; + + switch (x) { + case 1: + result = 10; + break; + case 2: + result = 20; + break; + default: + result = 99; + break; + } + + ASSERT_EQ(99, result); +} +EOF + +test_exit_code 0 <<'EOF' +#include "../../helpers.h" + +int main() { + int x = 2; + int result = 0; + + switch (x) { + case 1: + result = result + 10; + case 2: + result = result + 20; + case 3: + result = result + 30; + break; + } + + ASSERT_EQ(50, result); // 20 + 30 due to fall-through +} +EOF + +test_exit_code 0 <<'EOF' +#include "../../helpers.h" + +int main() { + int x = 1; + int y = 2; + int result = 0; + + switch (x) { + case 1: + switch (y) { + case 1: + result = 11; + break; + case 2: + result = 12; + break; + } + break; + case 2: + result = 20; + break; + } + + ASSERT_EQ(12, result); +} +EOF + +test_exit_code 0 <<'EOF' +#include "../../helpers.h" + +int main() { + int a = 3; + int b = 2; + int result = 0; + + switch (a + b) { + case 4: + result = 40; + break; + case 5: + result = 50; + break; + case 6: + result = 60; + break; + } + + ASSERT_EQ(50, result); +} +EOF + +test_exit_code 0 <<'EOF' +#include "../../helpers.h" + +int main() { + int x = 2; + int result = 0; + int temp = 0; + + switch (x) { + case 1: + temp = 5; + result = temp * 2; + break; + case 2: + temp = 10; + result = temp * 2; + break; + case 3: + temp = 15; + result = temp * 2; + break; + } + + ASSERT_EQ(20, result); + ASSERT_EQ(10, temp); +} +EOF + +test_exit_code 0 <<'EOF' +#include "../../helpers.h" + +int main() { + int x = 1; + int result = 0; + + switch (x) { + case 1: { + int local = 100; + result = local; + break; + } + case 2: { + int local = 200; + result = local; + break; + } + } + + ASSERT_EQ(100, result); +} +EOF + +test_exit_code 0 <<'EOF' +#include "../../helpers.h" + +int main() { + int x = 10; + int result = 42; + + switch (x) { + case 1: + result = 10; + break; + case 2: + result = 20; + break; + } + + ASSERT_EQ(42, result); +} +EOF + +cat <<'EOF' > expected +main.c:2: 'case' label not within a switch statement +EOF +test_compile_error <<'EOF' +int main() { + case 1: + return 0; +} +EOF + +cat <<'EOF' > expected +main.c:2: 'default' label not within a switch statement +EOF +test_compile_error <<'EOF' +int main() { + default: + return 0; +} +EOF |
