Support multiple error types per function

c8c83e2bcd385a6c004c74096f5712256d607c5825a4bc75e7bed6ba3141f255
Each error type is assigned a globally unique tag at lowering time.
Alexis Sellier committed ago 1 parent fcb3c9b3
lib/std/arch/rv64/tests/error.multi.basic.rad added +45 -0
1 +
//! Test basic multi-error throw and catch.
2 +
3 +
union ErrA { A }
4 +
union ErrB { B }
5 +
6 +
fn failA() -> i32 throws (ErrA, ErrB) {
7 +
    throw ErrA::A();
8 +
}
9 +
10 +
fn failB() -> i32 throws (ErrA, ErrB) {
11 +
    throw ErrB::B();
12 +
}
13 +
14 +
fn succeed() -> i32 throws (ErrA, ErrB) {
15 +
    return 42;
16 +
}
17 +
18 +
@default fn main() -> i32 {
19 +
    // Test throwing ErrA.
20 +
    let mut caught: i32 = 0;
21 +
    try failA() catch {
22 +
        caught = 1;
23 +
    };
24 +
    if caught != 1 {
25 +
        return 1;
26 +
    }
27 +
28 +
    // Test throwing ErrB.
29 +
    caught = 0;
30 +
    try failB() catch {
31 +
        caught = 2;
32 +
    };
33 +
    if caught != 2 {
34 +
        return 2;
35 +
    }
36 +
37 +
    // Test success path.
38 +
    let val = try succeed() catch {
39 +
        return 3;
40 +
    };
41 +
    if val != 42 {
42 +
        return 4;
43 +
    }
44 +
    return 0;
45 +
}
lib/std/arch/rv64/tests/error.multi.catch.rad added +49 -0
1 +
//! Test catch {} (no binding) with multi-error callee.
2 +
3 +
union ErrA { A }
4 +
union ErrB { B }
5 +
6 +
fn failA() -> i32 throws (ErrA, ErrB) {
7 +
    throw ErrA::A();
8 +
}
9 +
10 +
fn failB() -> i32 throws (ErrA, ErrB) {
11 +
    throw ErrB::B();
12 +
}
13 +
14 +
fn succeed() -> i32 throws (ErrA, ErrB) {
15 +
    return 99;
16 +
}
17 +
18 +
@default fn main() -> i32 {
19 +
    // Catch ErrA without binding.
20 +
    let mut handled: i32 = 0;
21 +
    try failA() catch {
22 +
        handled = 1;
23 +
    };
24 +
    if handled != 1 {
25 +
        return 1;
26 +
    }
27 +
28 +
    // Catch ErrB without binding.
29 +
    handled = 0;
30 +
    try failB() catch {
31 +
        handled = 2;
32 +
    };
33 +
    if handled != 2 {
34 +
        return 2;
35 +
    }
36 +
37 +
    // Success path should not trigger catch.
38 +
    handled = 0;
39 +
    let val = try succeed() catch {
40 +
        handled = 99;
41 +
    };
42 +
    if handled != 0 {
43 +
        return 3;
44 +
    }
45 +
    if val != 99 {
46 +
        return 4;
47 +
    }
48 +
    return 0;
49 +
}
lib/std/arch/rv64/tests/error.multi.catch.typed.binding.rad added +43 -0
1 +
//! Test typed multi-catch with payload extraction via bindings.
2 +
3 +
union ErrA { A(i32) }
4 +
union ErrB { B(i32) }
5 +
6 +
fn failA(v: i32) -> i32 throws (ErrA, ErrB) {
7 +
    throw ErrA::A(v);
8 +
}
9 +
10 +
fn failB(v: i32) -> i32 throws (ErrA, ErrB) {
11 +
    throw ErrB::B(v);
12 +
}
13 +
14 +
@default fn main() -> i32 {
15 +
    // Extract ErrA payload.
16 +
    let mut got: i32 = 0;
17 +
    try failA(100) catch e as ErrA {
18 +
        let case ErrA::A(v) = e else {
19 +
            return 1;
20 +
        };
21 +
        got = v;
22 +
    } catch e as ErrB {
23 +
        return 2;
24 +
    };
25 +
    if got != 100 {
26 +
        return 3;
27 +
    }
28 +
29 +
    // Extract ErrB payload.
30 +
    got = 0;
31 +
    try failB(200) catch e as ErrA {
32 +
        return 4;
33 +
    } catch e as ErrB {
34 +
        let case ErrB::B(v) = e else {
35 +
            return 5;
36 +
        };
37 +
        got = v;
38 +
    };
39 +
    if got != 200 {
40 +
        return 6;
41 +
    }
42 +
    return 0;
43 +
}
lib/std/arch/rv64/tests/error.multi.catch.typed.catchall.rad added +56 -0
1 +
//! Test typed catch with a catch-all fallback.
2 +
3 +
union ErrA { A(i32) }
4 +
union ErrB { B(i32) }
5 +
union ErrC { C }
6 +
7 +
fn failA() -> i32 throws (ErrA, ErrB, ErrC) {
8 +
    throw ErrA::A(10);
9 +
}
10 +
11 +
fn failB() -> i32 throws (ErrA, ErrB, ErrC) {
12 +
    throw ErrB::B(20);
13 +
}
14 +
15 +
fn failC() -> i32 throws (ErrA, ErrB, ErrC) {
16 +
    throw ErrC::C();
17 +
}
18 +
19 +
@default fn main() -> i32 {
20 +
    // Typed catch for ErrA, catch-all for the rest.
21 +
    let mut result: i32 = 0;
22 +
    try failA() catch e as ErrA {
23 +
        let case ErrA::A(v) = e else {
24 +
            return 10;
25 +
        };
26 +
        result = v;
27 +
    } catch {
28 +
        return 1;
29 +
    };
30 +
    if result != 10 {
31 +
        return 2;
32 +
    }
33 +
34 +
    // ErrB should fall into catch-all.
35 +
    let mut caught_all: i32 = 0;
36 +
    try failB() catch e as ErrA {
37 +
        return 3;
38 +
    } catch {
39 +
        caught_all = 1;
40 +
    };
41 +
    if caught_all != 1 {
42 +
        return 4;
43 +
    }
44 +
45 +
    // ErrC should also fall into catch-all.
46 +
    caught_all = 0;
47 +
    try failC() catch e as ErrA {
48 +
        return 5;
49 +
    } catch {
50 +
        caught_all = 1;
51 +
    };
52 +
    if caught_all != 1 {
53 +
        return 6;
54 +
    }
55 +
    return 0;
56 +
}
lib/std/arch/rv64/tests/error.multi.catch.typed.rad added +59 -0
1 +
//! Test typed multi-catch: catch e as ErrA {} catch e as ErrB {}.
2 +
3 +
union ErrA { A(i32) }
4 +
union ErrB { B(i32) }
5 +
6 +
fn failA() -> i32 throws (ErrA, ErrB) {
7 +
    throw ErrA::A(10);
8 +
}
9 +
10 +
fn failB() -> i32 throws (ErrA, ErrB) {
11 +
    throw ErrB::B(20);
12 +
}
13 +
14 +
fn succeed() -> i32 throws (ErrA, ErrB) {
15 +
    return 42;
16 +
}
17 +
18 +
@default fn main() -> i32 {
19 +
    // Catch ErrA specifically with typed binding.
20 +
    let mut result: i32 = 0;
21 +
    let val1 = try failA() catch e as ErrA {
22 +
        let case ErrA::A(v) = e else {
23 +
            return 10;
24 +
        };
25 +
        result = v;
26 +
        0
27 +
    } catch e as ErrB {
28 +
        return 1;
29 +
    };
30 +
    if result != 10 {
31 +
        return 2;
32 +
    }
33 +
34 +
    // Catch ErrB specifically with typed binding.
35 +
    result = 0;
36 +
    let val2 = try failB() catch e as ErrA {
37 +
        return 3;
38 +
    } catch e as ErrB {
39 +
        let case ErrB::B(v) = e else {
40 +
            return 11;
41 +
        };
42 +
        result = v;
43 +
        0
44 +
    };
45 +
    if result != 20 {
46 +
        return 4;
47 +
    }
48 +
49 +
    // Success path should skip all catches.
50 +
    let val3 = try succeed() catch e as ErrA {
51 +
        return 5;
52 +
    } catch e as ErrB {
53 +
        return 6;
54 +
    };
55 +
    if val3 != 42 {
56 +
        return 7;
57 +
    }
58 +
    return 0;
59 +
}
lib/std/arch/rv64/tests/error.multi.propagate.multi.rad added +54 -0
1 +
//! Test propagating a multi-throw callee via plain `try`.
2 +
3 +
union ErrA { A }
4 +
union ErrB { B }
5 +
6 +
fn inner(flag: i32) -> i32 throws (ErrA, ErrB) {
7 +
    if flag == 1 {
8 +
        throw ErrA::A();
9 +
    }
10 +
    if flag == 2 {
11 +
        throw ErrB::B();
12 +
    }
13 +
    return 100;
14 +
}
15 +
16 +
fn outer(flag: i32) -> i32 throws (ErrA, ErrB) {
17 +
    let val = try inner(flag);
18 +
    return val + 1;
19 +
}
20 +
21 +
@default fn main() -> i32 {
22 +
    // Propagate ErrA through outer.
23 +
    let mut result: i32 = 0;
24 +
    try outer(1) catch e as ErrA {
25 +
        result = 10;
26 +
    } catch e as ErrB {
27 +
        result = 20;
28 +
    };
29 +
    if result != 10 {
30 +
        return 1;
31 +
    }
32 +
33 +
    // Propagate ErrB through outer.
34 +
    result = 0;
35 +
    try outer(2) catch e as ErrA {
36 +
        result = 10;
37 +
    } catch e as ErrB {
38 +
        result = 20;
39 +
    };
40 +
    if result != 20 {
41 +
        return 2;
42 +
    }
43 +
44 +
    // Success path.
45 +
    let val = try outer(0) catch e as ErrA {
46 +
        return 3;
47 +
    } catch e as ErrB {
48 +
        return 4;
49 +
    };
50 +
    if val != 101 {
51 +
        return 5;
52 +
    }
53 +
    return 0;
54 +
}
lib/std/arch/rv64/tests/error.multi.propagate.rad added +51 -0
1 +
//! Test error propagation with multi-error and subset throw lists.
2 +
3 +
union ErrA { A }
4 +
union ErrB { B }
5 +
6 +
fn innerA() -> i32 throws (ErrA) {
7 +
    throw ErrA::A();
8 +
}
9 +
10 +
fn innerB() -> i32 throws (ErrB) {
11 +
    throw ErrB::B();
12 +
}
13 +
14 +
fn middle(flag: i32) -> i32 throws (ErrA, ErrB) {
15 +
    if flag == 1 {
16 +
        return try innerA();
17 +
    }
18 +
    if flag == 2 {
19 +
        return try innerB();
20 +
    }
21 +
    return 100;
22 +
}
23 +
24 +
@default fn main() -> i32 {
25 +
    // Propagate ErrA through middle.
26 +
    let mut caught: i32 = 0;
27 +
    try middle(1) catch {
28 +
        caught = 1;
29 +
    };
30 +
    if caught != 1 {
31 +
        return 1;
32 +
    }
33 +
34 +
    // Propagate ErrB through middle.
35 +
    caught = 0;
36 +
    try middle(2) catch {
37 +
        caught = 2;
38 +
    };
39 +
    if caught != 2 {
40 +
        return 2;
41 +
    }
42 +
43 +
    // Success path.
44 +
    let val = try middle(0) catch {
45 +
        return 3;
46 +
    };
47 +
    if val != 100 {
48 +
        return 4;
49 +
    }
50 +
    return 0;
51 +
}
lib/std/arch/rv64/tests/error.multi.try.optional.rad added +33 -0
1 +
//! Test try? with multi-error callee.
2 +
3 +
union ErrA { A }
4 +
union ErrB { B }
5 +
6 +
fn failA() -> i32 throws (ErrA, ErrB) {
7 +
    throw ErrA::A();
8 +
}
9 +
10 +
fn succeed() -> i32 throws (ErrA, ErrB) {
11 +
    return 77;
12 +
}
13 +
14 +
@default fn main() -> i32 {
15 +
    // try? on failure should produce nil.
16 +
    let r1 = try? failA();
17 +
    if r1 != nil {
18 +
        return 1;
19 +
    }
20 +
21 +
    // try? on success should produce the value.
22 +
    let r2 = try? succeed();
23 +
    if r2 == nil {
24 +
        return 2;
25 +
    }
26 +
    let val = r2 else {
27 +
        return 3;
28 +
    };
29 +
    if val != 77 {
30 +
        return 4;
31 +
    }
32 +
    return 0;
33 +
}
lib/std/lang/ast.rad +15 -5
351 351
352 352
/// Try expression metadata.
353 353
pub record Try {
354 354
    /// Expression evaluated with implicit error propagation.
355 355
    expr: *Node,
356 -
    /// Fallback evaluated when the try fails.
357 -
    catchExpr: ?*Node,
358 -
    /// Optional identifier binding for the error value in the catch clause.
359 -
    catchBinding: ?*Node,
356 +
    /// Catch clauses. Empty for propagation (`try`), `try!`, or `try?`.
357 +
    catches: NodeList,
360 358
    /// Whether the try should panic instead of returning an error.
361 359
    shouldPanic: bool,
362 360
    /// Whether the try should return an optional instead of propagating error.
363 361
    returnsOptional: bool,
364 362
}
365 363
364 +
/// A single catch clause in a `try ... catch` expression.
365 +
pub record CatchClause {
366 +
    /// Optional identifier binding for the error value (eg. `e`).
367 +
    binding: ?*Node,
368 +
    /// Optional type annotation after `as` (eg. `IoError`).
369 +
    typeNode: ?*Node,
370 +
    /// Block body executed when this clause matches.
371 +
    body: *Node,
372 +
}
373 +
366 374
/// `for` loop metadata.
367 375
pub record For {
368 376
    /// Loop variable binding.
369 377
    binding: *Node,
370 378
    /// Optional index binding for enumeration loops.
783 791
    RecordLitField(Arg),
784 792
    /// Alignment specifier.
785 793
    Align {
786 794
        /// Alignment value.
787 795
        value: *Node,
788 -
    }
796 +
    },
797 +
    /// Catch clause within a try expression.
798 +
    CatchClause(CatchClause),
789 799
}
790 800
791 801
/// Full AST node with shared metadata and variant-specific payload.
792 802
pub record Node {
793 803
    /// Unique identifier for this node.
lib/std/lang/ast/printer.rad +21 -1
319 319
            }
320 320
        }
321 321
        case super::NodeValue::Try(t) => {
322 322
            let mut head = "try";
323 323
            if t.shouldPanic { head = "try!"; }
324 -
            return sexpr::list(a, head, &[toExpr(a, t.expr), toExprOrNull(a, t.catchExpr)]);
324 +
            if t.catches.len > 0 {
325 +
                let catches = nodeListToExprs(a, &t.catches);
326 +
                return sexpr::list(a, head, &[toExpr(a, t.expr), sexpr::block(a, "catches", &[], catches)]);
327 +
            }
328 +
            return sexpr::list(a, head, &[toExpr(a, t.expr)]);
329 +
        }
330 +
        case super::NodeValue::CatchClause(clause) => {
331 +
            let mut head = "catch";
332 +
            let mut children: [sexpr::Expr; 3] = undefined;
333 +
            let mut len: u32 = 0;
334 +
            if let b = clause.binding {
335 +
                children[len] = toExpr(a, b);
336 +
                len = len + 1;
337 +
            }
338 +
            if let t = clause.typeNode {
339 +
                children[len] = toExpr(a, t);
340 +
                len = len + 1;
341 +
            }
342 +
            children[len] = toExpr(a, clause.body);
343 +
            len = len + 1;
344 +
            return sexpr::list(a, head, &children[..len]);
325 345
        }
326 346
        case super::NodeValue::Block(blk) => {
327 347
            let children = nodeListToExprs(a, &blk.statements);
328 348
            return sexpr::block(a, "block", &[], children);
329 349
        }
lib/std/lang/lower.rad +209 -33
98 98
use std::lang::resolver;
99 99
100 100
// TODO: Search for all `_ as i32` to ensure that casts from u32 to i32 don't
101 101
// happen, since they are potentially truncating values.
102 102
103 -
// TODO: Support more than one thrown error type per function.
104 103
// TODO: Support constant union lowering.
105 104
// TODO: Void unions should be passed by value.
106 105
107 106
////////////////////
108 107
// Error Handling //
210 209
const INITIAL_PREDS: u32 = 4;
211 210
/// Initial capacity for global data.
212 211
const INITIAL_DATA: u32 = 8;
213 212
/// Initial capacity for functions.
214 213
const INITIAL_FNS: u32 = 16;
214 +
/// Initial capacity for the error tag table.
215 +
const INITIAL_ERR_TAGS: u32 = 8;
215 216
/// Maximum nesting depth of loops.
216 217
const MAX_LOOP_DEPTH: u32 = 16;
217 218
/// Maximum number of function symbols to track.
218 219
const MAX_FN_SYMS: u32 = 2048;
220 +
/// Maximum number of `catch` clauses per `try`.
221 +
const MAX_CATCH_CLAUSES: u32 = 32;
219 222
220 223
// Slice Layout
221 224
//
222 225
// A slice is a fat pointer consisting of a data pointer and length.
223 226
// `{ ptr: u32, len: u32 }`.
281 284
    fnCount: u32,
282 285
    /// Map of function symbols to qualified names.
283 286
    fnSyms: *mut [FnSymEntry],
284 287
    /// Number of entries in fnSyms.
285 288
    fnSymsLen: u32,
289 +
    /// Global error type tag table. Maps nominal types to unique tags.
290 +
    errTags: *mut [ErrTagEntry],
291 +
    /// Number of entries in the error tag table.
292 +
    errTagsLen: u32,
293 +
    /// Next error tag to assign (starts at 1; 0 = success).
294 +
    errTagCounter: u32,
286 295
    /// Lowering options.
287 296
    options: LowerOptions,
288 297
}
289 298
290 299
/// Entry mapping a function symbol to its qualified name.
291 300
record FnSymEntry {
292 301
    sym: *resolver::Symbol,
293 302
    qualName: *[u8],
294 303
}
295 304
305 +
/// Entry in the global error tag table.
306 +
record ErrTagEntry {
307 +
    /// The type of this error, identified by its interned pointer.
308 +
    ty: resolver::Type,
309 +
    /// The globally unique tag assigned to this error type (non-zero).
310 +
    tag: u32,
311 +
}
312 +
313 +
/// Compute the maximum size of any error type in a throw list.
314 +
fn maxErrSize(throwList: *[*resolver::Type]) -> u32 {
315 +
    let mut maxSize: u32 = 0;
316 +
    for i in 0..throwList.len {
317 +
        let size = resolver::getTypeLayout(*throwList[i]).size;
318 +
        if size > maxSize {
319 +
            maxSize = size;
320 +
        }
321 +
    }
322 +
    return maxSize;
323 +
}
324 +
325 +
/// Get or assign a globally unique error tag for the given error type.
326 +
/// Tag `0` is reserved for success; error tags start at `1`.
327 +
fn getOrAssignErrorTag(self: *mut Lowerer, errType: resolver::Type) -> u32 {
328 +
    for i in 0..self.errTagsLen {
329 +
        if self.errTags[i].ty == errType {
330 +
            return self.errTags[i].tag;
331 +
        }
332 +
    }
333 +
    let tag = self.errTagCounter;
334 +
    self.errTagCounter = self.errTagCounter + 1;
335 +
336 +
    if self.errTagsLen >= self.errTags.len {
337 +
        self.errTags = try! alloc::growSlice(
338 +
            self.arena,
339 +
            self.errTags as *mut [opaque],
340 +
            self.errTagsLen,
341 +
            @sizeOf(ErrTagEntry),
342 +
            @alignOf(ErrTagEntry)
343 +
        ) as *mut [ErrTagEntry];
344 +
    }
345 +
    self.errTags[self.errTagsLen] = ErrTagEntry { ty: errType, tag };
346 +
    self.errTagsLen = self.errTagsLen + 1;
347 +
348 +
    return tag;
349 +
}
350 +
296 351
/// Builder for accumulating data values during constant lowering.
297 352
record DataValueBuilder {
298 353
    arena: *mut alloc::Arena,
299 354
    values: *mut [il::DataValue],
300 355
    len: u32,
679 734
    arena: *mut alloc::Arena
680 735
) -> il::Program throws (LowerError) {
681 736
    let data = try! alloc::allocSlice(arena, @sizeOf(il::Data), @alignOf(il::Data), INITIAL_DATA) as *mut [il::Data];
682 737
    let fns = try! alloc::allocSlice(arena, @sizeOf(*il::Fn), @alignOf(*il::Fn), INITIAL_FNS) as *mut [*il::Fn];
683 738
    let fnSyms = try! alloc::allocSlice(arena, @sizeOf(FnSymEntry), @alignOf(FnSymEntry), INITIAL_FNS) as *mut [FnSymEntry];
739 +
    let errTags = try! alloc::allocSlice(arena, @sizeOf(ErrTagEntry), @alignOf(ErrTagEntry), INITIAL_ERR_TAGS) as *mut [ErrTagEntry];
684 740
    let mut low = Lowerer {
685 741
        arena: arena,
686 742
        resolver: res,
687 743
        moduleGraph: nil,
688 744
        pkgName,
691 747
        dataCount: 0,
692 748
        fns,
693 749
        fnCount: 0,
694 750
        fnSyms,
695 751
        fnSymsLen: 0,
752 +
        errTags,
753 +
        errTagsLen: 0,
754 +
        errTagCounter: 1,
696 755
        options: LowerOptions { debug: false, buildTest: false },
697 756
    };
698 757
    let defaultFnIdx = try lowerDecls(&mut low, root, true);
699 758
700 759
    return il::Program {
726 785
727 786
    let fnSyms = try! alloc::allocSlice(
728 787
        arena, @sizeOf(FnSymEntry), @alignOf(FnSymEntry), MAX_FN_SYMS
729 788
    ) as *mut [FnSymEntry];
730 789
790 +
    let errTags = try! alloc::allocSlice(
791 +
        arena, @sizeOf(ErrTagEntry), @alignOf(ErrTagEntry), INITIAL_ERR_TAGS
792 +
    ) as *mut [ErrTagEntry];
793 +
731 794
    return Lowerer {
732 795
        arena,
733 796
        resolver: res,
734 797
        moduleGraph: graph,
735 798
        pkgName,
738 801
        dataCount: 0,
739 802
        fns,
740 803
        fnCount: 0,
741 804
        fnSyms,
742 805
        fnSymsLen: 0,
806 +
        errTags,
807 +
        errTagsLen: 0,
808 +
        errTagCounter: 1,
743 809
        options,
744 810
    };
745 811
}
746 812
747 813
/// Lower a module's AST into the lowerer accumulator.
3812 3878
    tag: i64,
3813 3879
    payload: ?il::Val,
3814 3880
    payloadType: resolver::Type
3815 3881
) -> il::Val throws (LowerError) {
3816 3882
    let successType = *self.fnType.returnType;
3817 -
    let errType = *self.fnType.throwList[0]; // TODO: Support more errors.
3818 -
    let layout = resolver::getResultLayout(successType, errType);
3819 -
3883 +
    let layout = resolver::getResultLayout(
3884 +
        successType, &self.fnType.throwList[..self.fnType.throwListLen]
3885 +
    );
3820 3886
    return try buildTagged(self, layout, tag, payload, payloadType, 8, RESULT_VAL_OFFSET);
3821 3887
}
3822 3888
3823 3889
/// Build a slice aggregate from a data pointer and length.
3824 3890
fn buildSliceValue(
5170 5236
5171 5237
/// Compute the size of the return buffer for the current function.
5172 5238
fn retBufSize(self: *mut FnLowerer) -> u32 {
5173 5239
    if self.fnType.throwListLen > 0 {
5174 5240
        let successType = *self.fnType.returnType;
5175 -
        let errType = *self.fnType.throwList[0];
5176 5241
5177 -
        return resolver::getResultLayout(successType, errType).size;
5242 +
        return resolver::getResultLayout(successType, &self.fnType.throwList[..self.fnType.throwListLen]).size;
5178 5243
    }
5179 5244
    return resolver::getTypeLayout(*self.fnType.returnType).size;
5180 5245
}
5181 5246
5182 5247
/// Lower a return statement.
5191 5256
5192 5257
/// Lower a throw statement.
5193 5258
fn lowerThrowStmt(self: *mut FnLowerer, expr: *ast::Node) throws (LowerError) {
5194 5259
    debug::assert(self.fnType.throwListLen > 0);
5195 5260
5196 -
    // TODO: Adapt to multiple throw types.
5197 -
    let errType = *self.fnType.throwList[0];
5261 +
    let errType = resolver::typeFor(self.low.resolver, expr) else {
5262 +
        throw LowerError::MissingType(expr);
5263 +
    };
5264 +
    let tag = getOrAssignErrorTag(self.low, errType) as i64;
5198 5265
    let errVal = try lowerExpr(self, expr);
5199 -
    let resultVal = try buildResult(self, 1, errVal, errType);
5266 +
    let resultVal = try buildResult(self, tag, errVal, errType);
5200 5267
5201 5268
    try emitRetVal(self, resultVal);
5202 5269
}
5203 5270
5204 5271
/// Ensure a value is in a register (eg. for branch conditions).
5637 5704
        throw LowerError::MissingType(callExpr.callee);
5638 5705
    };
5639 5706
    let case resolver::Type::Fn(calleeInfo) = calleeTy else {
5640 5707
        throw LowerError::ExpectedFunction;
5641 5708
    };
5642 -
    let errTy = *calleeInfo.throwList[0];
5643 5709
    let okValueTy = *calleeInfo.returnType; // The type of the success payload.
5644 5710
5645 5711
    // Type of the try expression, which is either the return type of the function
5646 5712
    // if successful, or an optional of it, if using `try?`.
5647 5713
    let tryExprTy = resolver::typeFor(self.low.resolver, node) else {
5696 5762
        if let slot = resultSlot {
5697 5763
            let errVal = try buildNilOptional(self, tryExprTy);
5698 5764
            try emitStore(self, slot, 0, tryExprTy, errVal);
5699 5765
        }
5700 5766
        try emitMergeIfUnterminated(self, &mut mergeBlock);
5701 -
    } else if let catchExpr = t.catchExpr {
5767 +
    } else if t.catches.len > 0 {
5702 5768
        // `try ... catch` -- handle the error.
5703 -
        // If there's an error binding, create a variable for it.
5704 -
        let savedVarsLen = enterVarScope(self);
5705 -
        if let binding = t.catchBinding {
5706 -
            let case ast::NodeValue::Ident(name) = binding.value else {
5707 -
                throw LowerError::ExpectedIdentifier;
5708 -
            };
5709 -
            let errVal = tvalPayloadVal(self, base, errTy, RESULT_VAL_OFFSET);
5710 -
            let _ = newVar(self, name, ilType(self.low, errTy), false, errVal);
5769 +
        let firstNode = t.catches.list[0];
5770 +
        let case ast::NodeValue::CatchClause(first) = firstNode.value
5771 +
            else panic "lowerTry: expected CatchClause";
5772 +
5773 +
        if first.typeNode != nil or t.catches.len > 1 {
5774 +
            // Typed multi-catch: switch on global error tag.
5775 +
            try lowerMultiCatch(self, t.catches, calleeInfo, base, tagReg, &mut mergeBlock);
5776 +
        } else {
5777 +
            // Single untyped catch clause.
5778 +
            let savedVarsLen = enterVarScope(self);
5779 +
            if let binding = first.binding {
5780 +
                let case ast::NodeValue::Ident(name) = binding.value else {
5781 +
                    throw LowerError::ExpectedIdentifier;
5782 +
                };
5783 +
                let errTy = *calleeInfo.throwList[0];
5784 +
                let errVal = tvalPayloadVal(self, base, errTy, RESULT_VAL_OFFSET);
5785 +
                let _ = newVar(self, name, ilType(self.low, errTy), false, errVal);
5786 +
            }
5787 +
            try lowerBlock(self, first.body);
5788 +
            try emitMergeIfUnterminated(self, &mut mergeBlock);
5789 +
            exitVarScope(self, savedVarsLen);
5711 5790
        }
5712 -
        try lowerBlock(self, catchExpr);
5713 -
        try emitMergeIfUnterminated(self, &mut mergeBlock);
5714 -
        exitVarScope(self, savedVarsLen);
5715 5791
    } else if t.shouldPanic {
5716 5792
        // `try!` -- panic on error, emit unreachable since control won't continue.
5717 5793
        // TODO: We should have some kind of `panic` instruction?
5718 5794
        emit(self, il::Instr::Unreachable);
5719 5795
    } else {
5720 5796
        // Plain `try` -- propagate the error to the caller by returning early.
5721 -
        // The error type must match the current function's declared error type.
5722 -
        let currentErrType = *self.fnType.throwList[0];
5723 -
        debug::assert(currentErrType == errTy);
5797 +
        // Forward the callee's global error tag and payload directly.
5798 +
        let callerLayout = resolver::getResultLayout(
5799 +
            *self.fnType.returnType, &self.fnType.throwList[..self.fnType.throwListLen]
5800 +
        );
5801 +
        let calleeErrSize = maxErrSize(&calleeInfo.throwList[..calleeInfo.throwListLen]);
5802 +
        let dst = try emitReserveLayout(self, callerLayout);
5724 5803
5725 -
        // Extract the error payload and wrap it in a result tagged union,
5726 -
        // then return to the caller.
5727 -
        let errVal = tvalPayloadVal(self, base, errTy, RESULT_VAL_OFFSET);
5728 -
        let layout = resolver::getResultLayout(*self.fnType.returnType, currentErrType);
5729 -
        let retVal = try buildTagged(self, layout, 1, errVal, currentErrType, 8, RESULT_VAL_OFFSET);
5804 +
        emitStoreW64At(self, il::Val::Reg(tagReg), dst, TVAL_TAG_OFFSET);
5805 +
        let srcPayload = emitPtrOffset(self, base, RESULT_VAL_OFFSET);
5806 +
        let dstPayload = emitPtrOffset(self, dst, RESULT_VAL_OFFSET);
5807 +
        emit(self, il::Instr::Blit { dst: dstPayload, src: srcPayload, size: calleeErrSize });
5730 5808
5731 -
        try emitRetVal(self, retVal);
5809 +
        try emitRetVal(self, il::Val::Reg(dst));
5732 5810
    }
5733 5811
5734 5812
    // Switch to the merge block if one was created. If all paths diverged
5735 5813
    // (e.g both success and error returned), there's no merge block.
5736 5814
    if let blk = mergeBlock {
5748 5826
    } else { // Void return.
5749 5827
        return il::Val::Undef;
5750 5828
    }
5751 5829
}
5752 5830
5831 +
/// Lower typed multi-catch clauses.
5832 +
///
5833 +
/// Emits a switch on the global error tag to dispatch to the correct catch
5834 +
/// clause. Each typed clause extracts the error payload for its specific type
5835 +
/// and binds it to the clause's identifier.
5836 +
fn lowerMultiCatch(
5837 +
    self: *mut FnLowerer,
5838 +
    catches: ast::NodeList,
5839 +
    calleeInfo: *resolver::FnType,
5840 +
    base: il::Reg,
5841 +
    tagReg: il::Reg,
5842 +
    mergeBlock: *mut ?BlockId
5843 +
) throws (LowerError) {
5844 +
    let entry = currentBlock(self);
5845 +
5846 +
    // First pass: create blocks, resolve error types, and build switch cases.
5847 +
    let mut blocks: [BlockId; MAX_CATCH_CLAUSES] = undefined;
5848 +
    let mut errTypes: [?resolver::Type; MAX_CATCH_CLAUSES] = undefined;
5849 +
    let cases = try! alloc::allocSlice(
5850 +
        self.low.arena,
5851 +
        @sizeOf(il::SwitchCase),
5852 +
        @alignOf(il::SwitchCase),
5853 +
        catches.len
5854 +
    ) as *mut [il::SwitchCase];
5855 +
    let mut caseIdx: u32 = 0;
5856 +
    let mut defaultIdx: ?u32 = nil;
5857 +
5858 +
    for i in 0..catches.len {
5859 +
        let clauseNode = catches.list[i];
5860 +
        let case ast::NodeValue::CatchClause(clause) = clauseNode.value
5861 +
            else panic "lowerMultiCatch: expected CatchClause";
5862 +
5863 +
        blocks[i] = try createBlock(self, "catch");
5864 +
        try addPredecessor(self, blocks[i], entry);
5865 +
5866 +
        if let typeNode = clause.typeNode {
5867 +
            let errTy = resolver::typeFor(self.low.resolver, typeNode) else {
5868 +
                throw LowerError::MissingType(typeNode);
5869 +
            };
5870 +
            errTypes[i] = errTy;
5871 +
5872 +
            cases[caseIdx] = il::SwitchCase {
5873 +
                value: getOrAssignErrorTag(self.low, errTy) as i64,
5874 +
                target: blocks[i].n,
5875 +
                args: &mut []
5876 +
            };
5877 +
            caseIdx = caseIdx + 1;
5878 +
        } else {
5879 +
            errTypes[i] = nil;
5880 +
            defaultIdx = i;
5881 +
        }
5882 +
    }
5883 +
5884 +
    // Emit switch. Default target is the catch-all block, or an unreachable block.
5885 +
    let mut defaultTarget: BlockId = undefined;
5886 +
    if let idx = defaultIdx {
5887 +
        defaultTarget = blocks[idx];
5888 +
    } else {
5889 +
        defaultTarget = try createBlock(self, "unreachable");
5890 +
        try addPredecessor(self, defaultTarget, entry);
5891 +
    }
5892 +
    emit(self, il::Instr::Switch {
5893 +
        val: il::Val::Reg(tagReg),
5894 +
        defaultTarget: defaultTarget.n,
5895 +
        defaultArgs: &mut [],
5896 +
        cases: cases[..caseIdx]
5897 +
    });
5898 +
5899 +
    // Second pass: emit each catch clause body.
5900 +
    for i in 0..catches.len {
5901 +
        let clauseNode = catches.list[i];
5902 +
        let case ast::NodeValue::CatchClause(clause) = clauseNode.value
5903 +
            else panic "lowerMultiCatch: expected CatchClause";
5904 +
5905 +
        try switchToAndSeal(self, blocks[i]);
5906 +
        let savedVarsLen = enterVarScope(self);
5907 +
5908 +
        if let binding = clause.binding {
5909 +
            let case ast::NodeValue::Ident(name) = binding.value else {
5910 +
                throw LowerError::ExpectedIdentifier;
5911 +
            };
5912 +
            let errTy = errTypes[i] else panic "lowerMultiCatch: catch-all with binding";
5913 +
            let errVal = tvalPayloadVal(self, base, errTy, RESULT_VAL_OFFSET);
5914 +
5915 +
            newVar(self, name, ilType(self.low, errTy), false, errVal);
5916 +
        }
5917 +
        try lowerBlock(self, clause.body);
5918 +
        try emitMergeIfUnterminated(self, mergeBlock);
5919 +
5920 +
        exitVarScope(self, savedVarsLen);
5921 +
    }
5922 +
5923 +
    // Emit unreachable block if no catch-all.
5924 +
    if defaultIdx == nil {
5925 +
        try switchToAndSeal(self, defaultTarget);
5926 +
        emit(self, il::Instr::Unreachable);
5927 +
    }
5928 +
}
5929 +
5753 5930
/// Lower a call expression, which may be a function call or type constructor.
5754 5931
fn lowerCallOrCtor(self: *mut FnLowerer, node: *ast::Node, call: ast::Call) -> il::Val throws (LowerError) {
5755 5932
    if let sym = resolver::nodeData(self.low.resolver, call.callee).sym {
5756 5933
        if let case resolver::SymbolData::Type(nominal) = sym.data {
5757 5934
            let case resolver::NominalType::Record(_) = *nominal else {
5868 6045
5869 6046
    // Allocate the return buffer when needed.
5870 6047
    if needsRetBuf {
5871 6048
        if isThrowing {
5872 6049
            let successType = *fnInfo.returnType;
5873 -
            let errType = *fnInfo.throwList[0];
5874 -
            let layout = resolver::getResultLayout(successType, errType);
6050 +
            let layout = resolver::getResultLayout(successType, &fnInfo.throwList[..fnInfo.throwListLen]);
5875 6051
5876 6052
            args[0] = il::Val::Reg(try emitReserveLayout(self, layout));
5877 6053
        } else {
5878 6054
            args[0] = il::Val::Reg(try emitReserve(self, retTy));
5879 6055
        }
lib/std/lang/lower/tests/multi.throw.basic.rad added +13 -0
1 +
/// Multi-throw: function that can throw two different error types.
2 +
union ErrA { A }
3 +
union ErrB { B }
4 +
5 +
fn fallible(flag: i32) -> i32 throws (ErrA, ErrB) {
6 +
    if flag == 1 {
7 +
        throw ErrA::A();
8 +
    }
9 +
    if flag == 2 {
10 +
        throw ErrB::B();
11 +
    }
12 +
    return 42;
13 +
}
lib/std/lang/lower/tests/multi.throw.basic.ril added +28 -0
1 +
fn w64 $fallible(w64 %0, w32 %1) {
2 +
  @entry0
3 +
    br.eq w32 %1 1 @then1 @merge2;
4 +
  @then1
5 +
    reserve %2 1 1;
6 +
    store w8 0 %2 0;
7 +
    reserve %3 12 8;
8 +
    store w64 1 %3 0;
9 +
    store w8 %2 %3 8;
10 +
    blit %0 %3 12;
11 +
    ret %0;
12 +
  @merge2
13 +
    br.eq w32 %1 2 @then3 @merge4;
14 +
  @then3
15 +
    reserve %4 1 1;
16 +
    store w8 0 %4 0;
17 +
    reserve %5 12 8;
18 +
    store w64 2 %5 0;
19 +
    store w8 %4 %5 8;
20 +
    blit %0 %5 12;
21 +
    ret %0;
22 +
  @merge4
23 +
    reserve %6 12 8;
24 +
    store w64 0 %6 0;
25 +
    store w32 42 %6 8;
26 +
    blit %0 %6 12;
27 +
    ret %0;
28 +
}
lib/std/lang/lower/tests/multi.throw.catch.typed.rad added +23 -0
1 +
/// Multi-error typed catch: switch dispatch on error tag.
2 +
union ErrA { A }
3 +
union ErrB { B }
4 +
5 +
fn fallible(flag: i32) -> i32 throws (ErrA, ErrB) {
6 +
    if flag == 1 {
7 +
        throw ErrA::A();
8 +
    }
9 +
    if flag == 2 {
10 +
        throw ErrB::B();
11 +
    }
12 +
    return 42;
13 +
}
14 +
15 +
fn caller(flag: i32) -> i32 {
16 +
    let mut r: i32 = 0;
17 +
    try fallible(flag) catch e as ErrA {
18 +
        r = 1;
19 +
    } catch e as ErrB {
20 +
        r = 2;
21 +
    };
22 +
    return r;
23 +
}
lib/std/lang/lower/tests/multi.throw.catch.typed.ril added +54 -0
1 +
fn w64 $fallible(w64 %0, w32 %1) {
2 +
  @entry0
3 +
    br.eq w32 %1 1 @then1 @merge2;
4 +
  @then1
5 +
    reserve %2 1 1;
6 +
    store w8 0 %2 0;
7 +
    reserve %3 12 8;
8 +
    store w64 1 %3 0;
9 +
    store w8 %2 %3 8;
10 +
    blit %0 %3 12;
11 +
    ret %0;
12 +
  @merge2
13 +
    br.eq w32 %1 2 @then3 @merge4;
14 +
  @then3
15 +
    reserve %4 1 1;
16 +
    store w8 0 %4 0;
17 +
    reserve %5 12 8;
18 +
    store w64 2 %5 0;
19 +
    store w8 %4 %5 8;
20 +
    blit %0 %5 12;
21 +
    ret %0;
22 +
  @merge4
23 +
    reserve %6 12 8;
24 +
    store w64 0 %6 0;
25 +
    store w32 42 %6 8;
26 +
    blit %0 %6 12;
27 +
    ret %0;
28 +
}
29 +
30 +
fn w32 $caller(w32 %0) {
31 +
  @entry0
32 +
    reserve %1 12 8;
33 +
    call w64 %2 $fallible(%1, %0);
34 +
    load w64 %3 %2 0;
35 +
    reserve %4 4 4;
36 +
    br.ne w32 %3 0 @err2 @ok1;
37 +
  @ok1
38 +
    sload w32 %5 %2 8;
39 +
    store w32 %5 %4 0;
40 +
    jmp @merge3(0);
41 +
  @err2
42 +
    switch %3 (1 @catch4) (2 @catch5) @unreachable6;
43 +
  @merge3(w32 %9)
44 +
    sload w32 %8 %4 0;
45 +
    ret %9;
46 +
  @catch4
47 +
    load w8 %6 %2 8;
48 +
    jmp @merge3(1);
49 +
  @catch5
50 +
    load w8 %7 %2 8;
51 +
    jmp @merge3(2);
52 +
  @unreachable6
53 +
    unreachable;
54 +
}
lib/std/lang/lower/tests/multi.throw.propagate.rad added +15 -0
1 +
/// Multi-throw propagation: caller propagates errors from callee with subset throw list.
2 +
union ErrA { A }
3 +
union ErrB { B }
4 +
5 +
fn inner(flag: bool) -> i32 throws (ErrA) {
6 +
    if flag {
7 +
        throw ErrA::A();
8 +
    }
9 +
    return 10;
10 +
}
11 +
12 +
fn outer(flag: bool) -> i32 throws (ErrA, ErrB) {
13 +
    let val = try inner(flag);
14 +
    return val + 1;
15 +
}
lib/std/lang/lower/tests/multi.throw.propagate.ril added +47 -0
1 +
fn w64 $inner(w64 %0, w8 %1) {
2 +
  @entry0
3 +
    br.ne w32 %1 0 @then1 @merge2;
4 +
  @then1
5 +
    reserve %2 1 1;
6 +
    store w8 0 %2 0;
7 +
    reserve %3 12 8;
8 +
    store w64 1 %3 0;
9 +
    store w8 %2 %3 8;
10 +
    blit %0 %3 12;
11 +
    ret %0;
12 +
  @merge2
13 +
    reserve %4 12 8;
14 +
    store w64 0 %4 0;
15 +
    store w32 10 %4 8;
16 +
    blit %0 %4 12;
17 +
    ret %0;
18 +
}
19 +
20 +
fn w64 $outer(w64 %0, w8 %1) {
21 +
  @entry0
22 +
    reserve %2 12 8;
23 +
    call w64 %3 $inner(%2, %1);
24 +
    load w64 %4 %3 0;
25 +
    reserve %5 4 4;
26 +
    br.ne w32 %4 0 @err2 @ok1;
27 +
  @ok1
28 +
    sload w32 %6 %3 8;
29 +
    store w32 %6 %5 0;
30 +
    jmp @merge3;
31 +
  @err2
32 +
    reserve %7 12 8;
33 +
    store w64 %4 %7 0;
34 +
    add w64 %8 %3 8;
35 +
    add w64 %9 %7 8;
36 +
    blit %9 %8 1;
37 +
    blit %0 %7 12;
38 +
    ret %0;
39 +
  @merge3
40 +
    sload w32 %10 %5 0;
41 +
    add w32 %11 %10 1;
42 +
    reserve %12 12 8;
43 +
    store w64 0 %12 0;
44 +
    store w32 %11 %12 8;
45 +
    blit %0 %12 12;
46 +
    ret %0;
47 +
}
lib/std/lang/lower/tests/try.basic.ril +12 -11
27 27
  @ok1
28 28
    sload w32 %6 %3 8;
29 29
    store w32 %6 %5 0;
30 30
    jmp @merge3;
31 31
  @err2
32 -
    load w8 %7 %3 8;
33 -
    reserve %8 12 8;
34 -
    store w64 1 %8 0;
35 -
    store w8 %7 %8 8;
36 -
    blit %0 %8 12;
32 +
    reserve %7 12 8;
33 +
    store w64 %4 %7 0;
34 +
    add w64 %8 %3 8;
35 +
    add w64 %9 %7 8;
36 +
    blit %9 %8 1;
37 +
    blit %0 %7 12;
37 38
    ret %0;
38 39
  @merge3
39 -
    sload w32 %9 %5 0;
40 -
    add w32 %10 %9 1;
41 -
    reserve %11 12 8;
42 -
    store w64 0 %11 0;
43 -
    store w32 %10 %11 8;
44 -
    blit %0 %11 12;
40 +
    sload w32 %10 %5 0;
41 +
    add w32 %11 %10 1;
42 +
    reserve %12 12 8;
43 +
    store w64 0 %12 0;
44 +
    store w32 %11 %12 8;
45 +
    blit %0 %12 12;
45 46
    ret %0;
46 47
}
lib/std/lang/lower/tests/void.throw.ril +9 -8
23 23
    load w64 %4 %3 0;
24 24
    br.ne w32 %4 0 @err2 @ok1;
25 25
  @ok1
26 26
    jmp @merge3;
27 27
  @err2
28 -
    load w8 %5 %3 8;
29 -
    reserve %6 9 8;
30 -
    store w64 1 %6 0;
31 -
    store w8 %5 %6 8;
32 -
    blit %0 %6 9;
28 +
    reserve %5 9 8;
29 +
    store w64 %4 %5 0;
30 +
    add w64 %6 %3 8;
31 +
    add w64 %7 %5 8;
32 +
    blit %7 %6 1;
33 +
    blit %0 %5 9;
33 34
    ret %0;
34 35
  @merge3
35 -
    reserve %7 9 8;
36 -
    store w64 0 %7 0;
37 -
    blit %0 %7 9;
36 +
    reserve %8 9 8;
37 +
    store w64 0 %8 0;
38 +
    blit %0 %8 9;
38 39
    ret %0;
39 40
}
lib/std/lang/parser.rad +18 -8
1369 1369
    try expect(p, scanner::TokenKind::Continue, "expected `continue`");
1370 1370
1371 1371
    return node(p, ast::NodeValue::Continue);
1372 1372
}
1373 1373
1374 -
/// Parse a `try` expression with optional `catch`.
1374 +
/// Parse a `try` expression with optional `catch` clause(s).
1375 1375
fn parseTryExpr(p: *mut Parser) -> *ast::Node
1376 1376
    throws (ParseError)
1377 1377
{
1378 1378
    try expect(p, scanner::TokenKind::Try, "expected `try`");
1379 1379
1380 1380
    let shouldPanic = consume(p, scanner::TokenKind::Bang);
1381 1381
    let returnsOptional = consume(p, scanner::TokenKind::Question);
1382 1382
    let expr = try parsePrimary(p);
1383 -
    let mut catchExpr: ?*ast::Node = nil;
1384 -
    let mut catchBinding: ?*ast::Node = nil;
1383 +
    let mut catches = ast::nodeList(p.arena, 4);
1385 1384
1386 -
    if consume(p, scanner::TokenKind::Catch) {
1387 -
        // Check for optional error binding: `catch err { ... }`.
1385 +
    while consume(p, scanner::TokenKind::Catch) {
1386 +
        let mut binding: ?*ast::Node = nil;
1387 +
        let mut typeNode: ?*ast::Node = nil;
1388 +
1389 +
        // Check for optional error binding: `catch ident { ... }` or
1390 +
        // `catch ident as Type { ... }`.
1388 1391
        if check(p, scanner::TokenKind::Ident) {
1389 -
            catchBinding = try parseIdent(p, "expected identifier after `catch`");
1392 +
            binding = try parseIdent(p, "expected identifier after `catch`");
1393 +
            if consume(p, scanner::TokenKind::As) {
1394 +
                typeNode = try parseType(p);
1395 +
            }
1390 1396
        }
1391 1397
        if not check(p, scanner::TokenKind::LBrace) {
1392 1398
            throw failParsing(p, "expected `{` after `catch`");
1393 1399
        }
1394 -
        catchExpr = try parseBlock(p);
1400 +
        let body = try parseBlock(p);
1401 +
        let clause = node(p, ast::NodeValue::CatchClause(
1402 +
            ast::CatchClause { binding, typeNode, body }
1403 +
        ));
1404 +
        ast::nodeListPush(&mut catches, clause);
1395 1405
    }
1396 1406
    return node(p, ast::NodeValue::Try(
1397 -
        ast::Try { expr, catchExpr, catchBinding, shouldPanic, returnsOptional }
1407 +
        ast::Try { expr, catches, shouldPanic, returnsOptional }
1398 1408
    ));
1399 1409
}
1400 1410
1401 1411
/// Parse an `if` expression, with optional `else` or `else if` clauses.
1402 1412
///
lib/std/lang/parser/tests.rad +7 -4
1343 1343
    let root = try! parseExprStr("try value");
1344 1344
    let case ast::NodeValue::Try(node) = root.value
1345 1345
        else throw testing::TestError::Failed;
1346 1346
1347 1347
    try expectIdent(node.expr, "value");
1348 -
    try testing::expect(node.catchExpr == nil);
1348 +
    try testing::expect(node.catches.len == 0);
1349 1349
    try testing::expect(not node.shouldPanic);
1350 1350
}
1351 1351
1352 1352
/// Test parsing a `try!` expression that panics on error.
1353 1353
@test fn testParseTryBang() throws (testing::TestError) {
1354 1354
    let root = try! parseExprStr("try! value");
1355 1355
    let case ast::NodeValue::Try(node) = root.value
1356 1356
        else throw testing::TestError::Failed;
1357 1357
1358 1358
    try expectIdent(node.expr, "value");
1359 -
    try testing::expect(node.catchExpr == nil);
1359 +
    try testing::expect(node.catches.len == 0);
1360 1360
    try testing::expect(node.shouldPanic);
1361 1361
}
1362 1362
1363 1363
/// Test parsing a `try` expression with a `catch` block.
1364 1364
@test fn testParseTryCatchBlock() throws (testing::TestError) {
1365 1365
    let root = try! parseExprStr("try value catch { alt; }");
1366 1366
    let case ast::NodeValue::Try(node) = root.value
1367 1367
        else throw testing::TestError::Failed;
1368 1368
1369 1369
    try expectIdent(node.expr, "value");
1370 +
    try testing::expect(node.catches.len == 1);
1370 1371
1371 -
    let catchExpr = node.catchExpr
1372 +
    let case ast::NodeValue::CatchClause(clause) = node.catches.list[0].value
1372 1373
        else throw testing::TestError::Failed;
1373 -
    try expectBlockExprStmt(catchExpr, ast::NodeValue::Ident("alt"));
1374 +
    try testing::expect(clause.binding == nil);
1375 +
    try testing::expect(clause.typeNode == nil);
1376 +
    try expectBlockExprStmt(clause.body, ast::NodeValue::Ident("alt"));
1374 1377
}
1375 1378
1376 1379
/// Test that `catch` without a block is rejected.
1377 1380
@test fn testParseTryCatchExprRejected() throws (testing::TestError) {
1378 1381
    let parsed: ?*ast::Node = try? parseExprStr("try value catch alternate");
lib/std/lang/resolver.rad +138 -37
440 440
    ThrowRequiresThrows,
441 441
    /// `throw` used with an error type not declared by the enclosing function.
442 442
    ThrowIncompatibleError,
443 443
    /// `try` applied to an expression that cannot throw.
444 444
    TryNonThrowing,
445 +
    /// Inferred catch binding used with multi-error callee.
446 +
    TryCatchMultiError,
447 +
    /// Duplicate error type in typed catch clauses.
448 +
    TryCatchDuplicateType,
449 +
    /// Typed catch clauses do not cover all error types.
450 +
    TryCatchNonExhaustive,
445 451
    /// Called a fallible function without using `try`.
446 452
    MissingTry,
447 453
    /// `else` branch of a `let` guard must not fall through.
448 454
    ElseBranchMustDiverge,
449 455
    /// Cannot use opaque type in this context.
1362 1368
        }
1363 1369
    }
1364 1370
}
1365 1371
1366 1372
/// Get the layout of a result aggregate with a tag and the larger payload.
1367 -
pub fn getResultLayout(payload: Type, err: Type) -> Layout {
1373 +
pub fn getResultLayout(payload: Type, throwList: *[*Type]) -> Layout {
1368 1374
    let payloadLayout = getTypeLayout(payload);
1369 -
    let errLayout = getTypeLayout(err);
1370 -
    let maxSize = max(payloadLayout.size, errLayout.size);
1371 -
    let maxAlign = max(payloadLayout.alignment, errLayout.alignment);
1375 +
    let mut maxSize = payloadLayout.size;
1376 +
    let mut maxAlign = payloadLayout.alignment;
1372 1377
1378 +
    for i in 0..throwList.len {
1379 +
        let errLayout = getTypeLayout(*throwList[i]);
1380 +
        maxSize = max(maxSize, errLayout.size);
1381 +
        maxAlign = max(maxAlign, errLayout.alignment);
1382 +
    }
1373 1383
    return Layout {
1374 1384
        size: PTR_SIZE + maxSize,
1375 1385
        alignment: max(PTR_SIZE, maxAlign),
1376 1386
    };
1377 1387
}
4947 4957
        if let case Type::Optional(_) = resultTy {
4948 4958
            // Already optional, no wrapping needed.
4949 4959
        } else {
4950 4960
            tryResultTy = Type::Optional(allocType(self, resultTy));
4951 4961
        }
4952 -
    } else if let expr = tryExpr.catchExpr {
4953 -
        // If there's an error binding, create a scope and add the variable.
4954 -
        if let binding = tryExpr.catchBinding {
4955 -
            enterScope(self, node);
4956 -
4957 -
            let errTy = *calleeInfo.throwList[0];
4958 -
            try bindValueIdent(self, binding, binding, errTy, false, 0, 0);
4959 -
        }
4960 -
        let catchTy = try visit(self, expr, resultTy);
4961 -
4962 -
        if let _ = tryExpr.catchBinding {
4963 -
            exitScope(self);
4964 -
        }
4965 -
        // If the catch expression is `nil` and the result type is not already
4966 -
        // optional, lift the result type to an optional. This allows
4967 -
        // `try? fail()` to yield `?T` where `fail` returns `T`.
4968 -
        if catchTy == Type::Nil {
4969 -
            if let case Type::Optional(_) = resultTy {
4970 -
                // Already optional, `nil` is directly assignable.
4971 -
            } else {
4972 -
                tryResultTy = Type::Optional(allocType(self, resultTy));
4973 -
            }
4974 -
        // TODO: We should be able to just check the optional with the variant here (bug).
4975 -
        } else if hint != Type::Unknown {
4976 -
            // In expression context, the `catch` type must be assignable to the
4977 -
            // result type. In statement context, this check is skipped since
4978 -
            // the value is discarded anyway.
4979 -
            if let case Type::Void = hint {
4980 -
                // Statement context, skip check.
4981 -
            } else {
4982 -
                let _ = try checkAssignable(self, expr, resultTy);
4983 -
            }
4984 -
        }
4962 +
    } else if tryExpr.catches.len > 0 {
4963 +
        // `try ... catch` -- one or more catch clauses.
4964 +
        tryResultTy = try resolveTryCatches(self, node, tryExpr.catches, calleeInfo, resultTy, hint);
4985 4965
    } else if not tryExpr.shouldPanic {
4986 4966
        let fnInfo = self.currentFn
4987 4967
            else throw emitError(self, node, ErrorKind::TryRequiresThrows);
4988 4968
        if fnInfo.throwListLen == 0 {
4989 4969
            throw emitError(self, node, ErrorKind::TryRequiresThrows);
5006 4986
        }
5007 4987
    }
5008 4988
    return try setNodeType(self, node, tryResultTy);
5009 4989
}
5010 4990
4991 +
4992 +
/// Check that a `catch` body is assignable to the expected result type, but only
4993 +
/// in expression context (`hint` is neither `Unknown` nor `Void`).
4994 +
fn checkCatchBody(self: *mut Resolver, body: *ast::Node, resultTy: Type, hint: Type)
4995 +
    throws (ResolveError)
4996 +
{
4997 +
    if hint != Type::Unknown and hint != Type::Void {
4998 +
        try checkAssignable(self, body, resultTy);
4999 +
    }
5000 +
}
5001 +
5002 +
/// Resolve catch clauses for a `try ... catch` expression.
5003 +
///
5004 +
/// For a single untyped catch (with or without binding), resolves the catch
5005 +
/// body and returns the result type. Multi-error callees with inferred bindings
5006 +
/// are rejected; you must use typed catches.
5007 +
fn resolveTryCatches(
5008 +
    self: *mut Resolver,
5009 +
    node: *ast::Node,
5010 +
    catches: ast::NodeList,
5011 +
    calleeInfo: *FnType,
5012 +
    resultTy: Type,
5013 +
    hint: Type
5014 +
) -> Type throws (ResolveError) {
5015 +
    let firstNode = catches.list[0];
5016 +
    let case ast::NodeValue::CatchClause(first) = firstNode.value else
5017 +
        throw emitError(self, node, ErrorKind::UnexpectedNode(firstNode));
5018 +
5019 +
    // Typed catches: dispatch to dedicated handler.
5020 +
    if first.typeNode != nil {
5021 +
        return try resolveTypedCatches(self, node, catches, calleeInfo, resultTy, hint);
5022 +
    }
5023 +
    // Single untyped catch clause.
5024 +
    if let binding = first.binding {
5025 +
        if calleeInfo.throwListLen > 1 {
5026 +
            throw emitError(self, binding, ErrorKind::TryCatchMultiError);
5027 +
        }
5028 +
        enterScope(self, node);
5029 +
5030 +
        let errTy = *calleeInfo.throwList[0];
5031 +
        try bindValueIdent(self, binding, binding, errTy, false, 0, 0);
5032 +
    }
5033 +
    try visit(self, first.body, resultTy);
5034 +
5035 +
    if let _ = first.binding {
5036 +
        exitScope(self);
5037 +
    }
5038 +
    try checkCatchBody(self, first.body, resultTy, hint);
5039 +
5040 +
    return resultTy;
5041 +
}
5042 +
5043 +
/// Resolve typed catch clauses (`catch e as T {..} catch e as S {..}`).
5044 +
///
5045 +
/// Validates that each type annotation is in the callee's throw list, that
5046 +
/// there are no duplicate catch types, and that the clauses are exhaustive.
5047 +
fn resolveTypedCatches(
5048 +
    self: *mut Resolver,
5049 +
    node: *ast::Node,
5050 +
    catches: ast::NodeList,
5051 +
    calleeInfo: *FnType,
5052 +
    resultTy: Type,
5053 +
    hint: Type
5054 +
) -> Type throws (ResolveError) {
5055 +
    // Track which of the callee's throw types have been covered.
5056 +
    let mut covered: [bool; MAX_FN_THROWS] = [false; MAX_FN_THROWS];
5057 +
    let mut hasCatchAll = false;
5058 +
5059 +
    for i in 0..catches.len {
5060 +
        let clauseNode = catches.list[i];
5061 +
        let case ast::NodeValue::CatchClause(clause) = clauseNode.value else
5062 +
            throw emitError(self, node, ErrorKind::UnexpectedNode(clauseNode));
5063 +
5064 +
        if let typeNode = clause.typeNode {
5065 +
            // Typed catch clause: validate against callee's throw list.
5066 +
            let errTy = try infer(self, typeNode);
5067 +
            let mut foundIdx: ?u32 = nil;
5068 +
5069 +
            for j in 0..calleeInfo.throwListLen {
5070 +
                if errTy == *calleeInfo.throwList[j] {
5071 +
                    foundIdx = j;
5072 +
                    break;
5073 +
                }
5074 +
            }
5075 +
            let idx = foundIdx else {
5076 +
                throw emitError(self, typeNode, ErrorKind::TryIncompatibleError);
5077 +
            };
5078 +
            if covered[idx] {
5079 +
                throw emitError(self, typeNode, ErrorKind::TryCatchDuplicateType);
5080 +
            }
5081 +
            covered[idx] = true;
5082 +
5083 +
            // Bind the error variable if present.
5084 +
            if let binding = clause.binding {
5085 +
                enterScope(self, clauseNode);
5086 +
                try bindValueIdent(self, binding, binding, errTy, false, 0, 0);
5087 +
            }
5088 +
        } else {
5089 +
            // Catch-all clause with no type annotation or binding.
5090 +
            hasCatchAll = true;
5091 +
        }
5092 +
        // Resolve the catch body and check assignability.
5093 +
        try visit(self, clause.body, resultTy);
5094 +
        // Only typed clauses can have bindings.
5095 +
        if let _ = clause.binding {
5096 +
            exitScope(self);
5097 +
        }
5098 +
        try checkCatchBody(self, clause.body, resultTy, hint);
5099 +
    }
5100 +
5101 +
    // Check exhaustiveness: all callee error types must be covered.
5102 +
    if not hasCatchAll {
5103 +
        for i in 0..calleeInfo.throwListLen {
5104 +
            if not covered[i] {
5105 +
                throw emitError(self, node, ErrorKind::TryCatchNonExhaustive);
5106 +
            }
5107 +
        }
5108 +
    }
5109 +
    return resultTy;
5110 +
}
5111 +
5011 5112
/// Analyze a `throw` statement.
5012 5113
fn resolveThrow(self: *mut Resolver, node: *ast::Node, expr: *ast::Node) -> Type
5013 5114
    throws (ResolveError)
5014 5115
{
5015 5116
    let fnInfo = self.currentFn
lib/std/lang/resolver/printer.rad +9 -0
507 507
            io::print("throw uses error type not declared by function");
508 508
        }
509 509
        case super::ErrorKind::TryNonThrowing => {
510 510
            io::print("try applied to expression that cannot throw");
511 511
        }
512 +
        case super::ErrorKind::TryCatchMultiError => {
513 +
            io::print("catch with inferred binding requires single error type; use typed catches for multiple error types");
514 +
        }
515 +
        case super::ErrorKind::TryCatchDuplicateType => {
516 +
            io::print("duplicate error type in catch clauses");
517 +
        }
518 +
        case super::ErrorKind::TryCatchNonExhaustive => {
519 +
            io::print("catch clauses do not cover all error types");
520 +
        }
512 521
        case super::ErrorKind::MissingTry => {
513 522
            io::print("called fallible function without using try");
514 523
        }
515 524
        case super::ErrorKind::CannotInferType => {
516 525
            io::print("cannot infer type from context");
lib/std/lang/resolver/tests.rad +44 -0
4444 4444
        let program = "fn f(opt: ?i32) { match opt { v => {}, else => {} } }";
4445 4445
        let result = try resolveProgramStr(&mut a, program);
4446 4446
        try expectNoErrors(&result);
4447 4447
    }
4448 4448
}
4449 +
4450 +
// --- Multi-error typed catch tests ---
4451 +
4452 +
@test fn testTypedCatchExhaustive() throws (testing::TestError) {
4453 +
    let mut a = testResolver();
4454 +
    let program = "union ErrA { A } union ErrB { B } fn f() -> i32 throws (ErrA, ErrB) { throw ErrA::A(); return 0; } fn g() -> i32 { return try f() catch e as ErrA { return 0; } catch e as ErrB { return 1; }; }";
4455 +
    let result = try resolveProgramStr(&mut a, program);
4456 +
    try expectNoErrors(&result);
4457 +
}
4458 +
4459 +
@test fn testTypedCatchNonExhaustive() throws (testing::TestError) {
4460 +
    let mut a = testResolver();
4461 +
    let program = "union ErrA { A } union ErrB { B } fn f() -> i32 throws (ErrA, ErrB) { throw ErrA::A(); return 0; } fn g() -> i32 { return try f() catch e as ErrA { return 0; }; }";
4462 +
    let result = try resolveProgramStr(&mut a, program);
4463 +
    try expectErrorKind(&result, super::ErrorKind::TryCatchNonExhaustive);
4464 +
}
4465 +
4466 +
@test fn testTypedCatchDuplicate() throws (testing::TestError) {
4467 +
    let mut a = testResolver();
4468 +
    let program = "union ErrA { A } union ErrB { B } fn f() -> i32 throws (ErrA, ErrB) { throw ErrA::A(); return 0; } fn g() -> i32 { return try f() catch e as ErrA { return 0; } catch e as ErrA { return 1; }; }";
4469 +
    let result = try resolveProgramStr(&mut a, program);
4470 +
    try expectErrorKind(&result, super::ErrorKind::TryCatchDuplicateType);
4471 +
}
4472 +
4473 +
@test fn testTypedCatchWithCatchAll() throws (testing::TestError) {
4474 +
    let mut a = testResolver();
4475 +
    let program = "union ErrA { A } union ErrB { B } fn f() -> i32 throws (ErrA, ErrB) { throw ErrA::A(); return 0; } fn g() -> i32 { return try f() catch e as ErrA { return 0; } catch { return 1; }; }";
4476 +
    let result = try resolveProgramStr(&mut a, program);
4477 +
    try expectNoErrors(&result);
4478 +
}
4479 +
4480 +
@test fn testTypedCatchWrongType() throws (testing::TestError) {
4481 +
    let mut a = testResolver();
4482 +
    let program = "union ErrA { A } union ErrB { B } union ErrC { C } fn f() -> i32 throws (ErrA, ErrB) { throw ErrA::A(); return 0; } fn g() -> i32 { return try f() catch e as ErrC { return 0; } catch e as ErrA { return 1; }; }";
4483 +
    let result = try resolveProgramStr(&mut a, program);
4484 +
    try expectErrorKind(&result, super::ErrorKind::TryIncompatibleError);
4485 +
}
4486 +
4487 +
@test fn testInferredCatchMultiError() throws (testing::TestError) {
4488 +
    let mut a = testResolver();
4489 +
    let program = "union ErrA { A } union ErrB { B } fn f() -> i32 throws (ErrA, ErrB) { throw ErrA::A(); return 0; } fn g() -> i32 { return try f() catch e { return 0; }; }";
4490 +
    let result = try resolveProgramStr(&mut a, program);
4491 +
    try expectErrorKind(&result, super::ErrorKind::TryCatchMultiError);
4492 +
}