Add auto-deref support for nested patterns
5f74ebc066787d131f48af47d8bcf25336395718b4117b7006e4238dd5330252
When a union variant field or record field holds a pointer (*T) and the pattern is a destructuring pattern, the compiler now automatically dereferences the pointer and matches against the pointed-to type.
1 parent
d9a090bf
lib/std/arch/rv64/tests/match.nested.deref.rad
added
+156 -0
| 1 | + | //! Test nested patterns through pointer dereferences (auto-deref). |
|
| 2 | + | ||
| 3 | + | union Inner { |
|
| 4 | + | A(i32), |
|
| 5 | + | B, |
|
| 6 | + | } |
|
| 7 | + | ||
| 8 | + | union Outer { |
|
| 9 | + | Some(*Inner), |
|
| 10 | + | None, |
|
| 11 | + | } |
|
| 12 | + | ||
| 13 | + | /// Match nested union variant through pointer dereference in match/case. |
|
| 14 | + | fn matchDeref(o: Outer) -> i32 { |
|
| 15 | + | match o { |
|
| 16 | + | case Outer::Some(Inner::A(x)) => { |
|
| 17 | + | return x; |
|
| 18 | + | } |
|
| 19 | + | case Outer::Some(Inner::B) => { |
|
| 20 | + | return -1; |
|
| 21 | + | } |
|
| 22 | + | else => { |
|
| 23 | + | return -2; |
|
| 24 | + | } |
|
| 25 | + | } |
|
| 26 | + | } |
|
| 27 | + | ||
| 28 | + | /// If-let-case with auto-deref nested pattern. |
|
| 29 | + | fn ifLetDeref(o: Outer) -> i32 { |
|
| 30 | + | if let case Outer::Some(Inner::A(x)) = o { |
|
| 31 | + | return x; |
|
| 32 | + | } |
|
| 33 | + | return 0; |
|
| 34 | + | } |
|
| 35 | + | ||
| 36 | + | /// Let-else with auto-deref nested pattern. |
|
| 37 | + | fn letElseDeref(o: Outer) -> i32 { |
|
| 38 | + | let case Outer::Some(Inner::A(x)) = o |
|
| 39 | + | else { return -1; }; |
|
| 40 | + | return x; |
|
| 41 | + | } |
|
| 42 | + | ||
| 43 | + | /// Record with pointer field and nested pattern through deref. |
|
| 44 | + | record Container { |
|
| 45 | + | inner: *Inner, |
|
| 46 | + | tag: i32, |
|
| 47 | + | } |
|
| 48 | + | ||
| 49 | + | union Boxed { |
|
| 50 | + | Some { c: Container }, |
|
| 51 | + | None, |
|
| 52 | + | } |
|
| 53 | + | ||
| 54 | + | /// Nested record with auto-deref on a pointer field. |
|
| 55 | + | fn matchRecordDeref(b: Boxed) -> i32 { |
|
| 56 | + | match b { |
|
| 57 | + | case Boxed::Some { c: Container { inner: Inner::A(val), tag } } => { |
|
| 58 | + | return val + tag; |
|
| 59 | + | } |
|
| 60 | + | else => { |
|
| 61 | + | return 0; |
|
| 62 | + | } |
|
| 63 | + | } |
|
| 64 | + | } |
|
| 65 | + | ||
| 66 | + | /// Auto-deref through pointer in if-let-case with record. |
|
| 67 | + | fn ifLetRecordDeref(b: Boxed) -> i32 { |
|
| 68 | + | if let case Boxed::Some { c: Container { inner: Inner::A(val), tag } } = b { |
|
| 69 | + | return val + tag; |
|
| 70 | + | } |
|
| 71 | + | return 0; |
|
| 72 | + | } |
|
| 73 | + | ||
| 74 | + | /// Void variant through pointer deref. |
|
| 75 | + | fn matchDerefVoid(o: Outer) -> i32 { |
|
| 76 | + | if let case Outer::Some(Inner::B) = o { |
|
| 77 | + | return 1; |
|
| 78 | + | } |
|
| 79 | + | return 0; |
|
| 80 | + | } |
|
| 81 | + | ||
| 82 | + | /// Auto-deref with placeholder in nested pattern. |
|
| 83 | + | fn matchDerefPlaceholder(o: Outer) -> i32 { |
|
| 84 | + | if let case Outer::Some(Inner::A(_)) = o { |
|
| 85 | + | return 1; |
|
| 86 | + | } |
|
| 87 | + | return 0; |
|
| 88 | + | } |
|
| 89 | + | ||
| 90 | + | /// Auto-deref through pointer to record in a record field. |
|
| 91 | + | record Point { |
|
| 92 | + | x: i32, |
|
| 93 | + | y: i32, |
|
| 94 | + | } |
|
| 95 | + | ||
| 96 | + | union Holder { |
|
| 97 | + | Ptr { p: *Point, z: i32 }, |
|
| 98 | + | Empty, |
|
| 99 | + | } |
|
| 100 | + | ||
| 101 | + | fn derefRecordField(h: Holder) -> i32 { |
|
| 102 | + | if let case Holder::Ptr { p: Point { x, y }, z } = h { |
|
| 103 | + | return x + y + z; |
|
| 104 | + | } |
|
| 105 | + | return 0; |
|
| 106 | + | } |
|
| 107 | + | ||
| 108 | + | @default fn main() -> i32 { |
|
| 109 | + | let innerA = Inner::A(42); |
|
| 110 | + | let innerB = Inner::B; |
|
| 111 | + | ||
| 112 | + | // matchDeref |
|
| 113 | + | assert matchDeref(Outer::Some(&innerA)) == 42; |
|
| 114 | + | assert matchDeref(Outer::Some(&innerB)) == -1; |
|
| 115 | + | assert matchDeref(Outer::None) == -2; |
|
| 116 | + | ||
| 117 | + | // ifLetDeref |
|
| 118 | + | assert ifLetDeref(Outer::Some(&innerA)) == 42; |
|
| 119 | + | assert ifLetDeref(Outer::Some(&innerB)) == 0; |
|
| 120 | + | assert ifLetDeref(Outer::None) == 0; |
|
| 121 | + | ||
| 122 | + | // letElseDeref |
|
| 123 | + | assert letElseDeref(Outer::Some(&innerA)) == 42; |
|
| 124 | + | assert letElseDeref(Outer::Some(&innerB)) == -1; |
|
| 125 | + | assert letElseDeref(Outer::None) == -1; |
|
| 126 | + | ||
| 127 | + | // matchRecordDeref |
|
| 128 | + | let c1 = Boxed::Some { c: Container { inner: &innerA, tag: 10 } }; |
|
| 129 | + | assert matchRecordDeref(c1) == 52; |
|
| 130 | + | let c2 = Boxed::Some { c: Container { inner: &innerB, tag: 20 } }; |
|
| 131 | + | assert matchRecordDeref(c2) == 0; |
|
| 132 | + | assert matchRecordDeref(Boxed::None) == 0; |
|
| 133 | + | ||
| 134 | + | // ifLetRecordDeref |
|
| 135 | + | assert ifLetRecordDeref(c1) == 52; |
|
| 136 | + | assert ifLetRecordDeref(c2) == 0; |
|
| 137 | + | assert ifLetRecordDeref(Boxed::None) == 0; |
|
| 138 | + | ||
| 139 | + | // matchDerefVoid |
|
| 140 | + | assert matchDerefVoid(Outer::Some(&innerA)) == 0; |
|
| 141 | + | assert matchDerefVoid(Outer::Some(&innerB)) == 1; |
|
| 142 | + | assert matchDerefVoid(Outer::None) == 0; |
|
| 143 | + | ||
| 144 | + | // matchDerefPlaceholder |
|
| 145 | + | assert matchDerefPlaceholder(Outer::Some(&innerA)) == 1; |
|
| 146 | + | assert matchDerefPlaceholder(Outer::Some(&innerB)) == 0; |
|
| 147 | + | assert matchDerefPlaceholder(Outer::None) == 0; |
|
| 148 | + | ||
| 149 | + | // derefRecordField |
|
| 150 | + | let pt = Point { x: 3, y: 4 }; |
|
| 151 | + | let h1 = Holder::Ptr { p: &pt, z: 5 }; |
|
| 152 | + | assert derefRecordField(h1) == 12; |
|
| 153 | + | assert derefRecordField(Holder::Empty) == 0; |
|
| 154 | + | ||
| 155 | + | return 0; |
|
| 156 | + | } |
lib/std/lang/lower.rad
+63 -27
| 3065 | 3065 | fn bindPatternVariables(self: *mut FnLowerer, subject: *MatchSubject, patterns: *mut [*ast::Node], failBlock: BlockId) throws (LowerError) { |
|
| 3066 | 3066 | for pattern in patterns { |
|
| 3067 | 3067 | ||
| 3068 | 3068 | // Handle simple variant patterns like `Variant(x)`. |
|
| 3069 | 3069 | if let arg = resolver::variantPatternBinding(self.low.resolver, pattern) { |
|
| 3070 | - | if let bindType = resolver::typeFor(self.low.resolver, arg) { |
|
| 3071 | - | let case MatchSubjectKind::Union(unionInfo) = subject.kind |
|
| 3072 | - | else panic "bindPatternVariables: expected union subject"; |
|
| 3073 | - | let valOffset = unionInfo.valOffset as i32; |
|
| 3074 | - | ||
| 3075 | - | match arg.value { |
|
| 3076 | - | case ast::NodeValue::Ident(name) => { |
|
| 3077 | - | try bindPayloadVariable(self, name, subject.val, bindType, subject.by, valOffset, false); |
|
| 3078 | - | } |
|
| 3079 | - | case ast::NodeValue::Placeholder => {} |
|
| 3080 | - | else => { |
|
| 3081 | - | // Nested pattern inside a variant call, e.g. `Variant(Inner { x, y })`. |
|
| 3082 | - | let base = emitValToReg(self, subject.val); |
|
| 3083 | - | let payloadBase = emitPtrOffset(self, base, valOffset); |
|
| 3084 | - | let fieldInfo = resolver::RecordField { |
|
| 3085 | - | name: nil, |
|
| 3086 | - | fieldType: bindType, |
|
| 3087 | - | offset: 0, |
|
| 3088 | - | }; |
|
| 3089 | - | try bindFieldVariable(self, arg, payloadBase, fieldInfo, subject.by, failBlock); |
|
| 3090 | - | } |
|
| 3070 | + | let case MatchSubjectKind::Union(unionInfo) = subject.kind |
|
| 3071 | + | else panic "bindPatternVariables: expected union subject"; |
|
| 3072 | + | let valOffset = unionInfo.valOffset as i32; |
|
| 3073 | + | ||
| 3074 | + | // Get the actual field type from the variant's record info. |
|
| 3075 | + | // This preserves the original data layout type (e.g. `*T`) even when |
|
| 3076 | + | // the resolver resolved the pattern against a dereferenced type (`T`). |
|
| 3077 | + | let variantExtra = resolver::nodeData(self.low.resolver, pattern).extra; |
|
| 3078 | + | let case resolver::NodeExtra::UnionVariant { ordinal, .. } = variantExtra |
|
| 3079 | + | else panic "bindPatternVariables: expected variant extra"; |
|
| 3080 | + | let payloadType = unionInfo.variants[ordinal].valueType; |
|
| 3081 | + | let payloadRec = resolver::getRecord(payloadType) |
|
| 3082 | + | else panic "bindPatternVariables: expected record payload"; |
|
| 3083 | + | let fieldType = payloadRec.fields[0].fieldType; |
|
| 3084 | + | ||
| 3085 | + | match arg.value { |
|
| 3086 | + | case ast::NodeValue::Ident(name) => { |
|
| 3087 | + | try bindPayloadVariable(self, name, subject.val, fieldType, subject.by, valOffset, false); |
|
| 3088 | + | } |
|
| 3089 | + | case ast::NodeValue::Placeholder => {} |
|
| 3090 | + | else => { |
|
| 3091 | + | // Nested pattern inside a variant call, e.g. `Variant(Inner { x, y })`. |
|
| 3092 | + | let base = emitValToReg(self, subject.val); |
|
| 3093 | + | let payloadBase = emitPtrOffset(self, base, valOffset); |
|
| 3094 | + | let fieldInfo = resolver::RecordField { |
|
| 3095 | + | name: nil, |
|
| 3096 | + | fieldType, |
|
| 3097 | + | offset: 0, |
|
| 3098 | + | }; |
|
| 3099 | + | try bindFieldVariable(self, arg, payloadBase, fieldInfo, subject.by, failBlock); |
|
| 3091 | 3100 | } |
|
| 3092 | 3101 | } |
|
| 3093 | 3102 | } |
|
| 3094 | 3103 | match pattern.value { |
|
| 3095 | 3104 | // Compound variant patterns like `Variant { a, b }`. |
| 3193 | 3202 | try emitNestedFieldTest(self, binding, base, fieldInfo, matchBy, failBlock); |
|
| 3194 | 3203 | return; |
|
| 3195 | 3204 | } |
|
| 3196 | 3205 | } |
|
| 3197 | 3206 | // Plain nested record destructuring pattern. |
|
| 3198 | - | let recInfo = resolver::getRecord(fieldInfo.fieldType) |
|
| 3207 | + | // Auto-deref: if the field is a pointer, load it first. |
|
| 3208 | + | let mut derefType = fieldInfo.fieldType; |
|
| 3209 | + | let mut nestedBase = emitPtrOffset(self, base, fieldInfo.offset); |
|
| 3210 | + | if let case resolver::Type::Pointer { target, .. } = fieldInfo.fieldType { |
|
| 3211 | + | let ptrReg = nextReg(self); |
|
| 3212 | + | emitLoadW64(self, ptrReg, nestedBase); |
|
| 3213 | + | nestedBase = ptrReg; |
|
| 3214 | + | derefType = *target; |
|
| 3215 | + | } |
|
| 3216 | + | let recInfo = resolver::getRecord(derefType) |
|
| 3199 | 3217 | else throw LowerError::ExpectedRecord; |
|
| 3200 | - | let nestedBase = emitPtrOffset(self, base, fieldInfo.offset); |
|
| 3201 | 3218 | ||
| 3202 | 3219 | try bindNestedRecordFields(self, nestedBase, lit, recInfo, matchBy, failBlock); |
|
| 3203 | 3220 | } |
|
| 3204 | 3221 | else => { |
|
| 3205 | 3222 | // Nested pattern requiring a test (union variant scope access, literal, etc). |
| 3217 | 3234 | base: il::Reg, |
|
| 3218 | 3235 | fieldInfo: resolver::RecordField, |
|
| 3219 | 3236 | matchBy: resolver::MatchBy, |
|
| 3220 | 3237 | failBlock: BlockId |
|
| 3221 | 3238 | ) throws (LowerError) { |
|
| 3222 | - | let fieldType = fieldInfo.fieldType; |
|
| 3239 | + | let mut fieldType = fieldInfo.fieldType; |
|
| 3223 | 3240 | let fieldPtr = emitPtrOffset(self, base, fieldInfo.offset); |
|
| 3224 | 3241 | ||
| 3242 | + | // Auto-deref: when the field is a pointer and the pattern destructures |
|
| 3243 | + | // the pointed-to value, load the pointer and use the target type. |
|
| 3244 | + | // The loaded pointer becomes the base address for the nested subject. |
|
| 3245 | + | let mut derefBase: ?il::Reg = nil; |
|
| 3246 | + | if let case resolver::Type::Pointer { target, .. } = fieldType { |
|
| 3247 | + | if resolver::isDestructuringPattern(pattern) { |
|
| 3248 | + | let ptrReg = nextReg(self); |
|
| 3249 | + | emitLoadW64(self, ptrReg, fieldPtr); |
|
| 3250 | + | derefBase = ptrReg; |
|
| 3251 | + | fieldType = *target; |
|
| 3252 | + | } |
|
| 3253 | + | } |
|
| 3225 | 3254 | // Build a MatchSubject for the nested field. |
|
| 3226 | 3255 | let ilTy = ilType(self.low, fieldType); |
|
| 3227 | 3256 | let kind = matchSubjectKind(fieldType); |
|
| 3228 | 3257 | ||
| 3229 | - | // For aggregate subjects, the value must be a pointer. |
|
| 3230 | - | let mut val = il::Val::Reg(fieldPtr); |
|
| 3231 | - | if not isAggregateType(fieldType) { |
|
| 3258 | + | // Determine the subject value. |
|
| 3259 | + | let mut val: il::Val = undefined; |
|
| 3260 | + | if let reg = derefBase { |
|
| 3261 | + | // Auto-deref: the loaded pointer is the address of the target value. |
|
| 3262 | + | val = il::Val::Reg(reg); |
|
| 3263 | + | } else if isAggregateType(fieldType) { |
|
| 3264 | + | // Aggregate: use the pointer. |
|
| 3265 | + | val = il::Val::Reg(fieldPtr); |
|
| 3266 | + | } else { |
|
| 3267 | + | // Scalar: load the value. |
|
| 3232 | 3268 | val = emitRead(self, base, fieldInfo.offset, fieldType); |
|
| 3233 | 3269 | } |
|
| 3234 | 3270 | let nestedSubject = MatchSubject { |
|
| 3235 | 3271 | val, |
|
| 3236 | 3272 | type: fieldType, |
lib/std/lang/resolver.rad
+20 -0
| 3860 | 3860 | Compare, |
|
| 3861 | 3861 | /// Identifier introduces a new binding. |
|
| 3862 | 3862 | Bind, |
|
| 3863 | 3863 | } |
|
| 3864 | 3864 | ||
| 3865 | + | /// Check whether a pattern node is a destructuring pattern that looks |
|
| 3866 | + | /// through structure (union variant, record literal, scope access). |
|
| 3867 | + | /// Identifiers, placeholders, and plain literals are not destructuring. |
|
| 3868 | + | pub fn isDestructuringPattern(pattern: *ast::Node) -> bool { |
|
| 3869 | + | match pattern.value { |
|
| 3870 | + | case ast::NodeValue::Call(_), |
|
| 3871 | + | ast::NodeValue::RecordLit(_), |
|
| 3872 | + | ast::NodeValue::ScopeAccess(_) => return true, |
|
| 3873 | + | else => return false, |
|
| 3874 | + | } |
|
| 3875 | + | } |
|
| 3876 | + | ||
| 3865 | 3877 | /// Analyze a case pattern for match, if-case, let-case, or while-case. |
|
| 3866 | 3878 | /// |
|
| 3867 | 3879 | /// At the top level, bare identifiers are compared against existing values. |
|
| 3868 | 3880 | /// Inside destructuring patterns (arrays, records), identifiers become bindings. |
|
| 3869 | 3881 | fn resolveCasePattern( |
| 3873 | 3885 | mode: IdentMode, |
|
| 3874 | 3886 | matchBy: MatchBy |
|
| 3875 | 3887 | ) throws (ResolveError) { |
|
| 3876 | 3888 | // TODO: Collapse these nested matches. |
|
| 3877 | 3889 | match scrutineeTy { |
|
| 3890 | + | case Type::Pointer { target, .. } => { |
|
| 3891 | + | // Auto-deref: when the scrutinee is a pointer and the pattern |
|
| 3892 | + | // is a destructuring pattern, resolve against the pointed-to type. |
|
| 3893 | + | if isDestructuringPattern(pattern) { |
|
| 3894 | + | try resolveCasePattern(self, pattern, *target, mode, matchBy); |
|
| 3895 | + | return; |
|
| 3896 | + | } |
|
| 3897 | + | } |
|
| 3878 | 3898 | case Type::Nominal(info) => { |
|
| 3879 | 3899 | try ensureNominalResolved(self, info, pattern); |
|
| 3880 | 3900 | ||
| 3881 | 3901 | match *info { |
|
| 3882 | 3902 | case NominalType::Union(unionType) => { |