Lower conditional expressions to IL

ff8c8e43fcef5f4dc69a22f50143e87ee58470a4f2f131569ab9f42e2b47a1d2
This was previously not implemented even though parsing was there.
Alexis Sellier committed ago 1 parent de09a137
lib/std/arch/rv64/tests/cond.expr.aggregate.rad added +73 -0
1 +
//! Test conditional expression with aggregate types (tagged unions, optionals).
2 +
3 +
union Val {
4 +
    Reg(u32),
5 +
    Imm(i64),
6 +
}
7 +
8 +
fn pickVal(flag: bool, reg: u32) -> Val {
9 +
    return Val::Reg(reg) if flag else Val::Imm(99);
10 +
}
11 +
12 +
fn pickOpt(a: ?u16, b: ?u16) -> ?u16 {
13 +
    return b if a == nil else a;
14 +
}
15 +
16 +
/// Test tagged union conditional expression.
17 +
fn testUnion() -> i32 {
18 +
    let v1 = pickVal(true, 7);
19 +
    if let case Val::Reg(r) = v1 {
20 +
        if r != 7 {
21 +
            return 1;
22 +
        }
23 +
    } else {
24 +
        return 2;
25 +
    }
26 +
27 +
    let v2 = pickVal(false, 7);
28 +
    if let case Val::Imm(i) = v2 {
29 +
        if i != 99 {
30 +
            return 3;
31 +
        }
32 +
    } else {
33 +
        return 4;
34 +
    }
35 +
    return 0;
36 +
}
37 +
38 +
/// Test optional conditional expression.
39 +
fn testOptional() -> i32 {
40 +
    let a: ?u16 = 42;
41 +
    let b: ?u16 = 99;
42 +
    let n: ?u16 = nil;
43 +
44 +
    let r1 = pickOpt(a, b);
45 +
    let v1 = r1 else {
46 +
        return 10;
47 +
    };
48 +
    if v1 != 42 {
49 +
        return 11;
50 +
    }
51 +
52 +
    let r2 = pickOpt(n, b);
53 +
    let v2 = r2 else {
54 +
        return 12;
55 +
    };
56 +
    if v2 != 99 {
57 +
        return 13;
58 +
    }
59 +
60 +
    return 0;
61 +
}
62 +
63 +
@default fn main() -> i32 {
64 +
    let r1 = testUnion();
65 +
    if r1 != 0 {
66 +
        return r1;
67 +
    }
68 +
    let r2 = testOptional();
69 +
    if r2 != 0 {
70 +
        return r2;
71 +
    }
72 +
    return 0;
73 +
}
lib/std/arch/rv64/tests/cond.expr.rad added +110 -0
1 +
//! Test conditional expression lowering (`thenExpr if cond else elseExpr`).
2 +
3 +
fn min(a: u32, b: u32) -> u32 {
4 +
    return b if b < a else a;
5 +
}
6 +
7 +
fn pick(cond: bool, a: i32, b: i32) -> i32 {
8 +
    return a if cond else b;
9 +
}
10 +
11 +
union Color { Red, Green, Blue }
12 +
13 +
fn colorPick(isRed: bool) -> Color {
14 +
    return Color::Red if isRed else Color::Blue;
15 +
}
16 +
17 +
/// Test basic scalar conditional expression.
18 +
fn testScalar() -> i32 {
19 +
    let x: i32 = 10 if true else 20;
20 +
    if x != 10 {
21 +
        return 1;
22 +
    }
23 +
    let y: i32 = 10 if false else 20;
24 +
    if y != 20 {
25 +
        return 2;
26 +
    }
27 +
    return 0;
28 +
}
29 +
30 +
/// Test min-value pattern.
31 +
fn testMin() -> i32 {
32 +
    if min(5, 3) != 3 {
33 +
        return 10;
34 +
    }
35 +
    if min(3, 5) != 3 {
36 +
        return 11;
37 +
    }
38 +
    if min(4, 4) != 4 {
39 +
        return 12;
40 +
    }
41 +
    return 0;
42 +
}
43 +
44 +
/// Test with function argument as condition.
45 +
fn testFnArg() -> i32 {
46 +
    if pick(true, 42, 99) != 42 {
47 +
        return 20;
48 +
    }
49 +
    if pick(false, 42, 99) != 99 {
50 +
        return 21;
51 +
    }
52 +
    return 0;
53 +
}
54 +
55 +
/// Test with enum values.
56 +
fn testEnum() -> i32 {
57 +
    let c1 = colorPick(true);
58 +
    if c1 != Color::Red {
59 +
        return 30;
60 +
    }
61 +
    let c2 = colorPick(false);
62 +
    if c2 != Color::Blue {
63 +
        return 31;
64 +
    }
65 +
    return 0;
66 +
}
67 +
68 +
/// Test nested conditional expression (right-associative).
69 +
fn testNested() -> i32 {
70 +
    let a: i32 = 0;
71 +
    let b: i32 = 1;
72 +
    let c: i32 = 2;
73 +
74 +
    // a if false else b if false else c => a if false else (b if false else c) => c
75 +
    let result: i32 = a if false else b if false else c;
76 +
    if result != 2 {
77 +
        return 40;
78 +
    }
79 +
80 +
    // a if true else b if false else c => a
81 +
    let result2: i32 = a if true else b if false else c;
82 +
    if result2 != 0 {
83 +
        return 41;
84 +
    }
85 +
    return 0;
86 +
}
87 +
88 +
@default fn main() -> i32 {
89 +
    let r1 = testScalar();
90 +
    if r1 != 0 {
91 +
        return r1;
92 +
    }
93 +
    let r2 = testMin();
94 +
    if r2 != 0 {
95 +
        return r2;
96 +
    }
97 +
    let r3 = testFnArg();
98 +
    if r3 != 0 {
99 +
        return r3;
100 +
    }
101 +
    let r4 = testEnum();
102 +
    if r4 != 0 {
103 +
        return r4;
104 +
    }
105 +
    let r5 = testNested();
106 +
    if r5 != 0 {
107 +
        return r5;
108 +
    }
109 +
    return 0;
110 +
}
lib/std/lang/lower.rad +52 -0
5368 5368
5369 5369
    try switchToAndSeal(self, mergeBlock);
5370 5370
    return il::Val::Reg(resultReg);
5371 5371
}
5372 5372
5373 +
/// Lower a conditional expression (`thenExpr if condition else elseExpr`).
5374 +
fn lowerCondExpr(self: *mut FnLowerer, node: *ast::Node, cond: ast::CondExpr) -> il::Val
5375 +
    throws (LowerError)
5376 +
{
5377 +
    let typ = resolver::typeFor(self.low.resolver, node) else {
5378 +
        throw LowerError::MissingType(node);
5379 +
    };
5380 +
    let thenBlock = try createBlock(self, "cond#then");
5381 +
    let elseBlock = try createBlock(self, "cond#else");
5382 +
5383 +
    if isAggregateType(typ) {
5384 +
        let dst = try emitReserve(self, typ);
5385 +
        let layout = resolver::getTypeLayout(typ);
5386 +
        try emitCondBranch(self, cond.condition, thenBlock, elseBlock);
5387 +
5388 +
        let mergeBlock = try createBlock(self, "cond#merge");
5389 +
        try switchToAndSeal(self, thenBlock);
5390 +
5391 +
        let thenVal = try emitValToReg(self, try lowerExpr(self, cond.thenExpr));
5392 +
        emit(self, il::Instr::Blit { dst, src: thenVal, size: layout.size });
5393 +
5394 +
        try emitJmp(self, mergeBlock);
5395 +
        try switchToAndSeal(self, elseBlock);
5396 +
5397 +
        let elseVal = try emitValToReg(self, try lowerExpr(self, cond.elseExpr));
5398 +
        emit(self, il::Instr::Blit { dst, src: elseVal, size: layout.size });
5399 +
5400 +
        try emitJmp(self, mergeBlock);
5401 +
        try switchToAndSeal(self, mergeBlock);
5402 +
5403 +
        return il::Val::Reg(dst);
5404 +
    } else {
5405 +
        try emitCondBranch(self, cond.condition, thenBlock, elseBlock);
5406 +
5407 +
        let resultType = ilType(self.low, typ);
5408 +
        let resultReg = nextReg(self);
5409 +
        let mergeBlock = try createBlockWithParam(
5410 +
            self, "cond#merge", il::Param { value: resultReg, type: resultType }
5411 +
        );
5412 +
        try switchToAndSeal(self, thenBlock);
5413 +
        try emitJmpWithArg(self, mergeBlock, try lowerExpr(self, cond.thenExpr));
5414 +
        try switchToAndSeal(self, elseBlock);
5415 +
        try emitJmpWithArg(self, mergeBlock, try lowerExpr(self, cond.elseExpr));
5416 +
        try switchToAndSeal(self, mergeBlock);
5417 +
5418 +
        return il::Val::Reg(resultReg);
5419 +
    }
5420 +
}
5421 +
5373 5422
/// Convert a binary operator to a comparison op, if applicable.
5374 5423
/// For `Gt`, caller must swap operands: `a > b = b < a`.
5375 5424
/// For `Gte`/`Lte`, caller must swap branch labels: `a >= b = !(a < b)`.
5376 5425
/// For `Lte`, caller must also swap operands: `a <= b = !(b < a)`.
5377 5426
fn cmpOpFrom(op: ast::BinaryOp, unsigned: bool) -> ?il::CmpOp {
6366 6415
            val = try lowerArrayRepeatLit(self, node, repeat);
6367 6416
        }
6368 6417
        case ast::NodeValue::As(cast) => {
6369 6418
            val = try lowerCast(self, node, cast);
6370 6419
        }
6420 +
        case ast::NodeValue::CondExpr(cond) => {
6421 +
            val = try lowerCondExpr(self, node, cond);
6422 +
        }
6371 6423
        case ast::NodeValue::String(s) => {
6372 6424
            val = try lowerStringLit(self, node, s);
6373 6425
        }
6374 6426
        case ast::NodeValue::Undef => {
6375 6427
            let typ = resolver::typeFor(self.low.resolver, node) else {
lib/std/lang/lower/tests/cond.expr.rad added +5 -0
1 +
//! Test lowering of conditional expressions.
2 +
fn condExpr(flag: bool) -> i32 {
3 +
    let x: i32 = 1 if flag else 2;
4 +
    return x;
5 +
}
lib/std/lang/lower/tests/cond.expr.ril added +10 -0
1 +
fn w32 $condExpr(w8 %0) {
2 +
  @entry0
3 +
    br.ne w32 %0 0 @cond#then1 @cond#else2;
4 +
  @cond#then1
5 +
    jmp @cond#merge3(1);
6 +
  @cond#else2
7 +
    jmp @cond#merge3(2);
8 +
  @cond#merge3(w32 %1)
9 +
    ret %1;
10 +
}