#include #include #include #include "ast.h" #include "desugar.h" #include "io.h" #include "module.h" #include "parser.h" #include "resolver.h" #include "symtab.h" /* * AST desugaring pass * * This pass runs before resolving and transforms the AST. */ /* Forward declarations */ static node_t *desugar_node(desugar_t *d, module_t *mod, node_t *n); static node_t *desugar_and_operator(desugar_t *d, module_t *mod, node_t *binop); static node_t *desugar_or_operator(desugar_t *d, module_t *mod, node_t *binop); static node_t *desugar_guard_stmt( desugar_t *d, module_t *mod, node_t *block, usize ix, node_t *guard ); static node_t *desugar_block(desugar_t *d, module_t *mod, node_t *block); /* Allocate a new AST node using the module's parser */ static node_t *node(module_t *mod, nodeclass_t cls, node_t *original) { parser_t *p = &mod->parser; if (p->nnodes >= MAX_NODES) { bail("maximum number of AST nodes reached"); } node_t *n = &p->nodes[p->nnodes++]; n->cls = cls; n->type = NULL; n->sym = NULL; n->offset = original ? original->offset : 0; n->length = original ? original->length : 0; return n; } static node_t *node_bool(module_t *mod, bool b, node_t *loc) { node_t *lit = node(mod, NODE_BOOL, loc); lit->val.bool_lit = b; return lit; } /* Create an empty block node. */ static node_t *node_block(module_t *mod, node_t *original) { node_t *block = node(mod, NODE_BLOCK, original); block->val.block.stmts = (nodespan_t){ 0 }; block->val.block.scope = NULL; return block; } /* Transform a while-let loop into: * * loop { * if let var = expr; guard { * body; * } else { * rbranch; * break; * } * } */ static node_t *desugar_while_let( desugar_t *d, module_t *mod, node_t *while_let_node ) { /* Create the loop node */ node_t *loop_node = node(mod, NODE_LOOP, while_let_node); /* Create the if-let node */ node_t *if_let_node = node(mod, NODE_IF_LET, while_let_node); /* Create the break statement */ node_t *break_node = node(mod, NODE_BREAK, while_let_node); /* Handle the else clause */ node_t *else_clause = node_block(mod, while_let_node); if (while_let_node->val.while_let_stmt.rbranch) { node_block_add_stmt( mod, else_clause, desugar_node(d, mod, while_let_node->val.while_let_stmt.rbranch) ); } node_block_add_stmt(mod, else_clause, break_node); node_t *loop_body = node_block(mod, while_let_node); node_block_add_stmt(mod, loop_body, if_let_node); /* Set up the if-let statement */ if_let_node->val.if_let_stmt.var = while_let_node->val.while_let_stmt.var; if_let_node->val.if_let_stmt.expr = desugar_node(d, mod, while_let_node->val.while_let_stmt.expr); if_let_node->val.if_let_stmt.guard = while_let_node->val.while_let_stmt.guard ? desugar_node(d, mod, while_let_node->val.while_let_stmt.guard) : NULL; if_let_node->val.if_let_stmt.lbranch = desugar_node(d, mod, while_let_node->val.while_let_stmt.body); if_let_node->val.if_let_stmt.rbranch = else_clause; if_let_node->val.if_let_stmt.scope = NULL; /* Will be set during resolving */ /* Set up the loop statement */ loop_node->val.loop_stmt.body = loop_body; return loop_node; } /* Transform a while loop into: * * loop { * if (condition) { * body; * } else { * else_clause; * break * } * } */ static node_t *desugar_while(desugar_t *d, module_t *mod, node_t *while_node) { /* Create the loop node */ node_t *loop_node = node(mod, NODE_LOOP, while_node); /* Create the condition check */ node_t *if_node = node(mod, NODE_IF, while_node); /* Create the break statement */ node_t *break_node = node(mod, NODE_BREAK, while_node); /* Handle the else clause */ node_t *else_clause = node_block(mod, while_node); if (while_node->val.while_stmt.rbranch) { node_block_add_stmt( mod, else_clause, desugar_node(d, mod, while_node->val.while_stmt.rbranch) ); } node_block_add_stmt(mod, else_clause, break_node); node_t *loop_body = node_block(mod, while_node->val.while_stmt.body); node_block_add_stmt(mod, loop_body, if_node); /* Set up the if statement. */ if_node->val.if_stmt.cond = desugar_node(d, mod, while_node->val.while_stmt.cond); if_node->val.if_stmt.lbranch = desugar_node(d, mod, while_node->val.while_stmt.body); if_node->val.if_stmt.rbranch = else_clause; /* Set up the loop statement */ loop_node->val.loop_stmt.body = loop_body; return loop_node; } static node_t *node_ident(module_t *mod, const char *name, node_t *loc) { node_t *ident = node(mod, NODE_IDENT, loc); ident->val.ident.name = name; ident->val.ident.length = strlen(name); return ident; } static node_t *node_var( module_t *mod, node_t *ident, node_t *typ, node_t *val, bool mut, node_t *loc ) { node_t *var = node(mod, NODE_VAR, loc); var->val.var.ident = ident; var->val.var.type = typ; var->val.var.value = val; var->val.var.mutable = mut; return var; } static node_t *node_number(module_t *mod, const char *text, node_t *loc) { node_t *n = node(mod, NODE_NUMBER, loc); n->val.number.text = text; n->val.number.text_len = strlen(text); return n; } static node_t *node_type(module_t *mod, typeclass_t tc, node_t *loc) { node_t *typ = node(mod, NODE_TYPE, loc); typ->val.type.tclass = tc; return typ; } static node_t *node_access( module_t *mod, node_t *lval, node_t *rval, node_t *loc ) { node_t *access = node(mod, NODE_ACCESS, loc); access->val.access.lval = lval; access->val.access.rval = rval; return access; } static node_t *node_access_str( module_t *mod, node_t *lval, const char *field, node_t *loc ) { return node_access(mod, lval, node_ident(mod, field, loc), loc); } static node_t *node_binop( module_t *mod, binop_t op, node_t *left, node_t *right, node_t *loc ) { node_t *binop = node(mod, NODE_BINOP, loc); binop->val.binop.op = op; binop->val.binop.left = left; binop->val.binop.right = right; return binop; } static node_t *node_increment( module_t *mod, node_t *lval_ident, node_t *expr_ident, node_t *loc ) { node_t *assign = node(mod, NODE_ASSIGN, loc); assign->val.assign.lval = lval_ident; assign->val.assign.rval = node_binop(mod, OP_ADD, expr_ident, node_number(mod, "1", loc), loc); return assign; } static node_t *node_match(module_t *mod, node_t *expr, node_t *loc) { node_t *swtch = node(mod, NODE_MATCH, loc); swtch->val.match_stmt.expr = expr; swtch->val.match_stmt.cases = (nodespan_t){ 0 }; return swtch; } static node_t *node_match_case( module_t *mod, node_t *pattern, node_t *guard, node_t *body, node_t *loc ) { node_t *swtch_case = node(mod, NODE_MATCH_CASE, loc); swtch_case->val.match_case.patterns = (nodespan_t){ 0 }; if (pattern != NULL) { nodespan_push( &mod->parser, &swtch_case->val.match_case.patterns, pattern ); } swtch_case->val.match_case.body = body; swtch_case->val.match_case.guard = guard; swtch_case->val.match_case.variable = NULL; return swtch_case; } /* * Transform guard statements into their desugared control flow. * * For example, `let value = opt else { handle(); }; rest;` becomes: * * if let value = opt { * rest; * } else { * handle(); * } * * Likewise, `let case Pattern(x) = expr else { ... };` becomes an * equivalent `if case` construct with the suffix statements placed in * the success branch so they retain access to bound names. */ static node_t *desugar_guard_stmt( desugar_t *d, module_t *mod, node_t *block, usize index, node_t *guard ) { node_t *success = node_block(mod, guard); node_t *if_stmt = NULL; /* Add the rest of the surrounding block into the success branch. */ node_t **stmts = nodespan_ptrs(&mod->parser, block->val.block.stmts); for (usize j = index + 1; j < block->val.block.stmts.len; j++) { node_block_add_stmt(mod, success, stmts[j]); } if (guard->cls == NODE_GUARD_LET) { if_stmt = node(mod, NODE_IF_LET, guard); if_stmt->val.if_let_stmt.var = guard->val.guard_let_stmt.var; if_stmt->val.if_let_stmt.expr = guard->val.guard_let_stmt.expr; if_stmt->val.if_let_stmt.guard = NULL; if_stmt->val.if_let_stmt.lbranch = success; if_stmt->val.if_let_stmt.rbranch = guard->val.guard_let_stmt.rbranch; if_stmt->val.if_let_stmt.scope = NULL; } else { if_stmt = node(mod, NODE_IF_CASE, guard); if_stmt->val.if_case_stmt.pattern = guard->val.guard_case_stmt.pattern; if_stmt->val.if_case_stmt.expr = guard->val.guard_case_stmt.expr; if_stmt->val.if_case_stmt.guard = guard->val.guard_case_stmt.guard; if_stmt->val.if_case_stmt.lbranch = success; if_stmt->val.if_case_stmt.rbranch = guard->val.guard_case_stmt.rbranch; } block->val.block.stmts.len = index + 1; return desugar_node(d, mod, if_stmt); } static node_t *desugar_block(desugar_t *d, module_t *mod, node_t *block) { node_t **stmts = nodespan_ptrs(&mod->parser, block->val.block.stmts); for (usize i = 0; i < block->val.block.stmts.len; i++) { node_t *stmt = stmts[i]; /* Guard statements fold the rest of the block under the success * branch of the generated `if` statement, therefore we continue * processing the block inside the guard statement desugar. */ if (stmt->cls == NODE_GUARD_LET || stmt->cls == NODE_GUARD_CASE) { stmts[i] = desugar_guard_stmt(d, mod, block, i, stmt); return block; } stmts[i] = desugar_node(d, mod, stmt); } return block; } static void node_match_add_case( module_t *mod, node_t *swtch, node_t *swtch_case ) { nodespan_push(&mod->parser, &swtch->val.match_stmt.cases, swtch_case); } static node_t *desugar_for_range( desugar_t *d, module_t *mod, node_t *for_node ) { node_t *range = for_node->val.for_stmt.iter; node_t *index_name = for_node->val.for_stmt.idx ? for_node->val.for_stmt.idx : node_ident(mod, "$i", for_node); node_t *end_name = node_ident(mod, "$end", for_node); node_t *start_expr = range->val.range.start ? desugar_node(d, mod, range->val.range.start) : node_number(mod, "0", range); node_t *index_typ = node_type(mod, TYPE_U32, for_node); node_t *index_var = node_var(mod, index_name, index_typ, start_expr, true, for_node); node_t *end_expr = desugar_node(d, mod, range->val.range.end); node_t *end_typ = node_type(mod, TYPE_U32, for_node); node_t *end_var = node_var(mod, end_name, end_typ, end_expr, false, range); node_t *cond = node_binop( mod, OP_LT, index_var->val.var.ident, end_var->val.var.ident, for_node ); node_t *loop_body = node_block(mod, for_node); node_t *loop_var = node_var( mod, for_node->val.for_stmt.var, NULL, index_var->val.var.ident, false, for_node->val.for_stmt.var ); node_block_add_stmt(mod, loop_body, loop_var); node_block_add_stmt( mod, loop_body, desugar_node(d, mod, for_node->val.for_stmt.body) ); node_t *increment = node_increment( mod, index_var->val.var.ident, index_var->val.var.ident, for_node ); node_block_add_stmt(mod, loop_body, increment); node_t *while_node = node(mod, NODE_WHILE, for_node); while_node->val.while_stmt.cond = cond; while_node->val.while_stmt.body = loop_body; while_node->val.while_stmt.rbranch = for_node->val.for_stmt.rbranch ? desugar_node(d, mod, for_node->val.for_stmt.rbranch) : NULL; node_t *wrapper = node_block(mod, for_node); node_block_add_stmt(mod, wrapper, index_var); node_block_add_stmt(mod, wrapper, end_var); node_block_add_stmt(mod, wrapper, desugar_while(d, mod, while_node)); return wrapper; } /* Transform a for loop into a while loop: * * for var in (iter) { * body; * } else { * rbranch; * } * * becomes: * * { * let $i: u32 = 0; * let $len: u32 = iter.len; * while ($i < $len) { * let var = iter[$i]; * body; * $i = $i + 1; * } else { * rbranch; * } * } */ static node_t *desugar_for(desugar_t *d, module_t *mod, node_t *for_node) { if (for_node->val.for_stmt.iter->cls == NODE_RANGE) { return desugar_for_range(d, mod, for_node); } /* Use simple temporary variable names or user-provided index variable */ node_t *index_name = for_node->val.for_stmt.idx ? for_node->val.for_stmt.idx : node_ident(mod, "$i", for_node); node_t *length_name = node_ident(mod, "$len", for_node); /* Create index variable: let $i: u32 = 0; */ node_t *index_val = node_number(mod, "0", for_node); node_t *index_typ = node_type(mod, TYPE_U32, for_node); node_t *index_var = node_var(mod, index_name, index_typ, index_val, true, for_node); /* Create length variable: let $len: u32 = iter.len; */ node_t *len_field = node_access_str( mod, desugar_node(d, mod, for_node->val.for_stmt.iter), "len", for_node->val.for_stmt.iter ); node_t *length_typ = node_type(mod, TYPE_U32, for_node); node_t *length_var = node_var(mod, length_name, length_typ, len_field, false, for_node); /* Create while condition: $i < $len */ node_t *cond = node_binop( mod, OP_LT, index_var->val.var.ident, length_var->val.var.ident, for_node ); /* Create array index access: iter[$i] */ node_t *array_idx = node(mod, NODE_ARRAY_INDEX, for_node); array_idx->val.access.lval = desugar_node(d, mod, for_node->val.for_stmt.iter); array_idx->val.access.rval = index_var->val.var.ident; /* Create loop variable assignment: let var = iter[$i]; */ node_t *var_name = for_node->val.for_stmt.var; node_t *loop_var = node_var( mod, var_name, NULL, array_idx, false, for_node->val.for_stmt.var ); /* Create increment statement: $i = $i + 1; */ node_t *increment = node_increment( mod, index_var->val.var.ident, index_var->val.var.ident, for_node ); /* Create while body */ node_t *body = node_block(mod, for_node); node_block_add_stmt(mod, body, loop_var); node_block_add_stmt( mod, body, desugar_node(d, mod, for_node->val.for_stmt.body) ); node_block_add_stmt(mod, body, increment); /* Create while node */ node_t *while_node = node(mod, NODE_WHILE, for_node); while_node->val.while_stmt.cond = cond; while_node->val.while_stmt.body = body; while_node->val.while_stmt.rbranch = for_node->val.for_stmt.rbranch ? desugar_node(d, mod, for_node->val.for_stmt.rbranch) : NULL; /* Create wrapper block containing the initialization and while loop */ node_t *wrapper = node_block(mod, for_node); node_block_add_stmt(mod, wrapper, index_var); node_block_add_stmt(mod, wrapper, length_var); node_block_add_stmt(mod, wrapper, desugar_while(d, mod, while_node)); return wrapper; } /* Transform `x and y` into: * * if (x) { * y * } else { * false * } */ static node_t *desugar_and_operator( desugar_t *d, module_t *mod, node_t *binop ) { node_t *if_node = node(mod, NODE_IF, binop); node_t *false_lit = node_bool(mod, false, binop); if_node->val.if_stmt.cond = desugar_node(d, mod, binop->val.binop.left); if_node->val.if_stmt.lbranch = desugar_node(d, mod, binop->val.binop.right); if_node->val.if_stmt.rbranch = false_lit; return if_node; } /* Transform `x or y` into: * * if (x) { * true * } else { * y * } */ static node_t *desugar_or_operator(desugar_t *d, module_t *mod, node_t *binop) { node_t *if_node = node(mod, NODE_IF, binop); node_t *true_lit = node_bool(mod, true, binop); if_node->val.if_stmt.cond = desugar_node(d, mod, binop->val.binop.left); if_node->val.if_stmt.lbranch = true_lit; if_node->val.if_stmt.rbranch = desugar_node(d, mod, binop->val.binop.right); return if_node; } /* Recursively desugar a node and its children */ static node_t *desugar_node(desugar_t *d, module_t *mod, node_t *n) { if (!n) return NULL; switch (n->cls) { case NODE_WHILE: return desugar_while(d, mod, n); case NODE_WHILE_LET: return desugar_while_let(d, mod, n); case NODE_MOD_BODY: case NODE_BLOCK: return desugar_block(d, mod, n); case NODE_IF: n->val.if_stmt.cond = desugar_node(d, mod, n->val.if_stmt.cond); n->val.if_stmt.lbranch = desugar_node(d, mod, n->val.if_stmt.lbranch); if (n->val.if_stmt.rbranch) { n->val.if_stmt.rbranch = desugar_node(d, mod, n->val.if_stmt.rbranch); } return n; case NODE_IF_LET: n->val.if_let_stmt.expr = desugar_node(d, mod, n->val.if_let_stmt.expr); if (n->val.if_let_stmt.guard) { n->val.if_let_stmt.guard = desugar_node(d, mod, n->val.if_let_stmt.guard); } n->val.if_let_stmt.lbranch = desugar_node(d, mod, n->val.if_let_stmt.lbranch); if (n->val.if_let_stmt.rbranch) { n->val.if_let_stmt.rbranch = desugar_node(d, mod, n->val.if_let_stmt.rbranch); } return n; case NODE_IF_CASE: { node_t *pattern = desugar_node(d, mod, n->val.if_case_stmt.pattern); node_t *expr = desugar_node(d, mod, n->val.if_case_stmt.expr); node_t *guard = NULL; if (n->val.if_case_stmt.guard) { guard = desugar_node(d, mod, n->val.if_case_stmt.guard); } node_t *then_block = desugar_node(d, mod, n->val.if_case_stmt.lbranch); node_t *swtch = node_match(mod, expr, n); node_t *case_node = node_match_case(mod, pattern, guard, then_block, n); node_match_add_case(mod, swtch, case_node); if (n->val.if_case_stmt.rbranch) { node_t *else_body = desugar_node(d, mod, n->val.if_case_stmt.rbranch); node_t *default_case = node_match_case(mod, NULL, NULL, else_body, n); node_match_add_case(mod, swtch, default_case); } return swtch; } case NODE_LOOP: n->val.loop_stmt.body = desugar_node(d, mod, n->val.loop_stmt.body); return n; case NODE_FN: n->val.fn_decl.body = desugar_node(d, mod, n->val.fn_decl.body); return n; case NODE_BINOP: /* Handle logical operators with short-circuit evaluation */ if (n->val.binop.op == OP_AND) { return desugar_and_operator(d, mod, n); } if (n->val.binop.op == OP_OR) { return desugar_or_operator(d, mod, n); } /* For other binary operators, recursively desugar operands */ n->val.binop.left = desugar_node(d, mod, n->val.binop.left); n->val.binop.right = desugar_node(d, mod, n->val.binop.right); return n; case NODE_UNOP: n->val.unop.expr = desugar_node(d, mod, n->val.unop.expr); return n; case NODE_CALL: { node_t **args = nodespan_ptrs(&mod->parser, n->val.call.args); for (usize i = 0; i < n->val.call.args.len; i++) { args[i] = desugar_node(d, mod, args[i]); } return n; } case NODE_BUILTIN: { node_t **args = nodespan_ptrs(&mod->parser, n->val.builtin.args); for (usize i = 0; i < n->val.builtin.args.len; i++) { args[i] = desugar_node(d, mod, args[i]); } return n; } case NODE_RETURN: if (n->val.return_stmt.value) { n->val.return_stmt.value = desugar_node(d, mod, n->val.return_stmt.value); } return n; case NODE_VAR: if (n->val.var.value) { n->val.var.value = desugar_node(d, mod, n->val.var.value); } return n; case NODE_ASSIGN: n->val.assign.lval = desugar_node(d, mod, n->val.assign.lval); n->val.assign.rval = desugar_node(d, mod, n->val.assign.rval); return n; case NODE_EXPR_STMT: n->val.expr_stmt = desugar_node(d, mod, n->val.expr_stmt); return n; case NODE_FOR: return desugar_for(d, mod, n); case NODE_MATCH: { n->val.match_stmt.expr = desugar_node(d, mod, n->val.match_stmt.expr); node_t **cases = nodespan_ptrs(&mod->parser, n->val.match_stmt.cases); for (usize i = 0; i < n->val.match_stmt.cases.len; i++) { cases[i] = desugar_node(d, mod, cases[i]); } return n; } case NODE_MATCH_CASE: { node_t **patterns = nodespan_ptrs(&mod->parser, n->val.match_case.patterns); for (usize i = 0; i < n->val.match_case.patterns.len; i++) { patterns[i] = desugar_node(d, mod, patterns[i]); } if (n->val.match_case.guard) { n->val.match_case.guard = desugar_node(d, mod, n->val.match_case.guard); } n->val.match_case.body = desugar_node(d, mod, n->val.match_case.body); return n; } case NODE_ARRAY_INDEX: case NODE_ARRAY_LIT: case NODE_ARRAY_REPEAT_LIT: case NODE_RECORD_LIT: case NODE_CALL_ARG: case NODE_REF: case NODE_ACCESS: case NODE_NUMBER: case NODE_CHAR: case NODE_STRING: case NODE_BOOL: case NODE_NIL: case NODE_UNDEF: case NODE_SCOPE: case NODE_IDENT: case NODE_PLACEHOLDER: case NODE_BREAK: case NODE_USE: case NODE_AS: case NODE_CONST: case NODE_STATIC: case NODE_MOD: case NODE_UNION: case NODE_RECORD: case NODE_PANIC: case NODE_TYPE: case NODE_RECORD_TYPE: return n; case NODE_THROW: n->val.throw_stmt.expr = desugar_node(d, mod, n->val.throw_stmt.expr); return n; case NODE_TRY: { n->val.try_expr.expr = desugar_node(d, mod, n->val.try_expr.expr); node_t **handlers = nodespan_ptrs(&mod->parser, n->val.try_expr.handlers); for (usize i = 0; i < n->val.try_expr.handlers.len; i++) { handlers[i] = desugar_node(d, mod, handlers[i]); } n->val.try_expr.catch_expr = desugar_node(d, mod, n->val.try_expr.catch_expr); return n; } case NODE_CATCH: n->val.catch_clause.body = desugar_node(d, mod, n->val.catch_clause.body); return n; default: bail("unsupported node type %s", node_names[n->cls]); return NULL; } } node_t *desugar_run(desugar_t *d, module_t *mod, node_t *ast) { return desugar_node(d, mod, ast); }