Support constant folding
f5bdf05cb31b838ea53a958ccba3908f71068c2f4fe8581bcf9b18e1d8bda1d2
Add constant folding for binary operations in the resolver. When both operands of a binary expression have known constant values and the result type is concrete, the operation is evaluated at compile time.
1 parent
19053951
lib/std/lang/resolver.rad
+130 -0
| 2874 | 2874 | } |
|
| 2875 | 2875 | return true; |
|
| 2876 | 2876 | } |
|
| 2877 | 2877 | return false; |
|
| 2878 | 2878 | }, |
|
| 2879 | + | case ast::NodeValue::BinOp(binop) => { |
|
| 2880 | + | // Binary expressions are constant if both operands are constant. |
|
| 2881 | + | return isConstExpr(self, binop.left) and isConstExpr(self, binop.right); |
|
| 2882 | + | }, |
|
| 2883 | + | case ast::NodeValue::UnOp(unop) => { |
|
| 2884 | + | // Unary expressions are constant if the operand is constant. |
|
| 2885 | + | return isConstExpr(self, unop.value); |
|
| 2886 | + | }, |
|
| 2879 | 2887 | else => { |
|
| 2880 | 2888 | return false; |
|
| 2881 | 2889 | } |
|
| 2882 | 2890 | } |
|
| 2883 | 2891 | } |
| 5910 | 5918 | setNodeCoercion(self, node, Coercion::ResultWrap); |
|
| 5911 | 5919 | } |
|
| 5912 | 5920 | return setNodeType(self, node, Type::Never); |
|
| 5913 | 5921 | } |
|
| 5914 | 5922 | ||
| 5923 | + | /// Convert a [`ConstInt`] to its signed two's-complement representation. |
|
| 5924 | + | fn constIntToSigned(c: ConstInt) -> i64 { |
|
| 5925 | + | if c.negative { |
|
| 5926 | + | return -(c.magnitude as i64); |
|
| 5927 | + | } |
|
| 5928 | + | return c.magnitude as i64; |
|
| 5929 | + | } |
|
| 5930 | + | ||
| 5931 | + | /// Build a [`ConstInt`] from a signed result, preserving bit width and signedness. |
|
| 5932 | + | fn constIntFromSigned(value: i64, bits: u8, signed: bool) -> ConstInt { |
|
| 5933 | + | if value < 0 { |
|
| 5934 | + | // Compute magnitude without signed overflow. |
|
| 5935 | + | let uval = value as u64; |
|
| 5936 | + | return ConstInt { |
|
| 5937 | + | magnitude: 0 - uval, |
|
| 5938 | + | bits, |
|
| 5939 | + | signed, |
|
| 5940 | + | negative: true, |
|
| 5941 | + | }; |
|
| 5942 | + | } |
|
| 5943 | + | return ConstInt { |
|
| 5944 | + | magnitude: value as u64, |
|
| 5945 | + | bits, |
|
| 5946 | + | signed, |
|
| 5947 | + | negative: false, |
|
| 5948 | + | }; |
|
| 5949 | + | } |
|
| 5950 | + | ||
| 5951 | + | /// Try to fold a binary operation on two integer constants. |
|
| 5952 | + | /// Returns the resulting constant value if successful. |
|
| 5953 | + | fn foldIntBinOp(op: ast::BinaryOp, left: ConstInt, right: ConstInt) -> ?ConstValue { |
|
| 5954 | + | // Use the wider bit width and propagate signedness. |
|
| 5955 | + | let mut bits = left.bits; |
|
| 5956 | + | if right.bits > bits { |
|
| 5957 | + | bits = right.bits; |
|
| 5958 | + | } |
|
| 5959 | + | let signed = left.signed or right.signed; |
|
| 5960 | + | let l = constIntToSigned(left); |
|
| 5961 | + | let r = constIntToSigned(right); |
|
| 5962 | + | ||
| 5963 | + | match op { |
|
| 5964 | + | // Shifts operate on unsigned magnitudes directly. |
|
| 5965 | + | case ast::BinaryOp::Shl => { |
|
| 5966 | + | return ConstValue::Int(ConstInt { |
|
| 5967 | + | magnitude: left.magnitude << right.magnitude, bits, signed, negative: left.negative, |
|
| 5968 | + | }); |
|
| 5969 | + | }, |
|
| 5970 | + | case ast::BinaryOp::Shr => { |
|
| 5971 | + | return ConstValue::Int(ConstInt { |
|
| 5972 | + | magnitude: left.magnitude >> right.magnitude, bits, signed, negative: left.negative, |
|
| 5973 | + | }); |
|
| 5974 | + | }, |
|
| 5975 | + | case ast::BinaryOp::Eq => return ConstValue::Bool(l == r), |
|
| 5976 | + | case ast::BinaryOp::Ne => return ConstValue::Bool(l != r), |
|
| 5977 | + | case ast::BinaryOp::Lt => return ConstValue::Bool(l < r), |
|
| 5978 | + | case ast::BinaryOp::Gt => return ConstValue::Bool(l > r), |
|
| 5979 | + | case ast::BinaryOp::Lte => return ConstValue::Bool(l <= r), |
|
| 5980 | + | case ast::BinaryOp::Gte => return ConstValue::Bool(l >= r), |
|
| 5981 | + | case ast::BinaryOp::Add => return ConstValue::Int(constIntFromSigned(l + r, bits, signed)), |
|
| 5982 | + | case ast::BinaryOp::Sub => return ConstValue::Int(constIntFromSigned(l - r, bits, signed)), |
|
| 5983 | + | case ast::BinaryOp::Mul => return ConstValue::Int(constIntFromSigned(l * r, bits, signed)), |
|
| 5984 | + | case ast::BinaryOp::Div => { |
|
| 5985 | + | if r == 0 { |
|
| 5986 | + | return nil; |
|
| 5987 | + | } |
|
| 5988 | + | return ConstValue::Int(constIntFromSigned(l / r, bits, signed)); |
|
| 5989 | + | }, |
|
| 5990 | + | case ast::BinaryOp::Mod => { |
|
| 5991 | + | if r == 0 { |
|
| 5992 | + | return nil; |
|
| 5993 | + | } |
|
| 5994 | + | return ConstValue::Int(constIntFromSigned(l % r, bits, signed)); |
|
| 5995 | + | }, |
|
| 5996 | + | case ast::BinaryOp::BitAnd => return ConstValue::Int(constIntFromSigned(l & r, bits, signed)), |
|
| 5997 | + | case ast::BinaryOp::BitOr => return ConstValue::Int(constIntFromSigned(l | r, bits, signed)), |
|
| 5998 | + | case ast::BinaryOp::BitXor => return ConstValue::Int(constIntFromSigned(l ^ r, bits, signed)), |
|
| 5999 | + | else => return nil, |
|
| 6000 | + | } |
|
| 6001 | + | } |
|
| 6002 | + | ||
| 6003 | + | /// Try to constant-fold a binary operation on two resolved operands. |
|
| 6004 | + | /// Only folds when the result type is concrete. |
|
| 6005 | + | fn tryFoldBinOp(self: *mut Resolver, node: *ast::Node, binop: ast::BinOp, resultTy: Type) { |
|
| 6006 | + | // Skip folding for untyped integer expressions. Their final type |
|
| 6007 | + | // depends on context, so attaching a signed constant value could |
|
| 6008 | + | // cause spurious range-check failures (e.g. `0 - 12` assigned to u8). |
|
| 6009 | + | if resultTy == Type::Int { |
|
| 6010 | + | return; |
|
| 6011 | + | } |
|
| 6012 | + | let leftVal = constValueFor(self, binop.left) |
|
| 6013 | + | else return; |
|
| 6014 | + | let rightVal = constValueFor(self, binop.right) |
|
| 6015 | + | else return; |
|
| 6016 | + | ||
| 6017 | + | // Fold integer binary ops. |
|
| 6018 | + | if let case ConstValue::Int(leftInt) = leftVal { |
|
| 6019 | + | if let case ConstValue::Int(rightInt) = rightVal { |
|
| 6020 | + | if let result = foldIntBinOp(binop.op, leftInt, rightInt) { |
|
| 6021 | + | setNodeConstValue(self, node, result); |
|
| 6022 | + | } |
|
| 6023 | + | return; |
|
| 6024 | + | } |
|
| 6025 | + | } |
|
| 6026 | + | ||
| 6027 | + | // Fold boolean binary ops. |
|
| 6028 | + | if let case ConstValue::Bool(l) = leftVal { |
|
| 6029 | + | if let case ConstValue::Bool(r) = rightVal { |
|
| 6030 | + | match binop.op { |
|
| 6031 | + | case ast::BinaryOp::And => setNodeConstValue(self, node, ConstValue::Bool(l and r)), |
|
| 6032 | + | case ast::BinaryOp::Or => setNodeConstValue(self, node, ConstValue::Bool(l or r)), |
|
| 6033 | + | case ast::BinaryOp::Eq => setNodeConstValue(self, node, ConstValue::Bool(l == r)), |
|
| 6034 | + | case ast::BinaryOp::Ne, |
|
| 6035 | + | ast::BinaryOp::Xor => setNodeConstValue(self, node, ConstValue::Bool(l != r)), |
|
| 6036 | + | else => {} |
|
| 6037 | + | } |
|
| 6038 | + | } |
|
| 6039 | + | } |
|
| 6040 | + | } |
|
| 6041 | + | ||
| 5915 | 6042 | /// Analyze a binary expression. |
|
| 5916 | 6043 | fn resolveBinOp(self: *mut Resolver, node: *ast::Node, binop: ast::BinOp) -> Type |
|
| 5917 | 6044 | throws (ResolveError) |
|
| 5918 | 6045 | { |
|
| 5919 | 6046 | let mut resultTy = Type::Unknown; |
| 6003 | 6130 | } |
|
| 6004 | 6131 | } |
|
| 6005 | 6132 | ||
| 6006 | 6133 | } |
|
| 6007 | 6134 | }; |
|
| 6135 | + | // Try constant folding after both operands are resolved. |
|
| 6136 | + | tryFoldBinOp(self, node, binop, resultTy); |
|
| 6137 | + | ||
| 6008 | 6138 | return setNodeType(self, node, resultTy); |
|
| 6009 | 6139 | } |
|
| 6010 | 6140 | ||
| 6011 | 6141 | /// Analyze a unary expression. |
|
| 6012 | 6142 | fn resolveUnOp(self: *mut Resolver, node: *ast::Node, unop: ast::UnOp) -> Type |
lib/std/lang/resolver/tests.rad
+118 -0
| 761 | 761 | try testing::expect(elemTy == super::Type::I32); |
|
| 762 | 762 | } |
|
| 763 | 763 | ||
| 764 | 764 | @test fn testResolveArrayRepeatNonConstCount() throws (testing::TestError) { |
|
| 765 | 765 | let mut a = testResolver(); |
|
| 766 | + | // `3 * 1` uses untyped integer literals, so constant folding does not |
|
| 767 | + | // produce a value. The repeat count requires a known compile-time value. |
|
| 766 | 768 | let result = try resolveProgramStr(&mut a, "let xs: [i32; 3] = [42; 3 * 1];"); |
|
| 767 | 769 | try expectErrorKind(&result, super::ErrorKind::ConstExprRequired); |
|
| 768 | 770 | } |
|
| 769 | 771 | ||
| 770 | 772 | @test fn testResolveArrayRepeatCountMismatch() throws (testing::TestError) { |
| 4905 | 4907 | let mut a = testResolver(); |
|
| 4906 | 4908 | let program = "union E { Fail } record R { x: i32 } trait T { fn (*T) get() -> i32 throws (E); } instance T for R { fn (r: *R) get() -> i32 throws (E) { throw E::Fail; return r.x; } }"; |
|
| 4907 | 4909 | let result = try resolveProgramStr(&mut a, program); |
|
| 4908 | 4910 | try expectNoErrors(&result); |
|
| 4909 | 4911 | } |
|
| 4912 | + | ||
| 4913 | + | // Constant expression folding tests ////////////////////////////////////////// |
|
| 4914 | + | ||
| 4915 | + | /// Resolve a program and verify that the constant at the given statement index |
|
| 4916 | + | /// has the expected integer magnitude. |
|
| 4917 | + | fn expectConstFold(program: *[u8], stmtIdx: u32, expected: u64) |
|
| 4918 | + | throws (testing::TestError) |
|
| 4919 | + | { |
|
| 4920 | + | let mut a = testResolver(); |
|
| 4921 | + | let result = try resolveProgramStr(&mut a, program); |
|
| 4922 | + | try expectNoErrors(&result); |
|
| 4923 | + | ||
| 4924 | + | let stmt = try getBlockStmt(result.root, stmtIdx); |
|
| 4925 | + | let sym = super::symbolFor(&a, stmt) |
|
| 4926 | + | else throw testing::TestError::Failed; |
|
| 4927 | + | let case super::SymbolData::Constant { value, .. } = sym.data |
|
| 4928 | + | else throw testing::TestError::Failed; |
|
| 4929 | + | let val = value else throw testing::TestError::Failed; |
|
| 4930 | + | let case super::ConstValue::Int(intVal) = val |
|
| 4931 | + | else throw testing::TestError::Failed; |
|
| 4932 | + | try testing::expect(intVal.magnitude == expected); |
|
| 4933 | + | } |
|
| 4934 | + | ||
| 4935 | + | /// Test arithmetic constant folding: add, sub, mul, div. |
|
| 4936 | + | @test fn testConstExprArithmetic() throws (testing::TestError) { |
|
| 4937 | + | try expectConstFold("const A: i32 = 10; const B: i32 = 20; const C: i32 = A + B;", 2, 30); |
|
| 4938 | + | try expectConstFold("const A: i32 = 50; const B: i32 = 20; const C: i32 = A - B;", 2, 30); |
|
| 4939 | + | try expectConstFold("const A: i32 = 6; const B: i32 = 7; const C: i32 = A * B;", 2, 42); |
|
| 4940 | + | try expectConstFold("const A: i32 = 100; const B: i32 = 5; const C: i32 = A / B;", 2, 20); |
|
| 4941 | + | } |
|
| 4942 | + | ||
| 4943 | + | /// Test bitwise constant folding: and, or, xor. |
|
| 4944 | + | @test fn testConstExprBitwise() throws (testing::TestError) { |
|
| 4945 | + | try expectConstFold("const A: i32 = 0xFF; const B: i32 = 0x0F; const C: i32 = A & B;", 2, 0x0F); |
|
| 4946 | + | try expectConstFold("const A: i32 = 0xF0; const B: i32 = 0x0F; const C: i32 = A | B;", 2, 0xFF); |
|
| 4947 | + | try expectConstFold("const A: i32 = 0xFF; const B: i32 = 0x0F; const C: i32 = A ^ B;", 2, 0xF0); |
|
| 4948 | + | } |
|
| 4949 | + | ||
| 4950 | + | /// Test shift constant folding. |
|
| 4951 | + | @test fn testConstExprShift() throws (testing::TestError) { |
|
| 4952 | + | try expectConstFold("const A: i32 = 1; const B: i32 = A << 4;", 1, 16); |
|
| 4953 | + | try expectConstFold("const A: i32 = 32; const B: i32 = A >> 2;", 1, 8); |
|
| 4954 | + | } |
|
| 4955 | + | ||
| 4956 | + | /// Test chained constant expressions (C depends on A + B, D depends on C). |
|
| 4957 | + | @test fn testConstExprChained() throws (testing::TestError) { |
|
| 4958 | + | try expectConstFold("const A: i32 = 10; const B: i32 = 20; const C: i32 = A + B; const D: i32 = C * 2;", 3, 60); |
|
| 4959 | + | } |
|
| 4960 | + | ||
| 4961 | + | /// Test constant expression used as array size. |
|
| 4962 | + | @test fn testConstExprAsArraySize() throws (testing::TestError) { |
|
| 4963 | + | let mut a = testResolver(); |
|
| 4964 | + | let program = "const A: u32 = 2; const B: u32 = 3; const SIZE: u32 = A + B; const ARR: [i32; SIZE] = [1, 2, 3, 4, 5];"; |
|
| 4965 | + | let result = try resolveProgramStr(&mut a, program); |
|
| 4966 | + | try expectNoErrors(&result); |
|
| 4967 | + | ||
| 4968 | + | let arrStmt = try getBlockStmt(result.root, 3); |
|
| 4969 | + | let sym = super::symbolFor(&a, arrStmt) |
|
| 4970 | + | else throw testing::TestError::Failed; |
|
| 4971 | + | let case super::SymbolData::Constant { type: super::Type::Array(arrType), .. } = sym.data |
|
| 4972 | + | else throw testing::TestError::Failed; |
|
| 4973 | + | try testing::expect(arrType.length == 5); |
|
| 4974 | + | } |
|
| 4975 | + | ||
| 4976 | + | /// Test cross-module constant expression: a constant in one module references |
|
| 4977 | + | /// a constant from another module via scope access. |
|
| 4978 | + | @test fn testCrossModuleConstExpr() throws (testing::TestError) { |
|
| 4979 | + | let mut a = testResolver(); |
|
| 4980 | + | let mut arena = ast::nodeArena(&mut AST_ARENA[..]); |
|
| 4981 | + | ||
| 4982 | + | let rootId = try registerModule(&mut MODULE_GRAPH, nil, "root", "pub mod consts; mod app;", &mut arena); |
|
| 4983 | + | let constsId = try registerModule(&mut MODULE_GRAPH, rootId, "consts", "pub const BASE: i32 = 100;", &mut arena); |
|
| 4984 | + | let appId = try registerModule(&mut MODULE_GRAPH, rootId, "app", "use root::consts; const DERIVED: i32 = consts::BASE + 50;", &mut arena); |
|
| 4985 | + | ||
| 4986 | + | let result = try resolveModuleTree(&mut a, rootId); |
|
| 4987 | + | try expectNoErrors(&result); |
|
| 4988 | + | } |
|
| 4989 | + | ||
| 4990 | + | /// Test cross-module constant expression used as array size. |
|
| 4991 | + | @test fn testCrossModuleConstExprArraySize() throws (testing::TestError) { |
|
| 4992 | + | let mut a = testResolver(); |
|
| 4993 | + | let mut arena = ast::nodeArena(&mut AST_ARENA[..]); |
|
| 4994 | + | ||
| 4995 | + | let rootId = try registerModule(&mut MODULE_GRAPH, nil, "root", "pub mod consts; mod app;", &mut arena); |
|
| 4996 | + | let constsId = try registerModule(&mut MODULE_GRAPH, rootId, "consts", "pub const WIDTH: u32 = 8; pub const HEIGHT: u32 = 4;", &mut arena); |
|
| 4997 | + | let appId = try registerModule(&mut MODULE_GRAPH, rootId, "app", "use root::consts; const TOTAL: u32 = consts::WIDTH * consts::HEIGHT; static BUF: [u8; TOTAL] = undefined;", &mut arena); |
|
| 4998 | + | ||
| 4999 | + | let result = try resolveModuleTree(&mut a, rootId); |
|
| 5000 | + | try expectNoErrors(&result); |
|
| 5001 | + | } |
|
| 5002 | + | ||
| 5003 | + | /// Test that non-constant expressions in const declarations are still rejected. |
|
| 5004 | + | @test fn testConstExprNonConstRejected() throws (testing::TestError) { |
|
| 5005 | + | let mut a = testResolver(); |
|
| 5006 | + | let program = "fn value() -> i32 { return 1; } const BAD: i32 = value() + 1;"; |
|
| 5007 | + | let result = try resolveProgramStr(&mut a, program); |
|
| 5008 | + | let err = try expectError(&result); |
|
| 5009 | + | let case super::ErrorKind::ConstExprRequired = err.kind |
|
| 5010 | + | else throw testing::TestError::Failed; |
|
| 5011 | + | } |
|
| 5012 | + | ||
| 5013 | + | /// Test unary negation in constant expressions. |
|
| 5014 | + | @test fn testConstExprUnaryNeg() throws (testing::TestError) { |
|
| 5015 | + | let mut a = testResolver(); |
|
| 5016 | + | let program = "const A: i32 = 10; const B: i32 = -A;"; |
|
| 5017 | + | let result = try resolveProgramStr(&mut a, program); |
|
| 5018 | + | try expectNoErrors(&result); |
|
| 5019 | + | } |
|
| 5020 | + | ||
| 5021 | + | /// Test unary not in constant expressions. |
|
| 5022 | + | @test fn testConstExprUnaryNot() throws (testing::TestError) { |
|
| 5023 | + | let mut a = testResolver(); |
|
| 5024 | + | let program = "const A: bool = true; const B: bool = not A;"; |
|
| 5025 | + | let result = try resolveProgramStr(&mut a, program); |
|
| 5026 | + | try expectNoErrors(&result); |
|
| 5027 | + | } |
test/tests/const-expr-array-size.rad
added
+16 -0
| 1 | + | //! returns: 0 |
|
| 2 | + | ||
| 3 | + | /// Test that constant expressions referencing other constants |
|
| 4 | + | /// can be used as array sizes. |
|
| 5 | + | ||
| 6 | + | const ROWS: u32 = 3; |
|
| 7 | + | const COLS: u32 = 4; |
|
| 8 | + | const TOTAL: u32 = ROWS * COLS; |
|
| 9 | + | ||
| 10 | + | static DATA: [i32; TOTAL] = undefined; |
|
| 11 | + | ||
| 12 | + | @default fn main() -> i32 { |
|
| 13 | + | // Verify the array has the expected length. |
|
| 14 | + | if DATA.len != 12 { return 1; } |
|
| 15 | + | return 0; |
|
| 16 | + | } |
test/tests/const-expr-refs.rad
added
+31 -0
| 1 | + | //! returns: 0 |
|
| 2 | + | ||
| 3 | + | /// Test that constant expressions can reference other constants, |
|
| 4 | + | /// including arithmetic on constants. |
|
| 5 | + | ||
| 6 | + | const A: i32 = 10; |
|
| 7 | + | const B: i32 = 20; |
|
| 8 | + | const C: i32 = A + B; |
|
| 9 | + | const D: i32 = C * 2; |
|
| 10 | + | const E: i32 = D - A; |
|
| 11 | + | const F: i32 = E / 5; |
|
| 12 | + | const G: i32 = A % 3; |
|
| 13 | + | const H: i32 = A | B; |
|
| 14 | + | const I: i32 = A & B; |
|
| 15 | + | const J: i32 = A ^ B; |
|
| 16 | + | const K: i32 = A << 2; |
|
| 17 | + | const L: i32 = B >> 1; |
|
| 18 | + | ||
| 19 | + | @default fn main() -> i32 { |
|
| 20 | + | if C != 30 { return 1; } |
|
| 21 | + | if D != 60 { return 2; } |
|
| 22 | + | if E != 50 { return 3; } |
|
| 23 | + | if F != 10 { return 4; } |
|
| 24 | + | if G != 1 { return 5; } |
|
| 25 | + | if H != 30 { return 6; } |
|
| 26 | + | if I != 0 { return 7; } |
|
| 27 | + | if J != 30 { return 8; } |
|
| 28 | + | if K != 40 { return 9; } |
|
| 29 | + | if L != 10 { return 10; } |
|
| 30 | + | return 0; |
|
| 31 | + | } |