Improve flexibility of constant expressions

359694c209cc4c8d98076bd9a69f0f8ad4c0648470a6fb8c3cf4ad6ae24a5fc6
Supports literals in more positions.
Alexis Sellier committed ago 1 parent 62ca04d4
lib/std/lang/resolver.rad +17 -9
1781 1781
            return nil;
1782 1782
        }
1783 1783
        else => {
1784 1784
            if isNumericType(to) and isNumericType(from) {
1785 1785
                // Perform range validation at compile time if possible.
1786 +
                // For unsuffixed integer expressions (`Type::Int`), only
1787 +
                // validate literals directly written by the programmer.
1788 +
                // Folded results (e.g. `0 - 65`) may not fit the target
1789 +
                // type but are valid wrapping arithmetic at runtime.
1786 1790
                if let value = constValueEntry(self, rval) {
1787 -
                    if validateConstIntRange(value, to) {
1788 -
                        return Coercion::Identity;
1791 +
                    if from != Type::Int or isNumberLiteral(rval) {
1792 +
                        if validateConstIntRange(value, to) {
1793 +
                            return Coercion::Identity;
1794 +
                        }
1795 +
                        return nil;
1789 1796
                    }
1790 -
                    return nil;
1791 1797
                }
1792 1798
                // Allow unsuffixed integer expressions to be inferred from context.
1793 1799
                if from == Type::Int {
1794 1800
                    return Coercion::NumericCast { from, to };
1795 1801
                }
2785 2791
    setNodeType(self, decl.value, bindingTy);
2786 2792
2787 2793
    return Type::Void;
2788 2794
}
2789 2795
2796 +
/// Check whether a node is a number literal.
2797 +
fn isNumberLiteral(node: *ast::Node) -> bool {
2798 +
    if let case ast::NodeValue::Number(_) = node.value {
2799 +
        return true;
2800 +
    }
2801 +
    return false;
2802 +
}
2803 +
2790 2804
/// Determine whether a node represents a compile-time constant expression.
2791 2805
pub fn isConstExpr(self: *Resolver, node: *ast::Node) -> bool {
2792 2806
    match node.value {
2793 2807
        case ast::NodeValue::Bool(_),
2794 2808
             ast::NodeValue::Char(_),
6027 6041
}
6028 6042
6029 6043
/// Try to constant-fold a binary operation on two resolved operands.
6030 6044
/// Only folds when the result type is concrete.
6031 6045
fn tryFoldBinOp(self: *mut Resolver, node: *ast::Node, binop: ast::BinOp, resultTy: Type) {
6032 -
    // Skip folding for untyped integer expressions. Their final type
6033 -
    // depends on context, so attaching a signed constant value could
6034 -
    // cause spurious range-check failures (e.g. `0 - 12` assigned to u8).
6035 -
    if resultTy == Type::Int {
6036 -
        return;
6037 -
    }
6038 6046
    let leftVal = constValueFor(self, binop.left)
6039 6047
        else return;
6040 6048
    let rightVal = constValueFor(self, binop.right)
6041 6049
        else return;
6042 6050
lib/std/lang/resolver/tests.rad +23 -3
759 759
    let arrayTy = try typeOf(&a, decl.value);
760 760
    let elemTy = try expectArrayType(arrayTy, 5);
761 761
    try testing::expect(elemTy == super::Type::I32);
762 762
}
763 763
764 -
@test fn testResolveArrayRepeatNonConstCount() throws (testing::TestError) {
764 +
@test fn testResolveArrayRepeatLiteralArithmetic() throws (testing::TestError) {
765 765
    let mut a = testResolver();
766 -
    // `3 * 1` uses untyped integer literals, so constant folding does not
767 -
    // produce a value. The repeat count requires a known compile-time value.
766 +
    // `3 * 1` folds to a compile-time constant, so the repeat count is valid.
768 767
    let result = try resolveProgramStr(&mut a, "let xs: [i32; 3] = [42; 3 * 1];");
768 +
    try expectNoErrors(&result);
769 +
}
770 +
771 +
@test fn testResolveArrayRepeatNonConstCount() throws (testing::TestError) {
772 +
    let mut a = testResolver();
773 +
    // A function call is not a constant expression.
774 +
    let result = try resolveProgramStr(&mut a, "fn f() -> u32 { return 3; } let xs: [i32; 3] = [42; f()];");
769 775
    try expectErrorKind(&result, super::ErrorKind::ConstExprRequired);
770 776
}
771 777
772 778
@test fn testResolveArrayRepeatCountMismatch() throws (testing::TestError) {
773 779
    let mut a = testResolver();
5032 5038
    try expectConstFold("const A: u64 = 10; const B: u8 = A as u8;", 1, 10);
5033 5039
    try expectConstFold("const A: i32 = 7; const B: u32 = A as u32;", 1, 7);
5034 5040
    try expectConstFold("const A: u32 = 100; const B: i32 = A as i32;", 1, 100);
5035 5041
    try expectConstFold("const A: u8 = 5; const B: u64 = (A as u32) as u64;", 1, 5);
5036 5042
    try expectConstFold("const A: u8 = 3; const B: u8 = 4; const C: i32 = (A as i32) + (B as i32);", 2, 7);
5043 +
    // Cast of unsuffixed literal arithmetic.
5044 +
    try expectConstFold("const A: u32 = (3 + 4) as u32;", 0, 7);
5045 +
    try expectConstFold("const A: u32 = ((3 + 4) as u64) as u32;", 0, 7);
5046 +
    try expectConstFold("const A: u32 = (3 + 4) as u32 + 1;", 0, 8);
5047 +
    try expectConstFold("const A: i32 = (2 as i32) * (3 + 4);", 0, 14);
5037 5048
}
5038 5049
5039 5050
/// Test `as` cast in constant expressions used as array size.
5040 5051
@test fn testConstExprCastAsArraySize() throws (testing::TestError) {
5041 5052
    let mut a = testResolver();
5048 5059
        else throw testing::TestError::Failed;
5049 5060
    let case super::SymbolData::Constant { type: super::Type::Array(arrType), .. } = sym.data
5050 5061
        else throw testing::TestError::Failed;
5051 5062
    try testing::expect(arrType.length == 4);
5052 5063
}
5064 +
5065 +
/// Test unsuffixed integer literals in constant expressions.
5066 +
@test fn testConstExprUnsuffixedLiterals() throws (testing::TestError) {
5067 +
    try expectConstFold("const A: u32 = 4 * 4;", 0, 16);
5068 +
    try expectConstFold("const B: u32 = 10; const C: u32 = B * 2;", 1, 20);
5069 +
    try expectConstFold("const D: u32 = 3 + 7;", 0, 10);
5070 +
    try expectConstFold("const E: u32 = 2 * 3 + 4;", 0, 10);
5071 +
    try expectConstFold("const F: i32 = -(3 + 4);", 0, 7);
5072 +
}
test/tests/const-expr-cast.rad +18 -6
1 1
//! returns: 0
2 2
3 3
/// Test that `as` casts work in constant expressions.
4 4
5 +
// Widening, narrowing, sign changes.
5 6
const A: i32 = 42;
6 7
const B: u64 = A as u64;
7 8
const C: u8 = 10;
8 9
const D: i32 = C as i32;
9 -
const SIZE: u32 = 4;
10 -
const SHIFTED: u64 = SIZE as u64;
10 +
const SHIFTED: u64 = 4 as u64;
11 11
12 -
// Use a cast const expression as an array size.
12 +
// Chained casts.
13 13
const LEN: u32 = 8;
14 14
const LEN2: u32 = (LEN as u64) as u32;
15 15
static BUF: [u8; LEN2] = undefined;
16 16
17 -
// Cast const used in static data initializer.
17 +
// Cast of unsuffixed literal arithmetic.
18 +
const E: u32 = (3 + 4) as u32;
19 +
// Nested casts of literal arithmetic.
20 +
const F: u32 = ((3 + 4) as u64) as u32;
21 +
// Cast + arithmetic.
22 +
const G: u32 = (3 + 4) as u32 + 1;
23 +
// Arithmetic with cast subexprs.
24 +
const H: i32 = (2 as i32) * (3 + 4);
25 +
26 +
// Cast in static data initializer.
18 27
const INIT_VAL: u8 = 0xFF;
19 28
const WIDE: u32 = INIT_VAL as u32;
20 29
21 -
// Cast with arithmetic.
30 +
// Typed const * typed const through cast.
22 31
const X: u8 = 3;
23 32
const Y: u8 = 4;
24 33
const Z: i32 = (X as i32) * (Y as i32);
25 34
26 35
@default fn main() -> i32 {
27 36
    assert BUF.len == 8;
28 37
    assert WIDE == 255;
29 38
    assert Z == 12;
30 -
39 +
    assert E == 7;
40 +
    assert F == 7;
41 +
    assert G == 8;
42 +
    assert H == 14;
31 43
    return 0;
32 44
}
test/tests/const-expr-literal.rad added +31 -0
1 +
//! returns: 0
2 +
3 +
/// Test that unsuffixed integer literals work in const expressions.
4 +
5 +
// Literal * literal.
6 +
const A: u32 = 4 * 4;
7 +
// Const * literal.
8 +
const B: u32 = 10;
9 +
const C: u32 = B * 2;
10 +
// Literal + literal.
11 +
const D: u32 = 3 + 7;
12 +
// Chained with literals.
13 +
const E: u32 = 2 * 3 + 4;
14 +
// Unary negation with literal.
15 +
const F: i32 = -5;
16 +
const G: i32 = F * 2;
17 +
18 +
static BUF: [u8; A] = undefined;
19 +
static BUF2: [u8; C] = undefined;
20 +
21 +
@default fn main() -> i32 {
22 +
    assert A == 16;
23 +
    assert BUF.len == 16;
24 +
    assert C == 20;
25 +
    assert BUF2.len == 20;
26 +
    assert D == 10;
27 +
    assert E == 10;
28 +
    assert G == -10;
29 +
30 +
    return 0;
31 +
}