aboutsummaryrefslogtreecommitdiffhomepage
path: root/src/codegen.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/codegen.c')
-rw-r--r--src/codegen.c95
1 files changed, 93 insertions, 2 deletions
diff --git a/src/codegen.c b/src/codegen.c
index cc3f9a7..91f8cc6 100644
--- a/src/codegen.c
+++ b/src/codegen.c
@@ -14,6 +14,7 @@ typedef struct {
int next_label;
int* loop_labels;
AstNode* current_func;
+ int switch_label;
} CodeGen;
static CodeGen* codegen_new(FILE* out) {
@@ -21,6 +22,7 @@ static CodeGen* codegen_new(FILE* out) {
g->out = out;
g->next_label = 1;
g->loop_labels = calloc(1024, sizeof(int));
+ g->switch_label = -1;
return g;
}
@@ -552,8 +554,12 @@ static void codegen_do_while_stmt(CodeGen* g, AstNode* ast) {
}
static void codegen_break_stmt(CodeGen* g, AstNode* ast) {
- int label = *g->loop_labels;
- fprintf(g->out, " jmp .Lend%d\n", label);
+ if (g->switch_label != -1) {
+ fprintf(g->out, " jmp .Lend%d\n", g->switch_label);
+ } else {
+ int label = *g->loop_labels;
+ fprintf(g->out, " jmp .Lend%d\n", label);
+ }
}
static void codegen_continue_stmt(CodeGen* g, AstNode* ast) {
@@ -561,6 +567,86 @@ static void codegen_continue_stmt(CodeGen* g, AstNode* ast) {
fprintf(g->out, " jmp .Lcontinue%d\n", label);
}
+// Helper to collect case values from the switch body
+static void collect_cases(AstNode* stmt, int* case_values, int* case_labels, int* n_cases) {
+ if (!stmt)
+ return;
+
+ if (stmt->kind == AstNodeKind_case_label) {
+ case_values[*n_cases] = stmt->node_int_value;
+ case_labels[*n_cases] = *n_cases + 1;
+ (*n_cases)++;
+ collect_cases(stmt->node_body, case_values, case_labels, n_cases);
+ } else if (stmt->kind == AstNodeKind_default_label) {
+ collect_cases(stmt->node_body, case_values, case_labels, n_cases);
+ } else if (stmt->kind == AstNodeKind_list) {
+ for (int i = 0; i < stmt->node_len; i++) {
+ collect_cases(stmt->node_items + i, case_values, case_labels, n_cases);
+ }
+ }
+}
+
+static bool codegen_switch_body(CodeGen* g, AstNode* stmt, int* case_values, int* case_labels, int n_cases) {
+ if (!stmt)
+ return false;
+
+ if (stmt->kind == AstNodeKind_case_label) {
+ int value = stmt->node_int_value;
+ for (int i = 0; i < n_cases; i++) {
+ if (case_values[i] == value) {
+ fprintf(g->out, ".Lcase%d_%d:\n", g->switch_label, case_labels[i]);
+ break;
+ }
+ }
+ codegen_stmt(g, stmt->node_body);
+ return false;
+ } else if (stmt->kind == AstNodeKind_default_label) {
+ fprintf(g->out, ".Ldefault%d:\n", g->switch_label);
+ codegen_stmt(g, stmt->node_body);
+ return true;
+ } else if (stmt->kind == AstNodeKind_list) {
+ bool default_label_emitted = false;
+ for (int i = 0; i < stmt->node_len; i++) {
+ default_label_emitted |= codegen_switch_body(g, stmt->node_items + i, case_values, case_labels, n_cases);
+ }
+ return default_label_emitted;
+ } else {
+ codegen_stmt(g, stmt);
+ return false;
+ }
+}
+
+static void codegen_switch_stmt(CodeGen* g, AstNode* ast) {
+ int switch_label = codegen_new_label(g);
+ int prev_switch_label = g->switch_label;
+ g->switch_label = switch_label;
+
+ // Collect all case values and assign labels
+ int case_values[256];
+ int case_labels[256];
+ int n_cases = 0;
+ collect_cases(ast->node_body, case_values, case_labels, &n_cases);
+
+ // Generate jump instructions.
+ codegen_expr(g, ast->node_expr, GenMode_rval);
+ fprintf(g->out, " pop rax\n");
+ for (int i = 0; i < n_cases; i++) {
+ fprintf(g->out, " cmp rax, %d\n", case_values[i]);
+ fprintf(g->out, " je .Lcase%d_%d\n", switch_label, case_labels[i]);
+ }
+ fprintf(g->out, " jmp .Ldefault%d\n", switch_label);
+
+ // Generate the switch body with labels.
+ bool default_label_emitted = codegen_switch_body(g, ast->node_body, case_values, case_labels, n_cases);
+
+ if (!default_label_emitted) {
+ fprintf(g->out, ".Ldefault%d:\n", switch_label);
+ }
+ fprintf(g->out, ".Lend%d:\n", switch_label);
+
+ g->switch_label = prev_switch_label;
+}
+
static void codegen_expr_stmt(CodeGen* g, AstNode* ast) {
codegen_expr(g, ast->node_expr, GenMode_rval);
// TODO: the expression on the stack can be more than 8 bytes.
@@ -587,6 +673,8 @@ static void codegen_stmt(CodeGen* g, AstNode* ast) {
codegen_return_stmt(g, ast);
} else if (ast->kind == AstNodeKind_if_stmt) {
codegen_if_stmt(g, ast);
+ } else if (ast->kind == AstNodeKind_switch_stmt) {
+ codegen_switch_stmt(g, ast);
} else if (ast->kind == AstNodeKind_for_stmt) {
codegen_for_stmt(g, ast);
} else if (ast->kind == AstNodeKind_do_while_stmt) {
@@ -601,6 +689,9 @@ static void codegen_stmt(CodeGen* g, AstNode* ast) {
codegen_var_decl(g, ast);
} else if (ast->kind == AstNodeKind_nop) {
codegen_nop(g, ast);
+ } else if (ast->kind == AstNodeKind_case_label || ast->kind == AstNodeKind_default_label) {
+ // They are handled by codegen_switch_stmt().
+ unreachable();
} else {
unreachable();
}