Lower conditional expressions to IL
ff8c8e43fcef5f4dc69a22f50143e87ee58470a4f2f131569ab9f42e2b47a1d2
This was previously not implemented even though parsing was there.
1 parent
de09a137
lib/std/arch/rv64/tests/cond.expr.aggregate.rad
added
+73 -0
| 1 | + | //! Test conditional expression with aggregate types (tagged unions, optionals). |
|
| 2 | + | ||
| 3 | + | union Val { |
|
| 4 | + | Reg(u32), |
|
| 5 | + | Imm(i64), |
|
| 6 | + | } |
|
| 7 | + | ||
| 8 | + | fn pickVal(flag: bool, reg: u32) -> Val { |
|
| 9 | + | return Val::Reg(reg) if flag else Val::Imm(99); |
|
| 10 | + | } |
|
| 11 | + | ||
| 12 | + | fn pickOpt(a: ?u16, b: ?u16) -> ?u16 { |
|
| 13 | + | return b if a == nil else a; |
|
| 14 | + | } |
|
| 15 | + | ||
| 16 | + | /// Test tagged union conditional expression. |
|
| 17 | + | fn testUnion() -> i32 { |
|
| 18 | + | let v1 = pickVal(true, 7); |
|
| 19 | + | if let case Val::Reg(r) = v1 { |
|
| 20 | + | if r != 7 { |
|
| 21 | + | return 1; |
|
| 22 | + | } |
|
| 23 | + | } else { |
|
| 24 | + | return 2; |
|
| 25 | + | } |
|
| 26 | + | ||
| 27 | + | let v2 = pickVal(false, 7); |
|
| 28 | + | if let case Val::Imm(i) = v2 { |
|
| 29 | + | if i != 99 { |
|
| 30 | + | return 3; |
|
| 31 | + | } |
|
| 32 | + | } else { |
|
| 33 | + | return 4; |
|
| 34 | + | } |
|
| 35 | + | return 0; |
|
| 36 | + | } |
|
| 37 | + | ||
| 38 | + | /// Test optional conditional expression. |
|
| 39 | + | fn testOptional() -> i32 { |
|
| 40 | + | let a: ?u16 = 42; |
|
| 41 | + | let b: ?u16 = 99; |
|
| 42 | + | let n: ?u16 = nil; |
|
| 43 | + | ||
| 44 | + | let r1 = pickOpt(a, b); |
|
| 45 | + | let v1 = r1 else { |
|
| 46 | + | return 10; |
|
| 47 | + | }; |
|
| 48 | + | if v1 != 42 { |
|
| 49 | + | return 11; |
|
| 50 | + | } |
|
| 51 | + | ||
| 52 | + | let r2 = pickOpt(n, b); |
|
| 53 | + | let v2 = r2 else { |
|
| 54 | + | return 12; |
|
| 55 | + | }; |
|
| 56 | + | if v2 != 99 { |
|
| 57 | + | return 13; |
|
| 58 | + | } |
|
| 59 | + | ||
| 60 | + | return 0; |
|
| 61 | + | } |
|
| 62 | + | ||
| 63 | + | @default fn main() -> i32 { |
|
| 64 | + | let r1 = testUnion(); |
|
| 65 | + | if r1 != 0 { |
|
| 66 | + | return r1; |
|
| 67 | + | } |
|
| 68 | + | let r2 = testOptional(); |
|
| 69 | + | if r2 != 0 { |
|
| 70 | + | return r2; |
|
| 71 | + | } |
|
| 72 | + | return 0; |
|
| 73 | + | } |
lib/std/arch/rv64/tests/cond.expr.rad
added
+110 -0
| 1 | + | //! Test conditional expression lowering (`thenExpr if cond else elseExpr`). |
|
| 2 | + | ||
| 3 | + | fn min(a: u32, b: u32) -> u32 { |
|
| 4 | + | return b if b < a else a; |
|
| 5 | + | } |
|
| 6 | + | ||
| 7 | + | fn pick(cond: bool, a: i32, b: i32) -> i32 { |
|
| 8 | + | return a if cond else b; |
|
| 9 | + | } |
|
| 10 | + | ||
| 11 | + | union Color { Red, Green, Blue } |
|
| 12 | + | ||
| 13 | + | fn colorPick(isRed: bool) -> Color { |
|
| 14 | + | return Color::Red if isRed else Color::Blue; |
|
| 15 | + | } |
|
| 16 | + | ||
| 17 | + | /// Test basic scalar conditional expression. |
|
| 18 | + | fn testScalar() -> i32 { |
|
| 19 | + | let x: i32 = 10 if true else 20; |
|
| 20 | + | if x != 10 { |
|
| 21 | + | return 1; |
|
| 22 | + | } |
|
| 23 | + | let y: i32 = 10 if false else 20; |
|
| 24 | + | if y != 20 { |
|
| 25 | + | return 2; |
|
| 26 | + | } |
|
| 27 | + | return 0; |
|
| 28 | + | } |
|
| 29 | + | ||
| 30 | + | /// Test min-value pattern. |
|
| 31 | + | fn testMin() -> i32 { |
|
| 32 | + | if min(5, 3) != 3 { |
|
| 33 | + | return 10; |
|
| 34 | + | } |
|
| 35 | + | if min(3, 5) != 3 { |
|
| 36 | + | return 11; |
|
| 37 | + | } |
|
| 38 | + | if min(4, 4) != 4 { |
|
| 39 | + | return 12; |
|
| 40 | + | } |
|
| 41 | + | return 0; |
|
| 42 | + | } |
|
| 43 | + | ||
| 44 | + | /// Test with function argument as condition. |
|
| 45 | + | fn testFnArg() -> i32 { |
|
| 46 | + | if pick(true, 42, 99) != 42 { |
|
| 47 | + | return 20; |
|
| 48 | + | } |
|
| 49 | + | if pick(false, 42, 99) != 99 { |
|
| 50 | + | return 21; |
|
| 51 | + | } |
|
| 52 | + | return 0; |
|
| 53 | + | } |
|
| 54 | + | ||
| 55 | + | /// Test with enum values. |
|
| 56 | + | fn testEnum() -> i32 { |
|
| 57 | + | let c1 = colorPick(true); |
|
| 58 | + | if c1 != Color::Red { |
|
| 59 | + | return 30; |
|
| 60 | + | } |
|
| 61 | + | let c2 = colorPick(false); |
|
| 62 | + | if c2 != Color::Blue { |
|
| 63 | + | return 31; |
|
| 64 | + | } |
|
| 65 | + | return 0; |
|
| 66 | + | } |
|
| 67 | + | ||
| 68 | + | /// Test nested conditional expression (right-associative). |
|
| 69 | + | fn testNested() -> i32 { |
|
| 70 | + | let a: i32 = 0; |
|
| 71 | + | let b: i32 = 1; |
|
| 72 | + | let c: i32 = 2; |
|
| 73 | + | ||
| 74 | + | // a if false else b if false else c => a if false else (b if false else c) => c |
|
| 75 | + | let result: i32 = a if false else b if false else c; |
|
| 76 | + | if result != 2 { |
|
| 77 | + | return 40; |
|
| 78 | + | } |
|
| 79 | + | ||
| 80 | + | // a if true else b if false else c => a |
|
| 81 | + | let result2: i32 = a if true else b if false else c; |
|
| 82 | + | if result2 != 0 { |
|
| 83 | + | return 41; |
|
| 84 | + | } |
|
| 85 | + | return 0; |
|
| 86 | + | } |
|
| 87 | + | ||
| 88 | + | @default fn main() -> i32 { |
|
| 89 | + | let r1 = testScalar(); |
|
| 90 | + | if r1 != 0 { |
|
| 91 | + | return r1; |
|
| 92 | + | } |
|
| 93 | + | let r2 = testMin(); |
|
| 94 | + | if r2 != 0 { |
|
| 95 | + | return r2; |
|
| 96 | + | } |
|
| 97 | + | let r3 = testFnArg(); |
|
| 98 | + | if r3 != 0 { |
|
| 99 | + | return r3; |
|
| 100 | + | } |
|
| 101 | + | let r4 = testEnum(); |
|
| 102 | + | if r4 != 0 { |
|
| 103 | + | return r4; |
|
| 104 | + | } |
|
| 105 | + | let r5 = testNested(); |
|
| 106 | + | if r5 != 0 { |
|
| 107 | + | return r5; |
|
| 108 | + | } |
|
| 109 | + | return 0; |
|
| 110 | + | } |
lib/std/lang/lower.rad
+52 -0
| 5368 | 5368 | ||
| 5369 | 5369 | try switchToAndSeal(self, mergeBlock); |
|
| 5370 | 5370 | return il::Val::Reg(resultReg); |
|
| 5371 | 5371 | } |
|
| 5372 | 5372 | ||
| 5373 | + | /// Lower a conditional expression (`thenExpr if condition else elseExpr`). |
|
| 5374 | + | fn lowerCondExpr(self: *mut FnLowerer, node: *ast::Node, cond: ast::CondExpr) -> il::Val |
|
| 5375 | + | throws (LowerError) |
|
| 5376 | + | { |
|
| 5377 | + | let typ = resolver::typeFor(self.low.resolver, node) else { |
|
| 5378 | + | throw LowerError::MissingType(node); |
|
| 5379 | + | }; |
|
| 5380 | + | let thenBlock = try createBlock(self, "cond#then"); |
|
| 5381 | + | let elseBlock = try createBlock(self, "cond#else"); |
|
| 5382 | + | ||
| 5383 | + | if isAggregateType(typ) { |
|
| 5384 | + | let dst = try emitReserve(self, typ); |
|
| 5385 | + | let layout = resolver::getTypeLayout(typ); |
|
| 5386 | + | try emitCondBranch(self, cond.condition, thenBlock, elseBlock); |
|
| 5387 | + | ||
| 5388 | + | let mergeBlock = try createBlock(self, "cond#merge"); |
|
| 5389 | + | try switchToAndSeal(self, thenBlock); |
|
| 5390 | + | ||
| 5391 | + | let thenVal = try emitValToReg(self, try lowerExpr(self, cond.thenExpr)); |
|
| 5392 | + | emit(self, il::Instr::Blit { dst, src: thenVal, size: layout.size }); |
|
| 5393 | + | ||
| 5394 | + | try emitJmp(self, mergeBlock); |
|
| 5395 | + | try switchToAndSeal(self, elseBlock); |
|
| 5396 | + | ||
| 5397 | + | let elseVal = try emitValToReg(self, try lowerExpr(self, cond.elseExpr)); |
|
| 5398 | + | emit(self, il::Instr::Blit { dst, src: elseVal, size: layout.size }); |
|
| 5399 | + | ||
| 5400 | + | try emitJmp(self, mergeBlock); |
|
| 5401 | + | try switchToAndSeal(self, mergeBlock); |
|
| 5402 | + | ||
| 5403 | + | return il::Val::Reg(dst); |
|
| 5404 | + | } else { |
|
| 5405 | + | try emitCondBranch(self, cond.condition, thenBlock, elseBlock); |
|
| 5406 | + | ||
| 5407 | + | let resultType = ilType(self.low, typ); |
|
| 5408 | + | let resultReg = nextReg(self); |
|
| 5409 | + | let mergeBlock = try createBlockWithParam( |
|
| 5410 | + | self, "cond#merge", il::Param { value: resultReg, type: resultType } |
|
| 5411 | + | ); |
|
| 5412 | + | try switchToAndSeal(self, thenBlock); |
|
| 5413 | + | try emitJmpWithArg(self, mergeBlock, try lowerExpr(self, cond.thenExpr)); |
|
| 5414 | + | try switchToAndSeal(self, elseBlock); |
|
| 5415 | + | try emitJmpWithArg(self, mergeBlock, try lowerExpr(self, cond.elseExpr)); |
|
| 5416 | + | try switchToAndSeal(self, mergeBlock); |
|
| 5417 | + | ||
| 5418 | + | return il::Val::Reg(resultReg); |
|
| 5419 | + | } |
|
| 5420 | + | } |
|
| 5421 | + | ||
| 5373 | 5422 | /// Convert a binary operator to a comparison op, if applicable. |
|
| 5374 | 5423 | /// For `Gt`, caller must swap operands: `a > b = b < a`. |
|
| 5375 | 5424 | /// For `Gte`/`Lte`, caller must swap branch labels: `a >= b = !(a < b)`. |
|
| 5376 | 5425 | /// For `Lte`, caller must also swap operands: `a <= b = !(b < a)`. |
|
| 5377 | 5426 | fn cmpOpFrom(op: ast::BinaryOp, unsigned: bool) -> ?il::CmpOp { |
| 6366 | 6415 | val = try lowerArrayRepeatLit(self, node, repeat); |
|
| 6367 | 6416 | } |
|
| 6368 | 6417 | case ast::NodeValue::As(cast) => { |
|
| 6369 | 6418 | val = try lowerCast(self, node, cast); |
|
| 6370 | 6419 | } |
|
| 6420 | + | case ast::NodeValue::CondExpr(cond) => { |
|
| 6421 | + | val = try lowerCondExpr(self, node, cond); |
|
| 6422 | + | } |
|
| 6371 | 6423 | case ast::NodeValue::String(s) => { |
|
| 6372 | 6424 | val = try lowerStringLit(self, node, s); |
|
| 6373 | 6425 | } |
|
| 6374 | 6426 | case ast::NodeValue::Undef => { |
|
| 6375 | 6427 | let typ = resolver::typeFor(self.low.resolver, node) else { |
lib/std/lang/lower/tests/cond.expr.rad
added
+5 -0
| 1 | + | //! Test lowering of conditional expressions. |
|
| 2 | + | fn condExpr(flag: bool) -> i32 { |
|
| 3 | + | let x: i32 = 1 if flag else 2; |
|
| 4 | + | return x; |
|
| 5 | + | } |
lib/std/lang/lower/tests/cond.expr.ril
added
+10 -0
| 1 | + | fn w32 $condExpr(w8 %0) { |
|
| 2 | + | @entry0 |
|
| 3 | + | br.ne w32 %0 0 @cond#then1 @cond#else2; |
|
| 4 | + | @cond#then1 |
|
| 5 | + | jmp @cond#merge3(1); |
|
| 6 | + | @cond#else2 |
|
| 7 | + | jmp @cond#merge3(2); |
|
| 8 | + | @cond#merge3(w32 %1) |
|
| 9 | + | ret %1; |
|
| 10 | + | } |