diff --git a/main.c b/main.c index 1d2e9bdd410e6cc6fbc20482bab5a972e3b0c1e8..da07a2b4eb1ece671b9e330930012618addb435a 100644 --- a/main.c +++ b/main.c @@ -5,6 +5,10 @@ #include #include +// +// Tokenizer +// + typedef enum { TK_RESERVED, // Keywords or punctuators TK_NUM, // Integer literals @@ -101,7 +105,7 @@ Token *tokenize() { } // Punctuator - if (*p == '+' || *p == '-') { + if (strchr("+-*/()", *p)) { cur = new_token(TK_RESERVED, cur, p++); continue; } @@ -113,40 +117,151 @@ Token *tokenize() { continue; } - error_at(p, "expected a number"); + error_at(p, "invalid token"); } new_token(TK_EOF, cur, p); return head.next; } +// +// Parser +// + +typedef enum { + ND_ADD, // + + ND_SUB, // - + ND_MUL, // * + ND_DIV, // / + ND_NUM, // Integer +} NodeKind; + +// AST node type +typedef struct Node Node; +struct Node { + NodeKind kind; // Node kind + Node *lhs; // Left-hand side + Node *rhs; // Right-hand side + int val; // Used if kind == ND_NUM +}; + +Node *new_node(NodeKind kind) { + Node *node = calloc(1, sizeof(Node)); + node->kind = kind; + return node; +} + +Node *new_binary(NodeKind kind, Node *lhs, Node *rhs) { + Node *node = new_node(kind); + node->lhs = lhs; + node->rhs = rhs; + return node; +} + +Node *new_num(int val) { + Node *node = new_node(ND_NUM); + node->val = val; + return node; +} + +Node *expr(); +Node *mul(); +Node *primary(); + +// expr = mul ("+" mul | "-" mul)* +Node *expr() { + Node *node = mul(); + + for (;;) { + if (consume('+')) + node = new_binary(ND_ADD, node, mul()); + else if (consume('-')) + node = new_binary(ND_SUB, node, mul()); + else + return node; + } +} + +// mul = primary ("*" primary | "/" primary)* +Node *mul() { + Node *node = primary(); + + for (;;) { + if (consume('*')) + node = new_binary(ND_MUL, node, primary()); + else if (consume('/')) + node = new_binary(ND_DIV, node, primary()); + else + return node; + } +} + +// primary = "(" expr ")" | num +Node *primary() { + if (consume('(')) { + Node *node = expr(); + expect(')'); + return node; + } + + return new_num(expect_number()); +} + +// +// Code generator +// + +void gen(Node *node) { + if (node->kind == ND_NUM) { + printf(" push %d\n", node->val); + return; + } + + gen(node->lhs); + gen(node->rhs); + + printf(" pop rdi\n"); + printf(" pop rax\n"); + + switch (node->kind) { + case ND_ADD: + printf(" add rax, rdi\n"); + break; + case ND_SUB: + printf(" sub rax, rdi\n"); + break; + case ND_MUL: + printf(" imul rax, rdi\n"); + break; + case ND_DIV: + printf(" cqo\n"); + printf(" idiv rdi\n"); + break; + } + + printf(" push rax\n"); +} + int main(int argc, char **argv) { - if (argc != 2) { + if (argc != 2) error("%s: invalid number of arguments", argv[0]); - return 1; - } + // Tokenize and parse. user_input = argv[1]; token = tokenize(); + Node *node = expr(); + // Print out the first half of assembly. printf(".intel_syntax noprefix\n"); printf(".global main\n"); printf("main:\n"); - // The first token must be a number - printf(" mov rax, %d\n", expect_number()); - - // ... followed by either `+ ` or `- `. - while (!at_eof()) { - if (consume('+')) { - printf(" add rax, %d\n", expect_number()); - continue; - } - - expect('-'); - printf(" sub rax, %d\n", expect_number()); - } + // Traverse the AST to emit assembly. + gen(node); + // A result must be at the top of the stack, so pop it + // to RAX to make it a program exit code. + printf(" pop rax\n"); printf(" ret\n"); return 0; } diff --git a/test.sh b/test.sh index 823139e3ac4ad5a04ef2267c834e049df648e540..8f8782dece543b63d4eb7a7b604a02da9ce1df53 100755 --- a/test.sh +++ b/test.sh @@ -20,5 +20,8 @@ assert 0 0 assert 42 42 assert 21 '5+20-4' assert 41 ' 12 + 34 - 5 ' +assert 47 '5+6*7' +assert 15 '5*(9-6)' +assert 4 '(3+5)/2' echo OK