Add auto-deref support for nested patterns

5f74ebc066787d131f48af47d8bcf25336395718b4117b7006e4238dd5330252
When a union variant field or record field holds a pointer (*T) and the
pattern is a destructuring pattern, the compiler now automatically
dereferences the pointer and matches against the pointed-to type.
Alexis Sellier committed ago 1 parent d9a090bf
lib/std/arch/rv64/tests/match.nested.deref.rad added +156 -0
1 +
//! Test nested patterns through pointer dereferences (auto-deref).
2 +
3 +
union Inner {
4 +
    A(i32),
5 +
    B,
6 +
}
7 +
8 +
union Outer {
9 +
    Some(*Inner),
10 +
    None,
11 +
}
12 +
13 +
/// Match nested union variant through pointer dereference in match/case.
14 +
fn matchDeref(o: Outer) -> i32 {
15 +
    match o {
16 +
        case Outer::Some(Inner::A(x)) => {
17 +
            return x;
18 +
        }
19 +
        case Outer::Some(Inner::B) => {
20 +
            return -1;
21 +
        }
22 +
        else => {
23 +
            return -2;
24 +
        }
25 +
    }
26 +
}
27 +
28 +
/// If-let-case with auto-deref nested pattern.
29 +
fn ifLetDeref(o: Outer) -> i32 {
30 +
    if let case Outer::Some(Inner::A(x)) = o {
31 +
        return x;
32 +
    }
33 +
    return 0;
34 +
}
35 +
36 +
/// Let-else with auto-deref nested pattern.
37 +
fn letElseDeref(o: Outer) -> i32 {
38 +
    let case Outer::Some(Inner::A(x)) = o
39 +
        else { return -1; };
40 +
    return x;
41 +
}
42 +
43 +
/// Record with pointer field and nested pattern through deref.
44 +
record Container {
45 +
    inner: *Inner,
46 +
    tag: i32,
47 +
}
48 +
49 +
union Boxed {
50 +
    Some { c: Container },
51 +
    None,
52 +
}
53 +
54 +
/// Nested record with auto-deref on a pointer field.
55 +
fn matchRecordDeref(b: Boxed) -> i32 {
56 +
    match b {
57 +
        case Boxed::Some { c: Container { inner: Inner::A(val), tag } } => {
58 +
            return val + tag;
59 +
        }
60 +
        else => {
61 +
            return 0;
62 +
        }
63 +
    }
64 +
}
65 +
66 +
/// Auto-deref through pointer in if-let-case with record.
67 +
fn ifLetRecordDeref(b: Boxed) -> i32 {
68 +
    if let case Boxed::Some { c: Container { inner: Inner::A(val), tag } } = b {
69 +
        return val + tag;
70 +
    }
71 +
    return 0;
72 +
}
73 +
74 +
/// Void variant through pointer deref.
75 +
fn matchDerefVoid(o: Outer) -> i32 {
76 +
    if let case Outer::Some(Inner::B) = o {
77 +
        return 1;
78 +
    }
79 +
    return 0;
80 +
}
81 +
82 +
/// Auto-deref with placeholder in nested pattern.
83 +
fn matchDerefPlaceholder(o: Outer) -> i32 {
84 +
    if let case Outer::Some(Inner::A(_)) = o {
85 +
        return 1;
86 +
    }
87 +
    return 0;
88 +
}
89 +
90 +
/// Auto-deref through pointer to record in a record field.
91 +
record Point {
92 +
    x: i32,
93 +
    y: i32,
94 +
}
95 +
96 +
union Holder {
97 +
    Ptr { p: *Point, z: i32 },
98 +
    Empty,
99 +
}
100 +
101 +
fn derefRecordField(h: Holder) -> i32 {
102 +
    if let case Holder::Ptr { p: Point { x, y }, z } = h {
103 +
        return x + y + z;
104 +
    }
105 +
    return 0;
106 +
}
107 +
108 +
@default fn main() -> i32 {
109 +
    let innerA = Inner::A(42);
110 +
    let innerB = Inner::B;
111 +
112 +
    // matchDeref
113 +
    assert matchDeref(Outer::Some(&innerA)) == 42;
114 +
    assert matchDeref(Outer::Some(&innerB)) == -1;
115 +
    assert matchDeref(Outer::None) == -2;
116 +
117 +
    // ifLetDeref
118 +
    assert ifLetDeref(Outer::Some(&innerA)) == 42;
119 +
    assert ifLetDeref(Outer::Some(&innerB)) == 0;
120 +
    assert ifLetDeref(Outer::None) == 0;
121 +
122 +
    // letElseDeref
123 +
    assert letElseDeref(Outer::Some(&innerA)) == 42;
124 +
    assert letElseDeref(Outer::Some(&innerB)) == -1;
125 +
    assert letElseDeref(Outer::None) == -1;
126 +
127 +
    // matchRecordDeref
128 +
    let c1 = Boxed::Some { c: Container { inner: &innerA, tag: 10 } };
129 +
    assert matchRecordDeref(c1) == 52;
130 +
    let c2 = Boxed::Some { c: Container { inner: &innerB, tag: 20 } };
131 +
    assert matchRecordDeref(c2) == 0;
132 +
    assert matchRecordDeref(Boxed::None) == 0;
133 +
134 +
    // ifLetRecordDeref
135 +
    assert ifLetRecordDeref(c1) == 52;
136 +
    assert ifLetRecordDeref(c2) == 0;
137 +
    assert ifLetRecordDeref(Boxed::None) == 0;
138 +
139 +
    // matchDerefVoid
140 +
    assert matchDerefVoid(Outer::Some(&innerA)) == 0;
141 +
    assert matchDerefVoid(Outer::Some(&innerB)) == 1;
142 +
    assert matchDerefVoid(Outer::None) == 0;
143 +
144 +
    // matchDerefPlaceholder
145 +
    assert matchDerefPlaceholder(Outer::Some(&innerA)) == 1;
146 +
    assert matchDerefPlaceholder(Outer::Some(&innerB)) == 0;
147 +
    assert matchDerefPlaceholder(Outer::None) == 0;
148 +
149 +
    // derefRecordField
150 +
    let pt = Point { x: 3, y: 4 };
151 +
    let h1 = Holder::Ptr { p: &pt, z: 5 };
152 +
    assert derefRecordField(h1) == 12;
153 +
    assert derefRecordField(Holder::Empty) == 0;
154 +
155 +
    return 0;
156 +
}
lib/std/lang/lower.rad +63 -27
3065 3065
fn bindPatternVariables(self: *mut FnLowerer, subject: *MatchSubject, patterns: *mut [*ast::Node], failBlock: BlockId) throws (LowerError) {
3066 3066
    for pattern in patterns {
3067 3067
3068 3068
        // Handle simple variant patterns like `Variant(x)`.
3069 3069
        if let arg = resolver::variantPatternBinding(self.low.resolver, pattern) {
3070 -
            if let bindType = resolver::typeFor(self.low.resolver, arg) {
3071 -
                let case MatchSubjectKind::Union(unionInfo) = subject.kind
3072 -
                    else panic "bindPatternVariables: expected union subject";
3073 -
                let valOffset = unionInfo.valOffset as i32;
3074 -
3075 -
                match arg.value {
3076 -
                    case ast::NodeValue::Ident(name) => {
3077 -
                        try bindPayloadVariable(self, name, subject.val, bindType, subject.by, valOffset, false);
3078 -
                    }
3079 -
                    case ast::NodeValue::Placeholder => {}
3080 -
                    else => {
3081 -
                        // Nested pattern inside a variant call, e.g. `Variant(Inner { x, y })`.
3082 -
                        let base = emitValToReg(self, subject.val);
3083 -
                        let payloadBase = emitPtrOffset(self, base, valOffset);
3084 -
                        let fieldInfo = resolver::RecordField {
3085 -
                            name: nil,
3086 -
                            fieldType: bindType,
3087 -
                            offset: 0,
3088 -
                        };
3089 -
                        try bindFieldVariable(self, arg, payloadBase, fieldInfo, subject.by, failBlock);
3090 -
                    }
3070 +
            let case MatchSubjectKind::Union(unionInfo) = subject.kind
3071 +
                else panic "bindPatternVariables: expected union subject";
3072 +
            let valOffset = unionInfo.valOffset as i32;
3073 +
3074 +
            // Get the actual field type from the variant's record info.
3075 +
            // This preserves the original data layout type (e.g. `*T`) even when
3076 +
            // the resolver resolved the pattern against a dereferenced type (`T`).
3077 +
            let variantExtra = resolver::nodeData(self.low.resolver, pattern).extra;
3078 +
            let case resolver::NodeExtra::UnionVariant { ordinal, .. } = variantExtra
3079 +
                else panic "bindPatternVariables: expected variant extra";
3080 +
            let payloadType = unionInfo.variants[ordinal].valueType;
3081 +
            let payloadRec = resolver::getRecord(payloadType)
3082 +
                else panic "bindPatternVariables: expected record payload";
3083 +
            let fieldType = payloadRec.fields[0].fieldType;
3084 +
3085 +
            match arg.value {
3086 +
                case ast::NodeValue::Ident(name) => {
3087 +
                    try bindPayloadVariable(self, name, subject.val, fieldType, subject.by, valOffset, false);
3088 +
                }
3089 +
                case ast::NodeValue::Placeholder => {}
3090 +
                else => {
3091 +
                    // Nested pattern inside a variant call, e.g. `Variant(Inner { x, y })`.
3092 +
                    let base = emitValToReg(self, subject.val);
3093 +
                    let payloadBase = emitPtrOffset(self, base, valOffset);
3094 +
                    let fieldInfo = resolver::RecordField {
3095 +
                        name: nil,
3096 +
                        fieldType,
3097 +
                        offset: 0,
3098 +
                    };
3099 +
                    try bindFieldVariable(self, arg, payloadBase, fieldInfo, subject.by, failBlock);
3091 3100
                }
3092 3101
            }
3093 3102
        }
3094 3103
        match pattern.value {
3095 3104
            // Compound variant patterns like `Variant { a, b }`.
3193 3202
                    try emitNestedFieldTest(self, binding, base, fieldInfo, matchBy, failBlock);
3194 3203
                    return;
3195 3204
                }
3196 3205
            }
3197 3206
            // Plain nested record destructuring pattern.
3198 -
            let recInfo = resolver::getRecord(fieldInfo.fieldType)
3207 +
            // Auto-deref: if the field is a pointer, load it first.
3208 +
            let mut derefType = fieldInfo.fieldType;
3209 +
            let mut nestedBase = emitPtrOffset(self, base, fieldInfo.offset);
3210 +
            if let case resolver::Type::Pointer { target, .. } = fieldInfo.fieldType {
3211 +
                let ptrReg = nextReg(self);
3212 +
                emitLoadW64(self, ptrReg, nestedBase);
3213 +
                nestedBase = ptrReg;
3214 +
                derefType = *target;
3215 +
            }
3216 +
            let recInfo = resolver::getRecord(derefType)
3199 3217
                else throw LowerError::ExpectedRecord;
3200 -
            let nestedBase = emitPtrOffset(self, base, fieldInfo.offset);
3201 3218
3202 3219
            try bindNestedRecordFields(self, nestedBase, lit, recInfo, matchBy, failBlock);
3203 3220
        }
3204 3221
        else => {
3205 3222
            // Nested pattern requiring a test (union variant scope access, literal, etc).
3217 3234
    base: il::Reg,
3218 3235
    fieldInfo: resolver::RecordField,
3219 3236
    matchBy: resolver::MatchBy,
3220 3237
    failBlock: BlockId
3221 3238
) throws (LowerError) {
3222 -
    let fieldType = fieldInfo.fieldType;
3239 +
    let mut fieldType = fieldInfo.fieldType;
3223 3240
    let fieldPtr = emitPtrOffset(self, base, fieldInfo.offset);
3224 3241
3242 +
    // Auto-deref: when the field is a pointer and the pattern destructures
3243 +
    // the pointed-to value, load the pointer and use the target type.
3244 +
    // The loaded pointer becomes the base address for the nested subject.
3245 +
    let mut derefBase: ?il::Reg = nil;
3246 +
    if let case resolver::Type::Pointer { target, .. } = fieldType {
3247 +
        if resolver::isDestructuringPattern(pattern) {
3248 +
            let ptrReg = nextReg(self);
3249 +
            emitLoadW64(self, ptrReg, fieldPtr);
3250 +
            derefBase = ptrReg;
3251 +
            fieldType = *target;
3252 +
        }
3253 +
    }
3225 3254
    // Build a MatchSubject for the nested field.
3226 3255
    let ilTy = ilType(self.low, fieldType);
3227 3256
    let kind = matchSubjectKind(fieldType);
3228 3257
3229 -
    // For aggregate subjects, the value must be a pointer.
3230 -
    let mut val = il::Val::Reg(fieldPtr);
3231 -
    if not isAggregateType(fieldType) {
3258 +
    // Determine the subject value.
3259 +
    let mut val: il::Val = undefined;
3260 +
    if let reg = derefBase {
3261 +
        // Auto-deref: the loaded pointer is the address of the target value.
3262 +
        val = il::Val::Reg(reg);
3263 +
    } else if isAggregateType(fieldType) {
3264 +
        // Aggregate: use the pointer.
3265 +
        val = il::Val::Reg(fieldPtr);
3266 +
    } else {
3267 +
        // Scalar: load the value.
3232 3268
        val = emitRead(self, base, fieldInfo.offset, fieldType);
3233 3269
    }
3234 3270
    let nestedSubject = MatchSubject {
3235 3271
        val,
3236 3272
        type: fieldType,
lib/std/lang/resolver.rad +20 -0
3860 3860
    Compare,
3861 3861
    /// Identifier introduces a new binding.
3862 3862
    Bind,
3863 3863
}
3864 3864
3865 +
/// Check whether a pattern node is a destructuring pattern that looks
3866 +
/// through structure (union variant, record literal, scope access).
3867 +
/// Identifiers, placeholders, and plain literals are not destructuring.
3868 +
pub fn isDestructuringPattern(pattern: *ast::Node) -> bool {
3869 +
    match pattern.value {
3870 +
        case ast::NodeValue::Call(_),
3871 +
             ast::NodeValue::RecordLit(_),
3872 +
             ast::NodeValue::ScopeAccess(_) => return true,
3873 +
        else => return false,
3874 +
    }
3875 +
}
3876 +
3865 3877
/// Analyze a case pattern for match, if-case, let-case, or while-case.
3866 3878
///
3867 3879
/// At the top level, bare identifiers are compared against existing values.
3868 3880
/// Inside destructuring patterns (arrays, records), identifiers become bindings.
3869 3881
fn resolveCasePattern(
3873 3885
    mode: IdentMode,
3874 3886
    matchBy: MatchBy
3875 3887
) throws (ResolveError) {
3876 3888
    // TODO: Collapse these nested matches.
3877 3889
    match scrutineeTy {
3890 +
        case Type::Pointer { target, .. } => {
3891 +
            // Auto-deref: when the scrutinee is a pointer and the pattern
3892 +
            // is a destructuring pattern, resolve against the pointed-to type.
3893 +
            if isDestructuringPattern(pattern) {
3894 +
                try resolveCasePattern(self, pattern, *target, mode, matchBy);
3895 +
                return;
3896 +
            }
3897 +
        }
3878 3898
        case Type::Nominal(info) => {
3879 3899
            try ensureNominalResolved(self, info, pattern);
3880 3900
3881 3901
            match *info {
3882 3902
                case NominalType::Union(unionType) => {