Support constant folding

f5bdf05cb31b838ea53a958ccba3908f71068c2f4fe8581bcf9b18e1d8bda1d2
Add constant folding for binary operations in the resolver. When both
operands of a binary expression have known constant values and the
result type is concrete, the operation is evaluated at compile time.
Alexis Sellier committed ago 1 parent 19053951
lib/std/lang/resolver.rad +130 -0
2874 2874
                }
2875 2875
                return true;
2876 2876
            }
2877 2877
            return false;
2878 2878
        },
2879 +
        case ast::NodeValue::BinOp(binop) => {
2880 +
            // Binary expressions are constant if both operands are constant.
2881 +
            return isConstExpr(self, binop.left) and isConstExpr(self, binop.right);
2882 +
        },
2883 +
        case ast::NodeValue::UnOp(unop) => {
2884 +
            // Unary expressions are constant if the operand is constant.
2885 +
            return isConstExpr(self, unop.value);
2886 +
        },
2879 2887
        else => {
2880 2888
            return false;
2881 2889
        }
2882 2890
    }
2883 2891
}
5910 5918
        setNodeCoercion(self, node, Coercion::ResultWrap);
5911 5919
    }
5912 5920
    return setNodeType(self, node, Type::Never);
5913 5921
}
5914 5922
5923 +
/// Convert a [`ConstInt`] to its signed two's-complement representation.
5924 +
fn constIntToSigned(c: ConstInt) -> i64 {
5925 +
    if c.negative {
5926 +
        return -(c.magnitude as i64);
5927 +
    }
5928 +
    return c.magnitude as i64;
5929 +
}
5930 +
5931 +
/// Build a [`ConstInt`] from a signed result, preserving bit width and signedness.
5932 +
fn constIntFromSigned(value: i64, bits: u8, signed: bool) -> ConstInt {
5933 +
    if value < 0 {
5934 +
        // Compute magnitude without signed overflow.
5935 +
        let uval = value as u64;
5936 +
        return ConstInt {
5937 +
            magnitude: 0 - uval,
5938 +
            bits,
5939 +
            signed,
5940 +
            negative: true,
5941 +
        };
5942 +
    }
5943 +
    return ConstInt {
5944 +
        magnitude: value as u64,
5945 +
        bits,
5946 +
        signed,
5947 +
        negative: false,
5948 +
    };
5949 +
}
5950 +
5951 +
/// Try to fold a binary operation on two integer constants.
5952 +
/// Returns the resulting constant value if successful.
5953 +
fn foldIntBinOp(op: ast::BinaryOp, left: ConstInt, right: ConstInt) -> ?ConstValue {
5954 +
    // Use the wider bit width and propagate signedness.
5955 +
    let mut bits = left.bits;
5956 +
    if right.bits > bits {
5957 +
        bits = right.bits;
5958 +
    }
5959 +
    let signed = left.signed or right.signed;
5960 +
    let l = constIntToSigned(left);
5961 +
    let r = constIntToSigned(right);
5962 +
5963 +
    match op {
5964 +
        // Shifts operate on unsigned magnitudes directly.
5965 +
        case ast::BinaryOp::Shl => {
5966 +
            return ConstValue::Int(ConstInt {
5967 +
                magnitude: left.magnitude << right.magnitude, bits, signed, negative: left.negative,
5968 +
            });
5969 +
        },
5970 +
        case ast::BinaryOp::Shr => {
5971 +
            return ConstValue::Int(ConstInt {
5972 +
                magnitude: left.magnitude >> right.magnitude, bits, signed, negative: left.negative,
5973 +
            });
5974 +
        },
5975 +
        case ast::BinaryOp::Eq  => return ConstValue::Bool(l == r),
5976 +
        case ast::BinaryOp::Ne  => return ConstValue::Bool(l != r),
5977 +
        case ast::BinaryOp::Lt  => return ConstValue::Bool(l < r),
5978 +
        case ast::BinaryOp::Gt  => return ConstValue::Bool(l > r),
5979 +
        case ast::BinaryOp::Lte => return ConstValue::Bool(l <= r),
5980 +
        case ast::BinaryOp::Gte => return ConstValue::Bool(l >= r),
5981 +
        case ast::BinaryOp::Add => return ConstValue::Int(constIntFromSigned(l + r, bits, signed)),
5982 +
        case ast::BinaryOp::Sub => return ConstValue::Int(constIntFromSigned(l - r, bits, signed)),
5983 +
        case ast::BinaryOp::Mul => return ConstValue::Int(constIntFromSigned(l * r, bits, signed)),
5984 +
        case ast::BinaryOp::Div => {
5985 +
            if r == 0 {
5986 +
                return nil;
5987 +
            }
5988 +
            return ConstValue::Int(constIntFromSigned(l / r, bits, signed));
5989 +
        },
5990 +
        case ast::BinaryOp::Mod => {
5991 +
            if r == 0 {
5992 +
                return nil;
5993 +
            }
5994 +
            return ConstValue::Int(constIntFromSigned(l % r, bits, signed));
5995 +
        },
5996 +
        case ast::BinaryOp::BitAnd => return ConstValue::Int(constIntFromSigned(l & r, bits, signed)),
5997 +
        case ast::BinaryOp::BitOr  => return ConstValue::Int(constIntFromSigned(l | r, bits, signed)),
5998 +
        case ast::BinaryOp::BitXor => return ConstValue::Int(constIntFromSigned(l ^ r, bits, signed)),
5999 +
        else => return nil,
6000 +
    }
6001 +
}
6002 +
6003 +
/// Try to constant-fold a binary operation on two resolved operands.
6004 +
/// Only folds when the result type is concrete.
6005 +
fn tryFoldBinOp(self: *mut Resolver, node: *ast::Node, binop: ast::BinOp, resultTy: Type) {
6006 +
    // Skip folding for untyped integer expressions. Their final type
6007 +
    // depends on context, so attaching a signed constant value could
6008 +
    // cause spurious range-check failures (e.g. `0 - 12` assigned to u8).
6009 +
    if resultTy == Type::Int {
6010 +
        return;
6011 +
    }
6012 +
    let leftVal = constValueFor(self, binop.left)
6013 +
        else return;
6014 +
    let rightVal = constValueFor(self, binop.right)
6015 +
        else return;
6016 +
6017 +
    // Fold integer binary ops.
6018 +
    if let case ConstValue::Int(leftInt) = leftVal {
6019 +
        if let case ConstValue::Int(rightInt) = rightVal {
6020 +
            if let result = foldIntBinOp(binop.op, leftInt, rightInt) {
6021 +
                setNodeConstValue(self, node, result);
6022 +
            }
6023 +
            return;
6024 +
        }
6025 +
    }
6026 +
6027 +
    // Fold boolean binary ops.
6028 +
    if let case ConstValue::Bool(l) = leftVal {
6029 +
        if let case ConstValue::Bool(r) = rightVal {
6030 +
            match binop.op {
6031 +
                case ast::BinaryOp::And => setNodeConstValue(self, node, ConstValue::Bool(l and r)),
6032 +
                case ast::BinaryOp::Or => setNodeConstValue(self, node, ConstValue::Bool(l or r)),
6033 +
                case ast::BinaryOp::Eq => setNodeConstValue(self, node, ConstValue::Bool(l == r)),
6034 +
                case ast::BinaryOp::Ne,
6035 +
                     ast::BinaryOp::Xor => setNodeConstValue(self, node, ConstValue::Bool(l != r)),
6036 +
                else => {}
6037 +
            }
6038 +
        }
6039 +
    }
6040 +
}
6041 +
5915 6042
/// Analyze a binary expression.
5916 6043
fn resolveBinOp(self: *mut Resolver, node: *ast::Node, binop: ast::BinOp) -> Type
5917 6044
    throws (ResolveError)
5918 6045
{
5919 6046
    let mut resultTy = Type::Unknown;
6003 6130
                }
6004 6131
            }
6005 6132
6006 6133
        }
6007 6134
    };
6135 +
    // Try constant folding after both operands are resolved.
6136 +
    tryFoldBinOp(self, node, binop, resultTy);
6137 +
6008 6138
    return setNodeType(self, node, resultTy);
6009 6139
}
6010 6140
6011 6141
/// Analyze a unary expression.
6012 6142
fn resolveUnOp(self: *mut Resolver, node: *ast::Node, unop: ast::UnOp) -> Type
lib/std/lang/resolver/tests.rad +118 -0
761 761
    try testing::expect(elemTy == super::Type::I32);
762 762
}
763 763
764 764
@test fn testResolveArrayRepeatNonConstCount() throws (testing::TestError) {
765 765
    let mut a = testResolver();
766 +
    // `3 * 1` uses untyped integer literals, so constant folding does not
767 +
    // produce a value. The repeat count requires a known compile-time value.
766 768
    let result = try resolveProgramStr(&mut a, "let xs: [i32; 3] = [42; 3 * 1];");
767 769
    try expectErrorKind(&result, super::ErrorKind::ConstExprRequired);
768 770
}
769 771
770 772
@test fn testResolveArrayRepeatCountMismatch() throws (testing::TestError) {
4905 4907
    let mut a = testResolver();
4906 4908
    let program = "union E { Fail } record R { x: i32 } trait T { fn (*T) get() -> i32 throws (E); } instance T for R { fn (r: *R) get() -> i32 throws (E) { throw E::Fail; return r.x; } }";
4907 4909
    let result = try resolveProgramStr(&mut a, program);
4908 4910
    try expectNoErrors(&result);
4909 4911
}
4912 +
4913 +
// Constant expression folding tests //////////////////////////////////////////
4914 +
4915 +
/// Resolve a program and verify that the constant at the given statement index
4916 +
/// has the expected integer magnitude.
4917 +
fn expectConstFold(program: *[u8], stmtIdx: u32, expected: u64)
4918 +
    throws (testing::TestError)
4919 +
{
4920 +
    let mut a = testResolver();
4921 +
    let result = try resolveProgramStr(&mut a, program);
4922 +
    try expectNoErrors(&result);
4923 +
4924 +
    let stmt = try getBlockStmt(result.root, stmtIdx);
4925 +
    let sym = super::symbolFor(&a, stmt)
4926 +
        else throw testing::TestError::Failed;
4927 +
    let case super::SymbolData::Constant { value, .. } = sym.data
4928 +
        else throw testing::TestError::Failed;
4929 +
    let val = value else throw testing::TestError::Failed;
4930 +
    let case super::ConstValue::Int(intVal) = val
4931 +
        else throw testing::TestError::Failed;
4932 +
    try testing::expect(intVal.magnitude == expected);
4933 +
}
4934 +
4935 +
/// Test arithmetic constant folding: add, sub, mul, div.
4936 +
@test fn testConstExprArithmetic() throws (testing::TestError) {
4937 +
    try expectConstFold("const A: i32 = 10; const B: i32 = 20; const C: i32 = A + B;", 2, 30);
4938 +
    try expectConstFold("const A: i32 = 50; const B: i32 = 20; const C: i32 = A - B;", 2, 30);
4939 +
    try expectConstFold("const A: i32 = 6; const B: i32 = 7; const C: i32 = A * B;", 2, 42);
4940 +
    try expectConstFold("const A: i32 = 100; const B: i32 = 5; const C: i32 = A / B;", 2, 20);
4941 +
}
4942 +
4943 +
/// Test bitwise constant folding: and, or, xor.
4944 +
@test fn testConstExprBitwise() throws (testing::TestError) {
4945 +
    try expectConstFold("const A: i32 = 0xFF; const B: i32 = 0x0F; const C: i32 = A & B;", 2, 0x0F);
4946 +
    try expectConstFold("const A: i32 = 0xF0; const B: i32 = 0x0F; const C: i32 = A | B;", 2, 0xFF);
4947 +
    try expectConstFold("const A: i32 = 0xFF; const B: i32 = 0x0F; const C: i32 = A ^ B;", 2, 0xF0);
4948 +
}
4949 +
4950 +
/// Test shift constant folding.
4951 +
@test fn testConstExprShift() throws (testing::TestError) {
4952 +
    try expectConstFold("const A: i32 = 1; const B: i32 = A << 4;", 1, 16);
4953 +
    try expectConstFold("const A: i32 = 32; const B: i32 = A >> 2;", 1, 8);
4954 +
}
4955 +
4956 +
/// Test chained constant expressions (C depends on A + B, D depends on C).
4957 +
@test fn testConstExprChained() throws (testing::TestError) {
4958 +
    try expectConstFold("const A: i32 = 10; const B: i32 = 20; const C: i32 = A + B; const D: i32 = C * 2;", 3, 60);
4959 +
}
4960 +
4961 +
/// Test constant expression used as array size.
4962 +
@test fn testConstExprAsArraySize() throws (testing::TestError) {
4963 +
    let mut a = testResolver();
4964 +
    let program = "const A: u32 = 2; const B: u32 = 3; const SIZE: u32 = A + B; const ARR: [i32; SIZE] = [1, 2, 3, 4, 5];";
4965 +
    let result = try resolveProgramStr(&mut a, program);
4966 +
    try expectNoErrors(&result);
4967 +
4968 +
    let arrStmt = try getBlockStmt(result.root, 3);
4969 +
    let sym = super::symbolFor(&a, arrStmt)
4970 +
        else throw testing::TestError::Failed;
4971 +
    let case super::SymbolData::Constant { type: super::Type::Array(arrType), .. } = sym.data
4972 +
        else throw testing::TestError::Failed;
4973 +
    try testing::expect(arrType.length == 5);
4974 +
}
4975 +
4976 +
/// Test cross-module constant expression: a constant in one module references
4977 +
/// a constant from another module via scope access.
4978 +
@test fn testCrossModuleConstExpr() throws (testing::TestError) {
4979 +
    let mut a = testResolver();
4980 +
    let mut arena = ast::nodeArena(&mut AST_ARENA[..]);
4981 +
4982 +
    let rootId = try registerModule(&mut MODULE_GRAPH, nil, "root", "pub mod consts; mod app;", &mut arena);
4983 +
    let constsId = try registerModule(&mut MODULE_GRAPH, rootId, "consts", "pub const BASE: i32 = 100;", &mut arena);
4984 +
    let appId = try registerModule(&mut MODULE_GRAPH, rootId, "app", "use root::consts; const DERIVED: i32 = consts::BASE + 50;", &mut arena);
4985 +
4986 +
    let result = try resolveModuleTree(&mut a, rootId);
4987 +
    try expectNoErrors(&result);
4988 +
}
4989 +
4990 +
/// Test cross-module constant expression used as array size.
4991 +
@test fn testCrossModuleConstExprArraySize() throws (testing::TestError) {
4992 +
    let mut a = testResolver();
4993 +
    let mut arena = ast::nodeArena(&mut AST_ARENA[..]);
4994 +
4995 +
    let rootId = try registerModule(&mut MODULE_GRAPH, nil, "root", "pub mod consts; mod app;", &mut arena);
4996 +
    let constsId = try registerModule(&mut MODULE_GRAPH, rootId, "consts", "pub const WIDTH: u32 = 8; pub const HEIGHT: u32 = 4;", &mut arena);
4997 +
    let appId = try registerModule(&mut MODULE_GRAPH, rootId, "app", "use root::consts; const TOTAL: u32 = consts::WIDTH * consts::HEIGHT; static BUF: [u8; TOTAL] = undefined;", &mut arena);
4998 +
4999 +
    let result = try resolveModuleTree(&mut a, rootId);
5000 +
    try expectNoErrors(&result);
5001 +
}
5002 +
5003 +
/// Test that non-constant expressions in const declarations are still rejected.
5004 +
@test fn testConstExprNonConstRejected() throws (testing::TestError) {
5005 +
    let mut a = testResolver();
5006 +
    let program = "fn value() -> i32 { return 1; } const BAD: i32 = value() + 1;";
5007 +
    let result = try resolveProgramStr(&mut a, program);
5008 +
    let err = try expectError(&result);
5009 +
    let case super::ErrorKind::ConstExprRequired = err.kind
5010 +
        else throw testing::TestError::Failed;
5011 +
}
5012 +
5013 +
/// Test unary negation in constant expressions.
5014 +
@test fn testConstExprUnaryNeg() throws (testing::TestError) {
5015 +
    let mut a = testResolver();
5016 +
    let program = "const A: i32 = 10; const B: i32 = -A;";
5017 +
    let result = try resolveProgramStr(&mut a, program);
5018 +
    try expectNoErrors(&result);
5019 +
}
5020 +
5021 +
/// Test unary not in constant expressions.
5022 +
@test fn testConstExprUnaryNot() throws (testing::TestError) {
5023 +
    let mut a = testResolver();
5024 +
    let program = "const A: bool = true; const B: bool = not A;";
5025 +
    let result = try resolveProgramStr(&mut a, program);
5026 +
    try expectNoErrors(&result);
5027 +
}
test/tests/const-expr-array-size.rad added +16 -0
1 +
//! returns: 0
2 +
3 +
/// Test that constant expressions referencing other constants
4 +
/// can be used as array sizes.
5 +
6 +
const ROWS: u32 = 3;
7 +
const COLS: u32 = 4;
8 +
const TOTAL: u32 = ROWS * COLS;
9 +
10 +
static DATA: [i32; TOTAL] = undefined;
11 +
12 +
@default fn main() -> i32 {
13 +
    // Verify the array has the expected length.
14 +
    if DATA.len != 12 { return 1; }
15 +
    return 0;
16 +
}
test/tests/const-expr-refs.rad added +31 -0
1 +
//! returns: 0
2 +
3 +
/// Test that constant expressions can reference other constants,
4 +
/// including arithmetic on constants.
5 +
6 +
const A: i32 = 10;
7 +
const B: i32 = 20;
8 +
const C: i32 = A + B;
9 +
const D: i32 = C * 2;
10 +
const E: i32 = D - A;
11 +
const F: i32 = E / 5;
12 +
const G: i32 = A % 3;
13 +
const H: i32 = A | B;
14 +
const I: i32 = A & B;
15 +
const J: i32 = A ^ B;
16 +
const K: i32 = A << 2;
17 +
const L: i32 = B >> 1;
18 +
19 +
@default fn main() -> i32 {
20 +
    if C != 30 { return 1; }
21 +
    if D != 60 { return 2; }
22 +
    if E != 50 { return 3; }
23 +
    if F != 10 { return 4; }
24 +
    if G != 1 { return 5; }
25 +
    if H != 30 { return 6; }
26 +
    if I != 0 { return 7; }
27 +
    if J != 30 { return 8; }
28 +
    if K != 40 { return 9; }
29 +
    if L != 10 { return 10; }
30 +
    return 0;
31 +
}