aboutsummaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--src/ast.h3
-rw-r--r--src/codegen.c95
-rw-r--r--src/parse.c55
-rw-r--r--tests/test_switch.sh212
4 files changed, 362 insertions, 3 deletions
diff --git a/src/ast.h b/src/ast.h
index 3c92eed..5f4367d 100644
--- a/src/ast.h
+++ b/src/ast.h
@@ -91,9 +91,11 @@ typedef enum {
AstNodeKind_assign_expr,
AstNodeKind_binary_expr,
AstNodeKind_break_stmt,
+ AstNodeKind_case_label,
AstNodeKind_cast_expr,
AstNodeKind_cond_expr,
AstNodeKind_continue_stmt,
+ AstNodeKind_default_label,
AstNodeKind_deref_expr,
AstNodeKind_do_while_stmt,
AstNodeKind_enum_def,
@@ -118,6 +120,7 @@ typedef enum {
AstNodeKind_struct_decl,
AstNodeKind_struct_def,
AstNodeKind_struct_member,
+ AstNodeKind_switch_stmt,
AstNodeKind_type,
AstNodeKind_typedef_decl,
AstNodeKind_unary_expr,
diff --git a/src/codegen.c b/src/codegen.c
index cc3f9a7..91f8cc6 100644
--- a/src/codegen.c
+++ b/src/codegen.c
@@ -14,6 +14,7 @@ typedef struct {
int next_label;
int* loop_labels;
AstNode* current_func;
+ int switch_label;
} CodeGen;
static CodeGen* codegen_new(FILE* out) {
@@ -21,6 +22,7 @@ static CodeGen* codegen_new(FILE* out) {
g->out = out;
g->next_label = 1;
g->loop_labels = calloc(1024, sizeof(int));
+ g->switch_label = -1;
return g;
}
@@ -552,8 +554,12 @@ static void codegen_do_while_stmt(CodeGen* g, AstNode* ast) {
}
static void codegen_break_stmt(CodeGen* g, AstNode* ast) {
- int label = *g->loop_labels;
- fprintf(g->out, " jmp .Lend%d\n", label);
+ if (g->switch_label != -1) {
+ fprintf(g->out, " jmp .Lend%d\n", g->switch_label);
+ } else {
+ int label = *g->loop_labels;
+ fprintf(g->out, " jmp .Lend%d\n", label);
+ }
}
static void codegen_continue_stmt(CodeGen* g, AstNode* ast) {
@@ -561,6 +567,86 @@ static void codegen_continue_stmt(CodeGen* g, AstNode* ast) {
fprintf(g->out, " jmp .Lcontinue%d\n", label);
}
+// Helper to collect case values from the switch body
+static void collect_cases(AstNode* stmt, int* case_values, int* case_labels, int* n_cases) {
+ if (!stmt)
+ return;
+
+ if (stmt->kind == AstNodeKind_case_label) {
+ case_values[*n_cases] = stmt->node_int_value;
+ case_labels[*n_cases] = *n_cases + 1;
+ (*n_cases)++;
+ collect_cases(stmt->node_body, case_values, case_labels, n_cases);
+ } else if (stmt->kind == AstNodeKind_default_label) {
+ collect_cases(stmt->node_body, case_values, case_labels, n_cases);
+ } else if (stmt->kind == AstNodeKind_list) {
+ for (int i = 0; i < stmt->node_len; i++) {
+ collect_cases(stmt->node_items + i, case_values, case_labels, n_cases);
+ }
+ }
+}
+
+static bool codegen_switch_body(CodeGen* g, AstNode* stmt, int* case_values, int* case_labels, int n_cases) {
+ if (!stmt)
+ return false;
+
+ if (stmt->kind == AstNodeKind_case_label) {
+ int value = stmt->node_int_value;
+ for (int i = 0; i < n_cases; i++) {
+ if (case_values[i] == value) {
+ fprintf(g->out, ".Lcase%d_%d:\n", g->switch_label, case_labels[i]);
+ break;
+ }
+ }
+ codegen_stmt(g, stmt->node_body);
+ return false;
+ } else if (stmt->kind == AstNodeKind_default_label) {
+ fprintf(g->out, ".Ldefault%d:\n", g->switch_label);
+ codegen_stmt(g, stmt->node_body);
+ return true;
+ } else if (stmt->kind == AstNodeKind_list) {
+ bool default_label_emitted = false;
+ for (int i = 0; i < stmt->node_len; i++) {
+ default_label_emitted |= codegen_switch_body(g, stmt->node_items + i, case_values, case_labels, n_cases);
+ }
+ return default_label_emitted;
+ } else {
+ codegen_stmt(g, stmt);
+ return false;
+ }
+}
+
+static void codegen_switch_stmt(CodeGen* g, AstNode* ast) {
+ int switch_label = codegen_new_label(g);
+ int prev_switch_label = g->switch_label;
+ g->switch_label = switch_label;
+
+ // Collect all case values and assign labels
+ int case_values[256];
+ int case_labels[256];
+ int n_cases = 0;
+ collect_cases(ast->node_body, case_values, case_labels, &n_cases);
+
+ // Generate jump instructions.
+ codegen_expr(g, ast->node_expr, GenMode_rval);
+ fprintf(g->out, " pop rax\n");
+ for (int i = 0; i < n_cases; i++) {
+ fprintf(g->out, " cmp rax, %d\n", case_values[i]);
+ fprintf(g->out, " je .Lcase%d_%d\n", switch_label, case_labels[i]);
+ }
+ fprintf(g->out, " jmp .Ldefault%d\n", switch_label);
+
+ // Generate the switch body with labels.
+ bool default_label_emitted = codegen_switch_body(g, ast->node_body, case_values, case_labels, n_cases);
+
+ if (!default_label_emitted) {
+ fprintf(g->out, ".Ldefault%d:\n", switch_label);
+ }
+ fprintf(g->out, ".Lend%d:\n", switch_label);
+
+ g->switch_label = prev_switch_label;
+}
+
static void codegen_expr_stmt(CodeGen* g, AstNode* ast) {
codegen_expr(g, ast->node_expr, GenMode_rval);
// TODO: the expression on the stack can be more than 8 bytes.
@@ -587,6 +673,8 @@ static void codegen_stmt(CodeGen* g, AstNode* ast) {
codegen_return_stmt(g, ast);
} else if (ast->kind == AstNodeKind_if_stmt) {
codegen_if_stmt(g, ast);
+ } else if (ast->kind == AstNodeKind_switch_stmt) {
+ codegen_switch_stmt(g, ast);
} else if (ast->kind == AstNodeKind_for_stmt) {
codegen_for_stmt(g, ast);
} else if (ast->kind == AstNodeKind_do_while_stmt) {
@@ -601,6 +689,9 @@ static void codegen_stmt(CodeGen* g, AstNode* ast) {
codegen_var_decl(g, ast);
} else if (ast->kind == AstNodeKind_nop) {
codegen_nop(g, ast);
+ } else if (ast->kind == AstNodeKind_case_label || ast->kind == AstNodeKind_default_label) {
+ // They are handled by codegen_switch_stmt().
+ unreachable();
} else {
unreachable();
}
diff --git a/src/parse.c b/src/parse.c
index ba9b7e5..dbbb61d 100644
--- a/src/parse.c
+++ b/src/parse.c
@@ -149,6 +149,7 @@ typedef struct {
AstNode* typedefs;
StrArray str_literals;
int anonymous_user_type_counter;
+ AstNode* current_switch;
} Parser;
static Parser* parser_new(TokenArray* tokens) {
@@ -1214,6 +1215,31 @@ static AstNode* parse_do_while_stmt(Parser* p) {
return stmt;
}
+static AstNode* parse_switch_stmt(Parser* p) {
+ expect(p, TokenKind_keyword_switch);
+ expect(p, TokenKind_paren_l);
+ AstNode* expr = parse_expr(p);
+ expect(p, TokenKind_paren_r);
+
+ AstNode* tmp_var = generate_temporary_lvar(p, expr->ty);
+ AstNode* assignment = ast_new_assign_expr(TokenKind_assign, tmp_var, expr);
+ AstNode* assign_stmt = ast_new(AstNodeKind_expr_stmt);
+ assign_stmt->node_expr = assignment;
+
+ AstNode* switch_stmt = ast_new(AstNodeKind_switch_stmt);
+ switch_stmt->node_expr = tmp_var;
+
+ AstNode* prev_switch = p->current_switch;
+ p->current_switch = switch_stmt;
+ switch_stmt->node_body = parse_stmt(p);
+ p->current_switch = prev_switch;
+
+ AstNode* list = ast_new_list(2);
+ ast_append(list, assign_stmt);
+ ast_append(list, switch_stmt);
+ return list;
+}
+
static AstNode* parse_break_stmt(Parser* p) {
expect(p, TokenKind_keyword_break);
expect(p, TokenKind_semicolon);
@@ -1254,10 +1280,37 @@ static AstNode* parse_empty_stmt(Parser* p) {
static AstNode* parse_stmt(Parser* p) {
Token* t = peek_token(p);
- if (t->kind == TokenKind_keyword_return) {
+
+ if (t->kind == TokenKind_keyword_case) {
+ if (!p->current_switch) {
+ fatal_error("%s:%d: 'case' label not within a switch statement", t->loc.filename, t->loc.line);
+ }
+ expect(p, TokenKind_keyword_case);
+ AstNode* value = parse_constant_expression(p);
+ expect(p, TokenKind_colon);
+ AstNode* stmt = parse_stmt(p);
+
+ AstNode* case_label = ast_new(AstNodeKind_case_label);
+ case_label->node_int_value = eval(value);
+ case_label->node_body = stmt;
+ return case_label;
+ } else if (t->kind == TokenKind_keyword_default) {
+ if (!p->current_switch) {
+ fatal_error("%s:%d: 'default' label not within a switch statement", t->loc.filename, t->loc.line);
+ }
+ expect(p, TokenKind_keyword_default);
+ expect(p, TokenKind_colon);
+ AstNode* stmt = parse_stmt(p);
+
+ AstNode* default_label = ast_new(AstNodeKind_default_label);
+ default_label->node_body = stmt;
+ return default_label;
+ } else if (t->kind == TokenKind_keyword_return) {
return parse_return_stmt(p);
} else if (t->kind == TokenKind_keyword_if) {
return parse_if_stmt(p);
+ } else if (t->kind == TokenKind_keyword_switch) {
+ return parse_switch_stmt(p);
} else if (t->kind == TokenKind_keyword_for) {
return parse_for_stmt(p);
} else if (t->kind == TokenKind_keyword_while) {
diff --git a/tests/test_switch.sh b/tests/test_switch.sh
new file mode 100644
index 0000000..d5d3dae
--- /dev/null
+++ b/tests/test_switch.sh
@@ -0,0 +1,212 @@
+#!/bin/bash
+
+test_exit_code 0 <<'EOF'
+#include "../../helpers.h"
+
+int main() {
+ int x = 2;
+ int result = 0;
+
+ switch (x) {
+ case 1:
+ result = 10;
+ break;
+ case 2:
+ result = 20;
+ break;
+ case 3:
+ result = 30;
+ break;
+ }
+
+ ASSERT_EQ(20, result);
+}
+EOF
+
+test_exit_code 0 <<'EOF'
+#include "../../helpers.h"
+
+int main() {
+ int x = 5;
+ int result = 0;
+
+ switch (x) {
+ case 1:
+ result = 10;
+ break;
+ case 2:
+ result = 20;
+ break;
+ default:
+ result = 99;
+ break;
+ }
+
+ ASSERT_EQ(99, result);
+}
+EOF
+
+test_exit_code 0 <<'EOF'
+#include "../../helpers.h"
+
+int main() {
+ int x = 2;
+ int result = 0;
+
+ switch (x) {
+ case 1:
+ result = result + 10;
+ case 2:
+ result = result + 20;
+ case 3:
+ result = result + 30;
+ break;
+ }
+
+ ASSERT_EQ(50, result); // 20 + 30 due to fall-through
+}
+EOF
+
+test_exit_code 0 <<'EOF'
+#include "../../helpers.h"
+
+int main() {
+ int x = 1;
+ int y = 2;
+ int result = 0;
+
+ switch (x) {
+ case 1:
+ switch (y) {
+ case 1:
+ result = 11;
+ break;
+ case 2:
+ result = 12;
+ break;
+ }
+ break;
+ case 2:
+ result = 20;
+ break;
+ }
+
+ ASSERT_EQ(12, result);
+}
+EOF
+
+test_exit_code 0 <<'EOF'
+#include "../../helpers.h"
+
+int main() {
+ int a = 3;
+ int b = 2;
+ int result = 0;
+
+ switch (a + b) {
+ case 4:
+ result = 40;
+ break;
+ case 5:
+ result = 50;
+ break;
+ case 6:
+ result = 60;
+ break;
+ }
+
+ ASSERT_EQ(50, result);
+}
+EOF
+
+test_exit_code 0 <<'EOF'
+#include "../../helpers.h"
+
+int main() {
+ int x = 2;
+ int result = 0;
+ int temp = 0;
+
+ switch (x) {
+ case 1:
+ temp = 5;
+ result = temp * 2;
+ break;
+ case 2:
+ temp = 10;
+ result = temp * 2;
+ break;
+ case 3:
+ temp = 15;
+ result = temp * 2;
+ break;
+ }
+
+ ASSERT_EQ(20, result);
+ ASSERT_EQ(10, temp);
+}
+EOF
+
+test_exit_code 0 <<'EOF'
+#include "../../helpers.h"
+
+int main() {
+ int x = 1;
+ int result = 0;
+
+ switch (x) {
+ case 1: {
+ int local = 100;
+ result = local;
+ break;
+ }
+ case 2: {
+ int local = 200;
+ result = local;
+ break;
+ }
+ }
+
+ ASSERT_EQ(100, result);
+}
+EOF
+
+test_exit_code 0 <<'EOF'
+#include "../../helpers.h"
+
+int main() {
+ int x = 10;
+ int result = 42;
+
+ switch (x) {
+ case 1:
+ result = 10;
+ break;
+ case 2:
+ result = 20;
+ break;
+ }
+
+ ASSERT_EQ(42, result);
+}
+EOF
+
+cat <<'EOF' > expected
+main.c:2: 'case' label not within a switch statement
+EOF
+test_compile_error <<'EOF'
+int main() {
+ case 1:
+ return 0;
+}
+EOF
+
+cat <<'EOF' > expected
+main.c:2: 'default' label not within a switch statement
+EOF
+test_compile_error <<'EOF'
+int main() {
+ default:
+ return 0;
+}
+EOF