diff --git a/chibicc.h b/chibicc.h index fe9e67193a03e30ceee47891e53a1945527eff1c..4d38c5ec596d03231e6e86842b250eb534de2062 100644 --- a/chibicc.h +++ b/chibicc.h @@ -115,6 +115,8 @@ typedef enum { ND_IF, // "if" ND_WHILE, // "while" ND_FOR, // "for" + ND_SWITCH, // "switch" + ND_CASE, // "case" ND_SIZEOF, // "sizeof" ND_BLOCK, // { ... } ND_BREAK, // "break" @@ -162,8 +164,14 @@ struct Node { // Goto or labeled statement char *label_name; - Var *var; // Used if kind == ND_VAR - long val; // Used if kind == ND_NUM + // Switch-cases + Node *case_next; + Node *default_case; + int case_label; + int case_end_label; + + Var *var; + long val; }; typedef struct Function Function; diff --git a/codegen.c b/codegen.c index d430e2643eec7ae7636ca19771a379ea683797b6..a0049e0fd2e337f12e6c3c5dbd1957213925b43a 100644 --- a/codegen.c +++ b/codegen.c @@ -5,7 +5,7 @@ char *argreg2[] = {"di", "si", "dx", "cx", "r8w", "r9w"}; char *argreg4[] = {"edi", "esi", "edx", "ecx", "r8d", "r9d"}; char *argreg8[] = {"rdi", "rsi", "rdx", "rcx", "r8", "r9"}; -int labelseq; +int labelseq = 1; int brkseq; int contseq; char *funcname; @@ -341,6 +341,41 @@ void gen(Node *node) { contseq = cont; return; } + case ND_SWITCH: { + int seq = labelseq++; + int brk = brkseq; + brkseq = seq; + node->case_label = seq; + + gen(node->cond); + printf(" pop rax\n"); + + for (Node *n = node->case_next; n; n = n->case_next) { + n->case_label = labelseq++; + n->case_end_label = seq; + printf(" cmp rax, %ld\n", n->val); + printf(" je .L.case.%d\n", n->case_label); + } + + if (node->default_case) { + int i = labelseq++; + node->default_case->case_end_label = seq; + node->default_case->case_label = i; + printf(" jmp .L.case.%d\n", i); + } + + printf(" jmp .L.break.%d\n", seq); + gen(node->then); + printf(".L.break.%d:\n", seq); + + brkseq = brk; + return; + } + case ND_CASE: + printf(".L.case.%d:\n", node->case_label); + gen(node->lhs); + printf(" jmp .L.break.%d\n", node->case_end_label); + return; case ND_BLOCK: case ND_STMT_EXPR: for (Node *n = node->body; n; n = n->next) diff --git a/parse.c b/parse.c index 2a6c4744eb65f481bd22f632f59adee7d25bf460..a793e6b3b598b822901fdc372d1f8ac51f0a13ab 100644 --- a/parse.c +++ b/parse.c @@ -33,6 +33,8 @@ VarScope *var_scope; TagScope *tag_scope; int scope_depth; +Node *current_switch; + Scope *enter_scope() { Scope *sc = calloc(1, sizeof(Scope)); sc->var_scope = var_scope; @@ -651,6 +653,9 @@ bool is_typename() { // stmt = "return" expr ";" // | "if" "(" expr ")" stmt ("else" stmt)? +// | "switch" "(" expr ")" stmt +// | "case" num ":" stmt +// | "default" ":" stmt // | "while" "(" expr ")" stmt // | "for" "(" (expr? ";" | declaration) expr? ";" expr? ")" stmt // | "{" stmt* "}" @@ -679,6 +684,42 @@ Node *stmt() { return node; } + if (tok = consume("switch")) { + Node *node = new_node(ND_SWITCH, tok); + expect("("); + node->cond = expr(); + expect(")"); + + Node *sw = current_switch; + current_switch = node; + node->then = stmt(); + current_switch = sw; + return node; + } + + if (tok = consume("case")) { + if (!current_switch) + error_tok(tok, "stray case"); + int val = expect_number(); + expect(":"); + + Node *node = new_unary(ND_CASE, stmt(), tok); + node->val = val; + node->case_next = current_switch->case_next; + current_switch->case_next = node; + return node; + } + + if (tok = consume("default")) { + if (!current_switch) + error_tok(tok, "stray default"); + expect(":"); + + Node *node = new_unary(ND_CASE, stmt(), tok); + current_switch->default_case = node; + return node; + } + if (tok = consume("while")) { Node *node = new_node(ND_WHILE, tok); expect("("); diff --git a/tests b/tests index c6e713bec4d4895a348f2bde24775fe1353ad671..ee835fb8a318cd95e6352a24a01a85624c6cd975 100644 --- a/tests +++ b/tests @@ -431,6 +431,13 @@ int main() { assert(2, ({ int i=0; goto e; d: i++; e: i++; f: i++; i; }), "int i=0; goto d; d: i++; e: i++; f: i++; i;"); assert(1, ({ int i=0; goto i; g: i++; h: i++; i: i++; i; }), "int i=0; goto g; h: i++; i: i++; j: i++; i;"); + assert(5, ({ int i=0; switch(0) { case 0:i=5;break; case 1:i=6;break; case 2:i=7;break; } i; }), "int i=0; switch(0) { case 0:i=5;break; case 1:i=6;break; case 2:i=7;break; } i;"); + assert(6, ({ int i=0; switch(1) { case 0:i=5;break; case 1:i=6;break; case 2:i=7;break; } i; }), "int i=0; switch(1) { case 0:i=5;break; case 1:i=6;break; case 2:i=7;break; } i;"); + assert(7, ({ int i=0; switch(2) { case 0:i=5;break; case 1:i=6;break; case 2:i=7;break; } i; }), "int i=0; switch(2) { case 0:i=5;break; case 1:i=6;break; case 2:i=7;break; } i;"); + assert(0, ({ int i=0; switch(3) { case 0:i=5;break; case 1:i=6;break; case 2:i=7;break; } i; }), "int i=0; switch(3) { case 0:i=5;break; case 1:i=6;break; case 2:i=7;break; } i;"); + assert(5, ({ int i=0; switch(0) { case 0:i=5;break; default:i=7; } i; }), "int i=0; switch(0) { case 0:i=5;break; default:i=7; } i;"); + assert(7, ({ int i=0; switch(1) { case 0:i=5;break; default:i=7; } i; }), "int i=0; switch(1) { case 0:i=5;break; default:i=7; } i;"); + printf("OK\n"); return 0; } diff --git a/tokenize.c b/tokenize.c index 9dc71fd30d075b2dae9ddd05d27d8e3bfc228877..41c5969d7faa2cad5430f896d73c0bc793d9708f 100644 --- a/tokenize.c +++ b/tokenize.c @@ -154,7 +154,7 @@ char *starts_with_reserved(char *p) { static char *kw[] = {"return", "if", "else", "while", "for", "int", "char", "sizeof", "struct", "typedef", "short", "long", "void", "_Bool", "enum", "static", "break", - "continue", "goto"}; + "continue", "goto", "switch", "case", "default"}; for (int i = 0; i < sizeof(kw) / sizeof(*kw); i++) { int len = strlen(kw[i]);