desugar.c 23.1 KiB raw
1
#include <stdio.h>
2
#include <stdlib.h>
3
#include <string.h>
4
5
#include "ast.h"
6
#include "desugar.h"
7
#include "io.h"
8
#include "module.h"
9
#include "parser.h"
10
#include "resolver.h"
11
#include "symtab.h"
12
13
/*
14
 * AST desugaring pass
15
 *
16
 * This pass runs before resolving and transforms the AST.
17
 */
18
19
/* Forward declarations */
20
static node_t *desugar_node(desugar_t *d, module_t *mod, node_t *n);
21
static node_t *desugar_and_operator(desugar_t *d, module_t *mod, node_t *binop);
22
static node_t *desugar_or_operator(desugar_t *d, module_t *mod, node_t *binop);
23
static node_t *desugar_guard_stmt(
24
    desugar_t *d, module_t *mod, node_t *block, usize ix, node_t *guard
25
);
26
static node_t *desugar_block(desugar_t *d, module_t *mod, node_t *block);
27
28
/* Allocate a new AST node using the module's parser */
29
static node_t *node(module_t *mod, nodeclass_t cls, node_t *original) {
30
    parser_t *p = &mod->parser;
31
    if (p->nnodes >= MAX_NODES) {
32
        bail("maximum number of AST nodes reached");
33
    }
34
    node_t *n = &p->nodes[p->nnodes++];
35
    n->cls    = cls;
36
    n->type   = NULL;
37
    n->sym    = NULL;
38
    n->offset = original ? original->offset : 0;
39
    n->length = original ? original->length : 0;
40
41
    return n;
42
}
43
44
static node_t *node_bool(module_t *mod, bool b, node_t *loc) {
45
    node_t *lit       = node(mod, NODE_BOOL, loc);
46
    lit->val.bool_lit = b;
47
48
    return lit;
49
}
50
51
/* Create an empty block node. */
52
static node_t *node_block(module_t *mod, node_t *original) {
53
    node_t *block          = node(mod, NODE_BLOCK, original);
54
    block->val.block.stmts = (nodespan_t){ 0 };
55
    block->val.block.scope = NULL;
56
57
    return block;
58
}
59
60
/* Transform a while-let loop into:
61
 *
62
 *   loop {
63
 *     if let var = expr; guard {
64
 *       body;
65
 *     } else {
66
 *       rbranch;
67
 *       break;
68
 *     }
69
 *   }
70
 */
71
static node_t *desugar_while_let(
72
    desugar_t *d, module_t *mod, node_t *while_let_node
73
) {
74
    /* Create the loop node */
75
    node_t *loop_node = node(mod, NODE_LOOP, while_let_node);
76
    /* Create the if-let node */
77
    node_t *if_let_node = node(mod, NODE_IF_LET, while_let_node);
78
    /* Create the break statement */
79
    node_t *break_node = node(mod, NODE_BREAK, while_let_node);
80
81
    /* Handle the else clause */
82
    node_t *else_clause = node_block(mod, while_let_node);
83
    if (while_let_node->val.while_let_stmt.rbranch) {
84
        node_block_add_stmt(
85
            mod,
86
            else_clause,
87
            desugar_node(d, mod, while_let_node->val.while_let_stmt.rbranch)
88
        );
89
    }
90
    node_block_add_stmt(mod, else_clause, break_node);
91
92
    node_t *loop_body = node_block(mod, while_let_node);
93
    node_block_add_stmt(mod, loop_body, if_let_node);
94
95
    /* Set up the if-let statement */
96
    if_let_node->val.if_let_stmt.var = while_let_node->val.while_let_stmt.var;
97
    if_let_node->val.if_let_stmt.expr =
98
        desugar_node(d, mod, while_let_node->val.while_let_stmt.expr);
99
    if_let_node->val.if_let_stmt.guard =
100
        while_let_node->val.while_let_stmt.guard
101
            ? desugar_node(d, mod, while_let_node->val.while_let_stmt.guard)
102
            : NULL;
103
    if_let_node->val.if_let_stmt.lbranch =
104
        desugar_node(d, mod, while_let_node->val.while_let_stmt.body);
105
    if_let_node->val.if_let_stmt.rbranch = else_clause;
106
    if_let_node->val.if_let_stmt.scope =
107
        NULL; /* Will be set during resolving */
108
109
    /* Set up the loop statement */
110
    loop_node->val.loop_stmt.body = loop_body;
111
112
    return loop_node;
113
}
114
115
/* Transform a while loop into:
116
 *
117
 *   loop {
118
 *     if (condition) {
119
 *       body;
120
 *     } else {
121
 *       else_clause;
122
 *       break
123
 *     }
124
 *   }
125
 */
126
static node_t *desugar_while(desugar_t *d, module_t *mod, node_t *while_node) {
127
    /* Create the loop node */
128
    node_t *loop_node = node(mod, NODE_LOOP, while_node);
129
    /* Create the condition check */
130
    node_t *if_node = node(mod, NODE_IF, while_node);
131
    /* Create the break statement */
132
    node_t *break_node = node(mod, NODE_BREAK, while_node);
133
134
    /* Handle the else clause */
135
    node_t *else_clause = node_block(mod, while_node);
136
    if (while_node->val.while_stmt.rbranch) {
137
        node_block_add_stmt(
138
            mod,
139
            else_clause,
140
            desugar_node(d, mod, while_node->val.while_stmt.rbranch)
141
        );
142
    }
143
    node_block_add_stmt(mod, else_clause, break_node);
144
145
    node_t *loop_body = node_block(mod, while_node->val.while_stmt.body);
146
    node_block_add_stmt(mod, loop_body, if_node);
147
148
    /* Set up the if statement. */
149
    if_node->val.if_stmt.cond =
150
        desugar_node(d, mod, while_node->val.while_stmt.cond);
151
    if_node->val.if_stmt.lbranch =
152
        desugar_node(d, mod, while_node->val.while_stmt.body);
153
    if_node->val.if_stmt.rbranch = else_clause;
154
155
    /* Set up the loop statement */
156
    loop_node->val.loop_stmt.body = loop_body;
157
158
    return loop_node;
159
}
160
161
static node_t *node_ident(module_t *mod, const char *name, node_t *loc) {
162
    node_t *ident           = node(mod, NODE_IDENT, loc);
163
    ident->val.ident.name   = name;
164
    ident->val.ident.length = strlen(name);
165
166
    return ident;
167
}
168
169
static node_t *node_var(
170
    module_t *mod,
171
    node_t   *ident,
172
    node_t   *typ,
173
    node_t   *val,
174
    bool      mut,
175
    node_t   *loc
176
) {
177
    node_t *var = node(mod, NODE_VAR, loc);
178
179
    var->val.var.ident   = ident;
180
    var->val.var.type    = typ;
181
    var->val.var.value   = val;
182
    var->val.var.mutable = mut;
183
184
    return var;
185
}
186
187
static node_t *node_number(module_t *mod, const char *text, node_t *loc) {
188
    node_t *n              = node(mod, NODE_NUMBER, loc);
189
    n->val.number.text     = text;
190
    n->val.number.text_len = strlen(text);
191
192
    return n;
193
}
194
195
static node_t *node_type(module_t *mod, typeclass_t tc, node_t *loc) {
196
    node_t *typ          = node(mod, NODE_TYPE, loc);
197
    typ->val.type.tclass = tc;
198
199
    return typ;
200
}
201
202
static node_t *node_access(
203
    module_t *mod, node_t *lval, node_t *rval, node_t *loc
204
) {
205
    node_t *access          = node(mod, NODE_ACCESS, loc);
206
    access->val.access.lval = lval;
207
    access->val.access.rval = rval;
208
209
    return access;
210
}
211
212
static node_t *node_access_str(
213
    module_t *mod, node_t *lval, const char *field, node_t *loc
214
) {
215
    return node_access(mod, lval, node_ident(mod, field, loc), loc);
216
}
217
218
static node_t *node_binop(
219
    module_t *mod, binop_t op, node_t *left, node_t *right, node_t *loc
220
) {
221
    node_t *binop          = node(mod, NODE_BINOP, loc);
222
    binop->val.binop.op    = op;
223
    binop->val.binop.left  = left;
224
    binop->val.binop.right = right;
225
226
    return binop;
227
}
228
229
static node_t *node_increment(
230
    module_t *mod, node_t *lval_ident, node_t *expr_ident, node_t *loc
231
) {
232
    node_t *assign          = node(mod, NODE_ASSIGN, loc);
233
    assign->val.assign.lval = lval_ident;
234
    assign->val.assign.rval =
235
        node_binop(mod, OP_ADD, expr_ident, node_number(mod, "1", loc), loc);
236
237
    return assign;
238
}
239
240
static node_t *node_match(module_t *mod, node_t *expr, node_t *loc) {
241
    node_t *swtch               = node(mod, NODE_MATCH, loc);
242
    swtch->val.match_stmt.expr  = expr;
243
    swtch->val.match_stmt.cases = (nodespan_t){ 0 };
244
245
    return swtch;
246
}
247
248
static node_t *node_match_case(
249
    module_t *mod, node_t *pattern, node_t *guard, node_t *body, node_t *loc
250
) {
251
    node_t *swtch_case                  = node(mod, NODE_MATCH_CASE, loc);
252
    swtch_case->val.match_case.patterns = (nodespan_t){ 0 };
253
    if (pattern != NULL) {
254
        nodespan_push(
255
            &mod->parser, &swtch_case->val.match_case.patterns, pattern
256
        );
257
    }
258
    swtch_case->val.match_case.body     = body;
259
    swtch_case->val.match_case.guard    = guard;
260
    swtch_case->val.match_case.variable = NULL;
261
262
    return swtch_case;
263
}
264
265
/*
266
 * Transform guard statements into their desugared control flow.
267
 *
268
 * For example, `let value = opt else { handle(); }; rest;` becomes:
269
 *
270
 *     if let value = opt {
271
 *         rest;
272
 *     } else {
273
 *         handle();
274
 *     }
275
 *
276
 * Likewise, `let case Pattern(x) = expr else { ... };` becomes an
277
 * equivalent `if case` construct with the suffix statements placed in
278
 * the success branch so they retain access to bound names.
279
 */
280
static node_t *desugar_guard_stmt(
281
    desugar_t *d, module_t *mod, node_t *block, usize index, node_t *guard
282
) {
283
    node_t *success = node_block(mod, guard);
284
    node_t *if_stmt = NULL;
285
286
    /* Add the rest of the surrounding block into the success branch. */
287
    node_t **stmts = nodespan_ptrs(&mod->parser, block->val.block.stmts);
288
    for (usize j = index + 1; j < block->val.block.stmts.len; j++) {
289
        node_block_add_stmt(mod, success, stmts[j]);
290
    }
291
    if (guard->cls == NODE_GUARD_LET) {
292
        if_stmt                          = node(mod, NODE_IF_LET, guard);
293
        if_stmt->val.if_let_stmt.var     = guard->val.guard_let_stmt.var;
294
        if_stmt->val.if_let_stmt.expr    = guard->val.guard_let_stmt.expr;
295
        if_stmt->val.if_let_stmt.guard   = NULL;
296
        if_stmt->val.if_let_stmt.lbranch = success;
297
        if_stmt->val.if_let_stmt.rbranch = guard->val.guard_let_stmt.rbranch;
298
        if_stmt->val.if_let_stmt.scope   = NULL;
299
    } else {
300
        if_stmt                           = node(mod, NODE_IF_CASE, guard);
301
        if_stmt->val.if_case_stmt.pattern = guard->val.guard_case_stmt.pattern;
302
        if_stmt->val.if_case_stmt.expr    = guard->val.guard_case_stmt.expr;
303
        if_stmt->val.if_case_stmt.guard   = guard->val.guard_case_stmt.guard;
304
        if_stmt->val.if_case_stmt.lbranch = success;
305
        if_stmt->val.if_case_stmt.rbranch = guard->val.guard_case_stmt.rbranch;
306
    }
307
    block->val.block.stmts.len = index + 1;
308
309
    return desugar_node(d, mod, if_stmt);
310
}
311
312
static node_t *desugar_block(desugar_t *d, module_t *mod, node_t *block) {
313
    node_t **stmts = nodespan_ptrs(&mod->parser, block->val.block.stmts);
314
    for (usize i = 0; i < block->val.block.stmts.len; i++) {
315
        node_t *stmt = stmts[i];
316
        /* Guard statements fold the rest of the block under the success
317
         * branch of the generated `if` statement, therefore we continue
318
         * processing the block inside the guard statement desugar. */
319
        if (stmt->cls == NODE_GUARD_LET || stmt->cls == NODE_GUARD_CASE) {
320
            stmts[i] = desugar_guard_stmt(d, mod, block, i, stmt);
321
            return block;
322
        }
323
        stmts[i] = desugar_node(d, mod, stmt);
324
    }
325
    return block;
326
}
327
328
static void node_match_add_case(
329
    module_t *mod, node_t *swtch, node_t *swtch_case
330
) {
331
    nodespan_push(&mod->parser, &swtch->val.match_stmt.cases, swtch_case);
332
}
333
334
static node_t *desugar_for_range(
335
    desugar_t *d, module_t *mod, node_t *for_node
336
) {
337
    node_t *range = for_node->val.for_stmt.iter;
338
339
    node_t *index_name = for_node->val.for_stmt.idx
340
                             ? for_node->val.for_stmt.idx
341
                             : node_ident(mod, "$i", for_node);
342
    node_t *end_name   = node_ident(mod, "$end", for_node);
343
    node_t *start_expr = range->val.range.start
344
                             ? desugar_node(d, mod, range->val.range.start)
345
                             : node_number(mod, "0", range);
346
    node_t *index_typ  = node_type(mod, TYPE_U32, for_node);
347
    node_t *index_var =
348
        node_var(mod, index_name, index_typ, start_expr, true, for_node);
349
350
    node_t *end_expr = desugar_node(d, mod, range->val.range.end);
351
    node_t *end_typ  = node_type(mod, TYPE_U32, for_node);
352
    node_t *end_var  = node_var(mod, end_name, end_typ, end_expr, false, range);
353
354
    node_t *cond = node_binop(
355
        mod, OP_LT, index_var->val.var.ident, end_var->val.var.ident, for_node
356
    );
357
    node_t *loop_body = node_block(mod, for_node);
358
    node_t *loop_var  = node_var(
359
        mod,
360
        for_node->val.for_stmt.var,
361
        NULL,
362
        index_var->val.var.ident,
363
        false,
364
        for_node->val.for_stmt.var
365
    );
366
    node_block_add_stmt(mod, loop_body, loop_var);
367
    node_block_add_stmt(
368
        mod, loop_body, desugar_node(d, mod, for_node->val.for_stmt.body)
369
    );
370
371
    node_t *increment = node_increment(
372
        mod, index_var->val.var.ident, index_var->val.var.ident, for_node
373
    );
374
    node_block_add_stmt(mod, loop_body, increment);
375
376
    node_t *while_node              = node(mod, NODE_WHILE, for_node);
377
    while_node->val.while_stmt.cond = cond;
378
    while_node->val.while_stmt.body = loop_body;
379
    while_node->val.while_stmt.rbranch =
380
        for_node->val.for_stmt.rbranch
381
            ? desugar_node(d, mod, for_node->val.for_stmt.rbranch)
382
            : NULL;
383
384
    node_t *wrapper = node_block(mod, for_node);
385
    node_block_add_stmt(mod, wrapper, index_var);
386
    node_block_add_stmt(mod, wrapper, end_var);
387
    node_block_add_stmt(mod, wrapper, desugar_while(d, mod, while_node));
388
389
    return wrapper;
390
}
391
392
/* Transform a for loop into a while loop:
393
 *
394
 *   for var in (iter) {
395
 *     body;
396
 *   } else {
397
 *     rbranch;
398
 *   }
399
 *
400
 * becomes:
401
 *
402
 *   {
403
 *     let $i: u32 = 0;
404
 *     let $len: u32 = iter.len;
405
 *     while ($i < $len) {
406
 *       let var = iter[$i];
407
 *       body;
408
 *       $i = $i + 1;
409
 *     } else {
410
 *       rbranch;
411
 *     }
412
 *   }
413
 */
414
static node_t *desugar_for(desugar_t *d, module_t *mod, node_t *for_node) {
415
    if (for_node->val.for_stmt.iter->cls == NODE_RANGE) {
416
        return desugar_for_range(d, mod, for_node);
417
    }
418
    /* Use simple temporary variable names or user-provided index variable */
419
    node_t *index_name  = for_node->val.for_stmt.idx
420
                              ? for_node->val.for_stmt.idx
421
                              : node_ident(mod, "$i", for_node);
422
    node_t *length_name = node_ident(mod, "$len", for_node);
423
424
    /* Create index variable: let $i: u32 = 0; */
425
    node_t *index_val = node_number(mod, "0", for_node);
426
    node_t *index_typ = node_type(mod, TYPE_U32, for_node);
427
    node_t *index_var =
428
        node_var(mod, index_name, index_typ, index_val, true, for_node);
429
430
    /* Create length variable: let $len: u32 = iter.len; */
431
    node_t *len_field = node_access_str(
432
        mod,
433
        desugar_node(d, mod, for_node->val.for_stmt.iter),
434
        "len",
435
        for_node->val.for_stmt.iter
436
    );
437
    node_t *length_typ = node_type(mod, TYPE_U32, for_node);
438
    node_t *length_var =
439
        node_var(mod, length_name, length_typ, len_field, false, for_node);
440
441
    /* Create while condition: $i < $len */
442
    node_t *cond = node_binop(
443
        mod,
444
        OP_LT,
445
        index_var->val.var.ident,
446
        length_var->val.var.ident,
447
        for_node
448
    );
449
450
    /* Create array index access: iter[$i] */
451
    node_t *array_idx = node(mod, NODE_ARRAY_INDEX, for_node);
452
    array_idx->val.access.lval =
453
        desugar_node(d, mod, for_node->val.for_stmt.iter);
454
    array_idx->val.access.rval = index_var->val.var.ident;
455
456
    /* Create loop variable assignment: let var = iter[$i]; */
457
    node_t *var_name = for_node->val.for_stmt.var;
458
    node_t *loop_var = node_var(
459
        mod, var_name, NULL, array_idx, false, for_node->val.for_stmt.var
460
    );
461
462
    /* Create increment statement: $i = $i + 1; */
463
    node_t *increment = node_increment(
464
        mod, index_var->val.var.ident, index_var->val.var.ident, for_node
465
    );
466
467
    /* Create while body */
468
    node_t *body = node_block(mod, for_node);
469
    node_block_add_stmt(mod, body, loop_var);
470
    node_block_add_stmt(
471
        mod, body, desugar_node(d, mod, for_node->val.for_stmt.body)
472
    );
473
    node_block_add_stmt(mod, body, increment);
474
475
    /* Create while node */
476
    node_t *while_node              = node(mod, NODE_WHILE, for_node);
477
    while_node->val.while_stmt.cond = cond;
478
    while_node->val.while_stmt.body = body;
479
    while_node->val.while_stmt.rbranch =
480
        for_node->val.for_stmt.rbranch
481
            ? desugar_node(d, mod, for_node->val.for_stmt.rbranch)
482
            : NULL;
483
484
    /* Create wrapper block containing the initialization and while loop */
485
    node_t *wrapper = node_block(mod, for_node);
486
    node_block_add_stmt(mod, wrapper, index_var);
487
    node_block_add_stmt(mod, wrapper, length_var);
488
    node_block_add_stmt(mod, wrapper, desugar_while(d, mod, while_node));
489
490
    return wrapper;
491
}
492
493
/* Transform `x and y` into:
494
 *
495
 *   if (x) {
496
 *     y
497
 *   } else {
498
 *     false
499
 *   }
500
 */
501
static node_t *desugar_and_operator(
502
    desugar_t *d, module_t *mod, node_t *binop
503
) {
504
    node_t *if_node   = node(mod, NODE_IF, binop);
505
    node_t *false_lit = node_bool(mod, false, binop);
506
507
    if_node->val.if_stmt.cond    = desugar_node(d, mod, binop->val.binop.left);
508
    if_node->val.if_stmt.lbranch = desugar_node(d, mod, binop->val.binop.right);
509
    if_node->val.if_stmt.rbranch = false_lit;
510
511
    return if_node;
512
}
513
514
/* Transform `x or y` into:
515
 *
516
 *   if (x) {
517
 *     true
518
 *   } else {
519
 *     y
520
 *   }
521
 */
522
static node_t *desugar_or_operator(desugar_t *d, module_t *mod, node_t *binop) {
523
    node_t *if_node  = node(mod, NODE_IF, binop);
524
    node_t *true_lit = node_bool(mod, true, binop);
525
526
    if_node->val.if_stmt.cond    = desugar_node(d, mod, binop->val.binop.left);
527
    if_node->val.if_stmt.lbranch = true_lit;
528
    if_node->val.if_stmt.rbranch = desugar_node(d, mod, binop->val.binop.right);
529
530
    return if_node;
531
}
532
533
/* Recursively desugar a node and its children */
534
static node_t *desugar_node(desugar_t *d, module_t *mod, node_t *n) {
535
    if (!n)
536
        return NULL;
537
538
    switch (n->cls) {
539
    case NODE_WHILE:
540
        return desugar_while(d, mod, n);
541
542
    case NODE_WHILE_LET:
543
        return desugar_while_let(d, mod, n);
544
545
    case NODE_MOD_BODY:
546
    case NODE_BLOCK:
547
        return desugar_block(d, mod, n);
548
549
    case NODE_IF:
550
        n->val.if_stmt.cond    = desugar_node(d, mod, n->val.if_stmt.cond);
551
        n->val.if_stmt.lbranch = desugar_node(d, mod, n->val.if_stmt.lbranch);
552
        if (n->val.if_stmt.rbranch) {
553
            n->val.if_stmt.rbranch =
554
                desugar_node(d, mod, n->val.if_stmt.rbranch);
555
        }
556
        return n;
557
558
    case NODE_IF_LET:
559
        n->val.if_let_stmt.expr = desugar_node(d, mod, n->val.if_let_stmt.expr);
560
        if (n->val.if_let_stmt.guard) {
561
            n->val.if_let_stmt.guard =
562
                desugar_node(d, mod, n->val.if_let_stmt.guard);
563
        }
564
        n->val.if_let_stmt.lbranch =
565
            desugar_node(d, mod, n->val.if_let_stmt.lbranch);
566
        if (n->val.if_let_stmt.rbranch) {
567
            n->val.if_let_stmt.rbranch =
568
                desugar_node(d, mod, n->val.if_let_stmt.rbranch);
569
        }
570
        return n;
571
572
    case NODE_IF_CASE: {
573
        node_t *pattern = desugar_node(d, mod, n->val.if_case_stmt.pattern);
574
        node_t *expr    = desugar_node(d, mod, n->val.if_case_stmt.expr);
575
        node_t *guard   = NULL;
576
577
        if (n->val.if_case_stmt.guard) {
578
            guard = desugar_node(d, mod, n->val.if_case_stmt.guard);
579
        }
580
        node_t *then_block = desugar_node(d, mod, n->val.if_case_stmt.lbranch);
581
        node_t *swtch      = node_match(mod, expr, n);
582
583
        node_t *case_node = node_match_case(mod, pattern, guard, then_block, n);
584
585
        node_match_add_case(mod, swtch, case_node);
586
587
        if (n->val.if_case_stmt.rbranch) {
588
            node_t *else_body =
589
                desugar_node(d, mod, n->val.if_case_stmt.rbranch);
590
            node_t *default_case =
591
                node_match_case(mod, NULL, NULL, else_body, n);
592
593
            node_match_add_case(mod, swtch, default_case);
594
        }
595
        return swtch;
596
    }
597
598
    case NODE_LOOP:
599
        n->val.loop_stmt.body = desugar_node(d, mod, n->val.loop_stmt.body);
600
        return n;
601
602
    case NODE_FN:
603
        n->val.fn_decl.body = desugar_node(d, mod, n->val.fn_decl.body);
604
        return n;
605
606
    case NODE_BINOP:
607
        /* Handle logical operators with short-circuit evaluation */
608
        if (n->val.binop.op == OP_AND) {
609
            return desugar_and_operator(d, mod, n);
610
        }
611
        if (n->val.binop.op == OP_OR) {
612
            return desugar_or_operator(d, mod, n);
613
        }
614
        /* For other binary operators, recursively desugar operands */
615
        n->val.binop.left  = desugar_node(d, mod, n->val.binop.left);
616
        n->val.binop.right = desugar_node(d, mod, n->val.binop.right);
617
        return n;
618
619
    case NODE_UNOP:
620
        n->val.unop.expr = desugar_node(d, mod, n->val.unop.expr);
621
        return n;
622
623
    case NODE_CALL: {
624
        node_t **args = nodespan_ptrs(&mod->parser, n->val.call.args);
625
        for (usize i = 0; i < n->val.call.args.len; i++) {
626
            args[i] = desugar_node(d, mod, args[i]);
627
        }
628
        return n;
629
    }
630
631
    case NODE_BUILTIN: {
632
        node_t **args = nodespan_ptrs(&mod->parser, n->val.builtin.args);
633
        for (usize i = 0; i < n->val.builtin.args.len; i++) {
634
            args[i] = desugar_node(d, mod, args[i]);
635
        }
636
        return n;
637
    }
638
639
    case NODE_RETURN:
640
        if (n->val.return_stmt.value) {
641
            n->val.return_stmt.value =
642
                desugar_node(d, mod, n->val.return_stmt.value);
643
        }
644
        return n;
645
646
    case NODE_VAR:
647
        if (n->val.var.value) {
648
            n->val.var.value = desugar_node(d, mod, n->val.var.value);
649
        }
650
        return n;
651
652
    case NODE_ASSIGN:
653
        n->val.assign.lval = desugar_node(d, mod, n->val.assign.lval);
654
        n->val.assign.rval = desugar_node(d, mod, n->val.assign.rval);
655
        return n;
656
657
    case NODE_EXPR_STMT:
658
        n->val.expr_stmt = desugar_node(d, mod, n->val.expr_stmt);
659
        return n;
660
661
    case NODE_FOR:
662
        return desugar_for(d, mod, n);
663
664
    case NODE_MATCH: {
665
        n->val.match_stmt.expr = desugar_node(d, mod, n->val.match_stmt.expr);
666
        node_t **cases = nodespan_ptrs(&mod->parser, n->val.match_stmt.cases);
667
        for (usize i = 0; i < n->val.match_stmt.cases.len; i++) {
668
            cases[i] = desugar_node(d, mod, cases[i]);
669
        }
670
        return n;
671
    }
672
673
    case NODE_MATCH_CASE: {
674
        node_t **patterns =
675
            nodespan_ptrs(&mod->parser, n->val.match_case.patterns);
676
        for (usize i = 0; i < n->val.match_case.patterns.len; i++) {
677
            patterns[i] = desugar_node(d, mod, patterns[i]);
678
        }
679
        if (n->val.match_case.guard) {
680
            n->val.match_case.guard =
681
                desugar_node(d, mod, n->val.match_case.guard);
682
        }
683
        n->val.match_case.body = desugar_node(d, mod, n->val.match_case.body);
684
        return n;
685
    }
686
687
    case NODE_ARRAY_INDEX:
688
    case NODE_ARRAY_LIT:
689
    case NODE_ARRAY_REPEAT_LIT:
690
    case NODE_RECORD_LIT:
691
    case NODE_CALL_ARG:
692
    case NODE_REF:
693
    case NODE_ACCESS:
694
    case NODE_NUMBER:
695
    case NODE_CHAR:
696
    case NODE_STRING:
697
    case NODE_BOOL:
698
    case NODE_NIL:
699
    case NODE_UNDEF:
700
    case NODE_SCOPE:
701
    case NODE_IDENT:
702
    case NODE_PLACEHOLDER:
703
    case NODE_BREAK:
704
    case NODE_USE:
705
    case NODE_AS:
706
    case NODE_CONST:
707
    case NODE_STATIC:
708
    case NODE_MOD:
709
    case NODE_UNION:
710
    case NODE_RECORD:
711
    case NODE_PANIC:
712
    case NODE_TYPE:
713
    case NODE_RECORD_TYPE:
714
        return n;
715
716
    case NODE_THROW:
717
        n->val.throw_stmt.expr = desugar_node(d, mod, n->val.throw_stmt.expr);
718
        return n;
719
720
    case NODE_TRY: {
721
        n->val.try_expr.expr = desugar_node(d, mod, n->val.try_expr.expr);
722
        node_t **handlers =
723
            nodespan_ptrs(&mod->parser, n->val.try_expr.handlers);
724
        for (usize i = 0; i < n->val.try_expr.handlers.len; i++) {
725
            handlers[i] = desugar_node(d, mod, handlers[i]);
726
        }
727
        n->val.try_expr.catch_expr =
728
            desugar_node(d, mod, n->val.try_expr.catch_expr);
729
        return n;
730
    }
731
    case NODE_CATCH:
732
        n->val.catch_clause.body =
733
            desugar_node(d, mod, n->val.catch_clause.body);
734
        return n;
735
736
    default:
737
        bail("unsupported node type %s", node_names[n->cls]);
738
        return NULL;
739
    }
740
}
741
742
node_t *desugar_run(desugar_t *d, module_t *mod, node_t *ast) {
743
    return desugar_node(d, mod, ast);
744
}