From f5536961e8182951376e320fafe4340c3a8e12b3 Mon Sep 17 00:00:00 2001 From: Rui Ueyama Date: Mon, 5 Aug 2019 23:30:35 +0900 Subject: [PATCH] Support arrays including multi-dimensional ones --- chibicc.h | 6 +++++- codegen.c | 22 +++++++++++++++------- main.c | 5 +++-- parse.c | 23 +++++++++++++++++++---- test.sh | 14 ++++++++++++++ tokenize.c | 2 +- type.c | 28 +++++++++++++++++++++++----- 7 files changed, 80 insertions(+), 20 deletions(-) diff --git a/chibicc.h b/chibicc.h index 515e94e..827ba97 100644 --- a/chibicc.h +++ b/chibicc.h @@ -1,3 +1,4 @@ +#include #include #include #include @@ -135,15 +136,18 @@ Function *program(); // typing.c // -typedef enum { TY_INT, TY_PTR } TypeKind; +typedef enum { TY_INT, TY_PTR, TY_ARRAY } TypeKind; struct Type { TypeKind kind; Type *base; + int array_size; }; Type *int_type(); Type *pointer_to(Type *base); +Type *array_of(Type *base, int size); +int size_of(Type *ty); void add_type(Function *prog); diff --git a/codegen.c b/codegen.c index 7323373..c23c40a 100644 --- a/codegen.c +++ b/codegen.c @@ -22,6 +22,12 @@ void gen_addr(Node *node) { error_tok(node->tok, "not an lvalue"); } +void gen_lval(Node *node) { + if (node->ty->kind == TY_ARRAY) + error_tok(node->tok, "not an lvalue"); + gen_addr(node); +} + void load() { printf(" pop rax\n"); printf(" mov rax, [rax]\n"); @@ -49,10 +55,11 @@ void gen(Node *node) { return; case ND_VAR: gen_addr(node); - load(); + if (node->ty->kind != TY_ARRAY) + load(); return; case ND_ASSIGN: - gen_addr(node->lhs); + gen_lval(node->lhs); gen(node->rhs); store(); return; @@ -61,7 +68,8 @@ void gen(Node *node) { return; case ND_DEREF: gen(node->lhs); - load(); + if (node->ty->kind != TY_ARRAY) + load(); return; case ND_IF: { int seq = labelseq++; @@ -163,13 +171,13 @@ void gen(Node *node) { switch (node->kind) { case ND_ADD: - if (node->ty->kind == TY_PTR) - printf(" imul rdi, 8\n"); + if (node->ty->base) + printf(" imul rdi, %d\n", size_of(node->ty->base)); printf(" add rax, rdi\n"); break; case ND_SUB: - if (node->ty->kind == TY_PTR) - printf(" imul rdi, 8\n"); + if (node->ty->base) + printf(" imul rdi, %d\n", size_of(node->ty->base)); printf(" sub rax, rdi\n"); break; case ND_MUL: diff --git a/main.c b/main.c index f29b51c..bf6c0e4 100644 --- a/main.c +++ b/main.c @@ -14,8 +14,9 @@ int main(int argc, char **argv) { for (Function *fn = prog; fn; fn = fn->next) { int offset = 0; for (VarList *vl = fn->locals; vl; vl = vl->next) { - offset += 8; - vl->var->offset = offset; + Var *var = vl->var; + offset += size_of(var->ty); + var->offset = offset; } fn->stack_size = offset; } diff --git a/parse.c b/parse.c index 39bd6dc..7ce425c 100644 --- a/parse.c +++ b/parse.c @@ -91,10 +91,22 @@ Type *basetype() { return ty; } +Type *read_type_suffix(Type *base) { + if (!consume("[")) + return base; + int sz = expect_number(); + expect("]"); + base = read_type_suffix(base); + return array_of(base, sz); +} + VarList *read_func_param() { - VarList *vl = calloc(1, sizeof(VarList)); Type *ty = basetype(); - vl->var = push_var(expect_ident(), ty); + char *name = expect_ident(); + ty = read_type_suffix(ty); + + VarList *vl = calloc(1, sizeof(VarList)); + vl->var = push_var(name, ty); return vl; } @@ -140,11 +152,13 @@ Function *function() { return fn; } -// declaration = basetype ident ("=" expr) ";" +// declaration = basetype ident ("[" num "]")* ("=" expr) ";" Node *declaration() { Token *tok = token; Type *ty = basetype(); - Var *var = push_var(expect_ident(), ty); + char *name = expect_ident(); + ty = read_type_suffix(ty); + Var *var = push_var(name, ty); if (consume(";")) return new_node(ND_NULL, tok); @@ -231,6 +245,7 @@ Node *stmt() { return node; } + if (tok = peek("int")) return declaration(); diff --git a/test.sh b/test.sh index a036b57..156f3e7 100755 --- a/test.sh +++ b/test.sh @@ -105,4 +105,18 @@ assert 7 'int main() { int x=3; int y=5; *(&x+1)=7; return y; }' assert 7 'int main() { int x=3; int y=5; *(&y-1)=7; return x; }' assert 8 'int main() { int x=3; int y=5; return foo(&x, y); } int foo(int *x, int y) { return *x + y; }' +assert 3 'int main() { int x[2]; int *y=&x; *y=3; return *x; }' + +assert 3 'int main() { int x[3]; *x=3; *(x+1)=4; *(x+2)=5; return *x; }' +assert 4 'int main() { int x[3]; *x=3; *(x+1)=4; *(x+2)=5; return *(x+1); }' +assert 5 'int main() { int x[3]; *x=3; *(x+1)=4; *(x+2)=5; return *(x+2); }' + +assert 0 'int main() { int x[2][3]; int *y=x; *y=0; return **x; }' +assert 1 'int main() { int x[2][3]; int *y=x; *(y+1)=1; return *(*x+1); }' +assert 2 'int main() { int x[2][3]; int *y=x; *(y+2)=2; return *(*x+2); }' +assert 3 'int main() { int x[2][3]; int *y=x; *(y+3)=3; return **(x+1); }' +assert 4 'int main() { int x[2][3]; int *y=x; *(y+4)=4; return *(*(x+1)+1); }' +assert 5 'int main() { int x[2][3]; int *y=x; *(y+5)=5; return *(*(x+1)+2); }' +assert 6 'int main() { int x[2][3]; int *y=x; *(y+6)=6; return **(x+2); }' + echo OK diff --git a/tokenize.c b/tokenize.c index 223045c..8c28506 100644 --- a/tokenize.c +++ b/tokenize.c @@ -170,7 +170,7 @@ Token *tokenize() { } // Single-letter punctuator - if (strchr("+-*/()<>;={},&", *p)) { + if (strchr("+-*/()<>;={},&[]", *p)) { cur = new_token(TK_RESERVED, cur, p++, 1); continue; } diff --git a/type.c b/type.c index d6a373e..dd11e08 100644 --- a/type.c +++ b/type.c @@ -13,6 +13,21 @@ Type *pointer_to(Type *base) { return ty; } +Type *array_of(Type *base, int size) { + Type *ty = calloc(1, sizeof(Type)); + ty->kind = TY_ARRAY; + ty->base = base; + ty->array_size = size; + return ty; +} + +int size_of(Type *ty) { + if (ty->kind == TY_INT || ty->kind == TY_PTR) + return 8; + assert(ty->kind == TY_ARRAY); + return size_of(ty->base) * ty->array_size; +} + void visit(Node *node) { if (!node) return; @@ -45,17 +60,17 @@ void visit(Node *node) { node->ty = node->var->ty; return; case ND_ADD: - if (node->rhs->ty->kind == TY_PTR) { + if (node->rhs->ty->base) { Node *tmp = node->lhs; node->lhs = node->rhs; node->rhs = tmp; } - if (node->rhs->ty->kind == TY_PTR) + if (node->rhs->ty->base) error_tok(node->tok, "invalid pointer arithmetic operands"); node->ty = node->lhs->ty; return; case ND_SUB: - if (node->rhs->ty->kind == TY_PTR) + if (node->rhs->ty->base) error_tok(node->tok, "invalid pointer arithmetic operands"); node->ty = node->lhs->ty; return; @@ -63,10 +78,13 @@ void visit(Node *node) { node->ty = node->lhs->ty; return; case ND_ADDR: - node->ty = pointer_to(node->lhs->ty); + if (node->lhs->ty->kind == TY_ARRAY) + node->ty = pointer_to(node->lhs->ty->base); + else + node->ty = pointer_to(node->lhs->ty); return; case ND_DEREF: - if (node->lhs->ty->kind != TY_PTR) + if (!node->lhs->ty->base) error_tok(node->tok, "invalid pointer dereference"); node->ty = node->lhs->ty->base; return; -- GitLab