Fix trait soundness bugs
ae7f9152c109d8e1f6b277a6993310c49511dfd45fd8f68cf5538b226ff6f58d
1 parent
f952258d
lib/std/arch/rv64/tests/trait.aggregate.ret.rad
added
+90 -0
| 1 | + | //! Test trait methods returning aggregate types. |
|
| 2 | + | //! |
|
| 3 | + | //! Exercises: trait methods that return structs (both small and larger |
|
| 4 | + | //! than pointer size) and optionals via v-table dispatch. |
|
| 5 | + | ||
| 6 | + | record Point { |
|
| 7 | + | x: i32, |
|
| 8 | + | y: i32, |
|
| 9 | + | } |
|
| 10 | + | ||
| 11 | + | record Vec3 { |
|
| 12 | + | x: i32, |
|
| 13 | + | y: i32, |
|
| 14 | + | z: i32, |
|
| 15 | + | } |
|
| 16 | + | ||
| 17 | + | trait Geometry { |
|
| 18 | + | fn (*Geometry) origin() -> Point; |
|
| 19 | + | fn (*Geometry) center() -> Vec3; |
|
| 20 | + | fn (*Geometry) maybe() -> ?i32; |
|
| 21 | + | } |
|
| 22 | + | ||
| 23 | + | record Circle { |
|
| 24 | + | cx: i32, |
|
| 25 | + | cy: i32, |
|
| 26 | + | radius: i32, |
|
| 27 | + | } |
|
| 28 | + | ||
| 29 | + | instance Geometry for Circle { |
|
| 30 | + | fn (c: *Circle) origin() -> Point { |
|
| 31 | + | return Point { x: c.cx, y: c.cy }; |
|
| 32 | + | } |
|
| 33 | + | ||
| 34 | + | fn (c: *Circle) center() -> Vec3 { |
|
| 35 | + | return Vec3 { x: c.cx, y: c.cy, z: 0 }; |
|
| 36 | + | } |
|
| 37 | + | ||
| 38 | + | fn (c: *Circle) maybe() -> ?i32 { |
|
| 39 | + | if c.radius > 0 { |
|
| 40 | + | return c.radius; |
|
| 41 | + | } |
|
| 42 | + | return nil; |
|
| 43 | + | } |
|
| 44 | + | } |
|
| 45 | + | ||
| 46 | + | @default fn main() -> i32 { |
|
| 47 | + | let c = Circle { cx: 10, cy: 20, radius: 5 }; |
|
| 48 | + | let g: *opaque Geometry = &c; |
|
| 49 | + | ||
| 50 | + | // Small struct return (Point, 8 bytes = pointer size). |
|
| 51 | + | let p = g.origin(); |
|
| 52 | + | if p.x != 10 { |
|
| 53 | + | return 1; |
|
| 54 | + | } |
|
| 55 | + | if p.y != 20 { |
|
| 56 | + | return 2; |
|
| 57 | + | } |
|
| 58 | + | ||
| 59 | + | // Larger struct return (Vec3, 12 bytes > pointer size). |
|
| 60 | + | let v = g.center(); |
|
| 61 | + | if v.x != 10 { |
|
| 62 | + | return 3; |
|
| 63 | + | } |
|
| 64 | + | if v.y != 20 { |
|
| 65 | + | return 4; |
|
| 66 | + | } |
|
| 67 | + | if v.z != 0 { |
|
| 68 | + | return 5; |
|
| 69 | + | } |
|
| 70 | + | ||
| 71 | + | // Optional return - Some case. |
|
| 72 | + | let m = g.maybe(); |
|
| 73 | + | if let val = m { |
|
| 74 | + | if val != 5 { |
|
| 75 | + | return 6; |
|
| 76 | + | } |
|
| 77 | + | } else { |
|
| 78 | + | return 7; |
|
| 79 | + | } |
|
| 80 | + | ||
| 81 | + | // Optional return - None case. |
|
| 82 | + | let c2 = Circle { cx: 0, cy: 0, radius: 0 }; |
|
| 83 | + | let g2: *opaque Geometry = &c2; |
|
| 84 | + | let m2 = g2.maybe(); |
|
| 85 | + | if let _ = m2 { |
|
| 86 | + | return 8; |
|
| 87 | + | } |
|
| 88 | + | ||
| 89 | + | return 0; |
|
| 90 | + | } |
lib/std/arch/rv64/tests/trait.array.optional.rad
added
+79 -0
| 1 | + | //! Test trait objects stored in arrays and used as optional values. |
|
| 2 | + | //! |
|
| 3 | + | //! Exercises: arrays of trait objects dispatched in loops, |
|
| 4 | + | //! and optional trait object values. |
|
| 5 | + | ||
| 6 | + | record Adder { |
|
| 7 | + | n: i32, |
|
| 8 | + | } |
|
| 9 | + | ||
| 10 | + | record Multiplier { |
|
| 11 | + | n: i32, |
|
| 12 | + | } |
|
| 13 | + | ||
| 14 | + | trait Transform { |
|
| 15 | + | fn (*Transform) apply(x: i32) -> i32; |
|
| 16 | + | } |
|
| 17 | + | ||
| 18 | + | instance Transform for Adder { |
|
| 19 | + | fn (a: *Adder) apply(x: i32) -> i32 { |
|
| 20 | + | return x + a.n; |
|
| 21 | + | } |
|
| 22 | + | } |
|
| 23 | + | ||
| 24 | + | instance Transform for Multiplier { |
|
| 25 | + | fn (m: *Multiplier) apply(x: i32) -> i32 { |
|
| 26 | + | return x * m.n; |
|
| 27 | + | } |
|
| 28 | + | } |
|
| 29 | + | ||
| 30 | + | /// Apply a chain of transforms to a value. |
|
| 31 | + | fn applyAll(transforms: *[*opaque Transform], value: i32) -> i32 { |
|
| 32 | + | let mut result = value; |
|
| 33 | + | for i in 0..transforms.len { |
|
| 34 | + | result = transforms[i].apply(result); |
|
| 35 | + | } |
|
| 36 | + | return result; |
|
| 37 | + | } |
|
| 38 | + | ||
| 39 | + | /// Apply an optional transform, returning the original value if nil. |
|
| 40 | + | fn applyMaybe(t: ?*opaque Transform, value: i32) -> i32 { |
|
| 41 | + | if let tr = t { |
|
| 42 | + | return tr.apply(value); |
|
| 43 | + | } |
|
| 44 | + | return value; |
|
| 45 | + | } |
|
| 46 | + | ||
| 47 | + | @default fn main() -> i32 { |
|
| 48 | + | let a1 = Adder { n: 10 }; |
|
| 49 | + | let m1 = Multiplier { n: 3 }; |
|
| 50 | + | let a2 = Adder { n: 5 }; |
|
| 51 | + | ||
| 52 | + | let t1: *opaque Transform = &a1; |
|
| 53 | + | let t2: *opaque Transform = &m1; |
|
| 54 | + | let t3: *opaque Transform = &a2; |
|
| 55 | + | ||
| 56 | + | // Array of trait objects. |
|
| 57 | + | let transforms: [*opaque Transform; 3] = [t1, t2, t3]; |
|
| 58 | + | // Chain: (1 + 10) * 3 + 5 = 38 |
|
| 59 | + | let result = applyAll(&transforms[..], 1); |
|
| 60 | + | if result != 38 { |
|
| 61 | + | return 1; |
|
| 62 | + | } |
|
| 63 | + | ||
| 64 | + | // Optional trait object - Some case. |
|
| 65 | + | let opt: ?*opaque Transform = t1; |
|
| 66 | + | let r2 = applyMaybe(opt, 5); |
|
| 67 | + | if r2 != 15 { |
|
| 68 | + | return 2; |
|
| 69 | + | } |
|
| 70 | + | ||
| 71 | + | // Optional trait object - None case. |
|
| 72 | + | let none: ?*opaque Transform = nil; |
|
| 73 | + | let r3 = applyMaybe(none, 5); |
|
| 74 | + | if r3 != 5 { |
|
| 75 | + | return 3; |
|
| 76 | + | } |
|
| 77 | + | ||
| 78 | + | return 0; |
|
| 79 | + | } |
lib/std/arch/rv64/tests/trait.throws.rad
added
+61 -0
| 1 | + | //! Test trait methods with throws clauses. |
|
| 2 | + | ||
| 3 | + | union ParseError { |
|
| 4 | + | InvalidInput, |
|
| 5 | + | Overflow, |
|
| 6 | + | } |
|
| 7 | + | ||
| 8 | + | record StrictParser { |
|
| 9 | + | limit: i32, |
|
| 10 | + | } |
|
| 11 | + | ||
| 12 | + | trait Parser { |
|
| 13 | + | fn (*Parser) parse(n: i32) -> i32 throws (ParseError); |
|
| 14 | + | } |
|
| 15 | + | ||
| 16 | + | instance Parser for StrictParser { |
|
| 17 | + | fn (p: *StrictParser) parse(n: i32) -> i32 throws (ParseError) { |
|
| 18 | + | if n < 0 { |
|
| 19 | + | throw ParseError::InvalidInput; |
|
| 20 | + | } |
|
| 21 | + | if n > p.limit { |
|
| 22 | + | throw ParseError::Overflow; |
|
| 23 | + | } |
|
| 24 | + | return n * 2; |
|
| 25 | + | } |
|
| 26 | + | } |
|
| 27 | + | ||
| 28 | + | @default fn main() -> i32 { |
|
| 29 | + | let sp = StrictParser { limit: 100 }; |
|
| 30 | + | let p: *opaque Parser = &sp; |
|
| 31 | + | ||
| 32 | + | // Success path. |
|
| 33 | + | let r1 = try p.parse(5) catch { |
|
| 34 | + | return 1; |
|
| 35 | + | }; |
|
| 36 | + | if r1 != 10 { |
|
| 37 | + | return 2; |
|
| 38 | + | } |
|
| 39 | + | ||
| 40 | + | // Error path: negative input. |
|
| 41 | + | let mut caught = false; |
|
| 42 | + | let r2 = try p.parse(-1) catch { |
|
| 43 | + | caught = true; |
|
| 44 | + | 0 |
|
| 45 | + | }; |
|
| 46 | + | if not caught { |
|
| 47 | + | return 3; |
|
| 48 | + | } |
|
| 49 | + | ||
| 50 | + | // Error path: overflow. |
|
| 51 | + | caught = false; |
|
| 52 | + | let r3 = try p.parse(200) catch { |
|
| 53 | + | caught = true; |
|
| 54 | + | 0 |
|
| 55 | + | }; |
|
| 56 | + | if not caught { |
|
| 57 | + | return 4; |
|
| 58 | + | } |
|
| 59 | + | ||
| 60 | + | return 0; |
|
| 61 | + | } |
lib/std/lang/lower.rad
+52 -10
| 5954 | 5954 | // Type of the try expression, which is either the return type of the function |
|
| 5955 | 5955 | // if successful, or an optional of it, if using `try?`. |
|
| 5956 | 5956 | let tryExprTy = resolver::typeFor(self.low.resolver, node) else { |
|
| 5957 | 5957 | throw LowerError::MissingType(node); |
|
| 5958 | 5958 | }; |
|
| 5959 | - | let resVal = try lowerCall(self, t.expr, callExpr); |
|
| 5959 | + | // Check for trait method dispatch. |
|
| 5960 | + | let mut resVal: il::Val = undefined; |
|
| 5961 | + | if let case resolver::NodeExtra::TraitMethodCall { |
|
| 5962 | + | traitInfo, methodIndex |
|
| 5963 | + | } = resolver::nodeData(self.low.resolver, t.expr).extra { |
|
| 5964 | + | resVal = try lowerTraitMethodCall(self, t.expr, callExpr, traitInfo, methodIndex); |
|
| 5965 | + | } else { |
|
| 5966 | + | resVal = try lowerCall(self, t.expr, callExpr); |
|
| 5967 | + | } |
|
| 5960 | 5968 | let base = try emitValToReg(self, resVal); // The result value. |
|
| 5961 | 5969 | let tagReg = resultTagReg(self, base); // The result tag. |
|
| 5962 | 5970 | ||
| 5963 | 5971 | let okBlock = try createBlock(self, "ok"); // Block if success. |
|
| 5964 | 5972 | let errBlock = try createBlock(self, "err"); // Block if failure. |
| 6252 | 6260 | ||
| 6253 | 6261 | // Build args: data pointer (receiver) + user args. |
|
| 6254 | 6262 | let argOffset: u32 = 1 if returnParam else 0; |
|
| 6255 | 6263 | let args = try allocVals(self, call.args.len + 1 + argOffset); |
|
| 6256 | 6264 | ||
| 6257 | - | if returnParam { |
|
| 6258 | - | args[0] = il::Val::Reg(try emitReserve(self, retTy)); |
|
| 6259 | - | } |
|
| 6260 | - | // Data pointer is the receiver (first argument). |
|
| 6265 | + | // Data pointer is the receiver (first argument after hidden return param). |
|
| 6261 | 6266 | args[argOffset] = il::Val::Reg(dataReg); |
|
| 6262 | 6267 | ||
| 6263 | 6268 | // Lower user arguments. |
|
| 6264 | 6269 | for i in 0..call.args.len { |
|
| 6265 | 6270 | args[i + 1 + argOffset] = try lowerExpr(self, call.args.list[i]); |
|
| 6266 | 6271 | } |
|
| 6267 | - | // Emit the call. Hidden return param calls always return a pointer. |
|
| 6268 | - | let callRetTy = il::Type::W64 if returnParam else ilType(self.low, retTy); |
|
| 6269 | - | let mut dst: ?il::Reg = nil; |
|
| 6270 | 6272 | ||
| 6271 | - | if returnParam or retTy != resolver::Type::Void { |
|
| 6273 | + | // Allocate the return buffer when needed. |
|
| 6274 | + | if returnParam { |
|
| 6275 | + | if methodFnType.throwListLen > 0 { |
|
| 6276 | + | let successType = *methodFnType.returnType; |
|
| 6277 | + | let layout = resolver::getResultLayout( |
|
| 6278 | + | successType, &methodFnType.throwList[..methodFnType.throwListLen]); |
|
| 6279 | + | ||
| 6280 | + | args[0] = il::Val::Reg(try emitReserveLayout(self, layout)); |
|
| 6281 | + | } else { |
|
| 6282 | + | args[0] = il::Val::Reg(try emitReserve(self, retTy)); |
|
| 6283 | + | } |
|
| 6284 | + | let dst = nextReg(self); |
|
| 6285 | + | ||
| 6286 | + | emit(self, il::Instr::Call { |
|
| 6287 | + | retTy: il::Type::W64, |
|
| 6288 | + | dst, |
|
| 6289 | + | func: il::Val::Reg(fnPtrReg), |
|
| 6290 | + | args, |
|
| 6291 | + | }); |
|
| 6292 | + | return il::Val::Reg(dst); |
|
| 6293 | + | } |
|
| 6294 | + | ||
| 6295 | + | // Scalar call: allocate destination register for non-void return types. |
|
| 6296 | + | let mut dst: ?il::Reg = nil; |
|
| 6297 | + | if retTy != resolver::Type::Void { |
|
| 6272 | 6298 | dst = nextReg(self); |
|
| 6273 | 6299 | } |
|
| 6274 | 6300 | emit(self, il::Instr::Call { |
|
| 6275 | - | retTy: callRetTy, |
|
| 6301 | + | retTy: ilType(self.low, retTy), |
|
| 6276 | 6302 | dst, |
|
| 6277 | 6303 | func: il::Val::Reg(fnPtrReg), |
|
| 6278 | 6304 | args, |
|
| 6279 | 6305 | }); |
|
| 6280 | 6306 | ||
| 6281 | 6307 | if let d = dst { |
|
| 6308 | + | if isSmallAggregate(retTy) { |
|
| 6309 | + | let slot = try emitReserveLayout( |
|
| 6310 | + | self, |
|
| 6311 | + | resolver::Layout { |
|
| 6312 | + | size: resolver::PTR_SIZE, |
|
| 6313 | + | alignment: resolver::PTR_SIZE |
|
| 6314 | + | }); |
|
| 6315 | + | ||
| 6316 | + | emit(self, il::Instr::Store { |
|
| 6317 | + | typ: il::Type::W64, |
|
| 6318 | + | src: il::Val::Reg(d), |
|
| 6319 | + | dst: slot, |
|
| 6320 | + | offset: 0, |
|
| 6321 | + | }); |
|
| 6322 | + | return il::Val::Reg(slot); |
|
| 6323 | + | } |
|
| 6282 | 6324 | return il::Val::Reg(d); |
|
| 6283 | 6325 | } |
|
| 6284 | 6326 | return il::Val::Undef; |
|
| 6285 | 6327 | } |
|
| 6286 | 6328 |
lib/std/lang/lower/tests/trait.dispatch.rad
added
+23 -0
| 1 | + | record Acc { |
|
| 2 | + | n: i32, |
|
| 3 | + | } |
|
| 4 | + | ||
| 5 | + | trait Ops { |
|
| 6 | + | fn (*Ops) get() -> i32; |
|
| 7 | + | fn (*mut Ops) set(n: i32); |
|
| 8 | + | } |
|
| 9 | + | ||
| 10 | + | instance Ops for Acc { |
|
| 11 | + | fn (a: *Acc) get() -> i32 { |
|
| 12 | + | return a.n; |
|
| 13 | + | } |
|
| 14 | + | ||
| 15 | + | fn (a: *mut Acc) set(n: i32) { |
|
| 16 | + | a.n = n; |
|
| 17 | + | } |
|
| 18 | + | } |
|
| 19 | + | ||
| 20 | + | fn dispatch(o: *mut opaque Ops) -> i32 { |
|
| 21 | + | o.set(42); |
|
| 22 | + | return o.get(); |
|
| 23 | + | } |
lib/std/lang/lower/tests/trait.dispatch.ril
added
+29 -0
| 1 | + | data $"vtable::Acc::Ops" align 8 { |
|
| 2 | + | fn $"Acc::get"; |
|
| 3 | + | fn $"Acc::set"; |
|
| 4 | + | } |
|
| 5 | + | ||
| 6 | + | fn w32 $"Acc::get"(w64 %0) { |
|
| 7 | + | @entry0 |
|
| 8 | + | sload w32 %1 %0 0; |
|
| 9 | + | ret %1; |
|
| 10 | + | } |
|
| 11 | + | ||
| 12 | + | fn w32 $"Acc::set"(w64 %0, w32 %1) { |
|
| 13 | + | @entry0 |
|
| 14 | + | store w32 %1 %0 0; |
|
| 15 | + | ret; |
|
| 16 | + | } |
|
| 17 | + | ||
| 18 | + | fn w32 $dispatch(w64 %0) { |
|
| 19 | + | @entry0 |
|
| 20 | + | load w64 %1 %0 0; |
|
| 21 | + | load w64 %2 %0 8; |
|
| 22 | + | load w64 %3 %2 8; |
|
| 23 | + | call w32 %3(%1, 42); |
|
| 24 | + | load w64 %4 %0 0; |
|
| 25 | + | load w64 %5 %0 8; |
|
| 26 | + | load w64 %6 %5 0; |
|
| 27 | + | call w32 %7 %6(%4); |
|
| 28 | + | ret %7; |
|
| 29 | + | } |
lib/std/lang/lower/tests/trait.object.rad
added
+20 -0
| 1 | + | record Counter { |
|
| 2 | + | value: i32, |
|
| 3 | + | } |
|
| 4 | + | ||
| 5 | + | trait Adder { |
|
| 6 | + | fn (*mut Adder) add(n: i32) -> i32; |
|
| 7 | + | } |
|
| 8 | + | ||
| 9 | + | instance Adder for Counter { |
|
| 10 | + | fn (c: *mut Counter) add(n: i32) -> i32 { |
|
| 11 | + | c.value = c.value + n; |
|
| 12 | + | return c.value; |
|
| 13 | + | } |
|
| 14 | + | } |
|
| 15 | + | ||
| 16 | + | fn use_adder() -> i32 { |
|
| 17 | + | let mut c = Counter { value: 0 }; |
|
| 18 | + | let a: *mut opaque Adder = &mut c; |
|
| 19 | + | return a.add(1); |
|
| 20 | + | } |
lib/std/lang/lower/tests/trait.object.ril
added
+26 -0
| 1 | + | data $"vtable::Counter::Adder" align 8 { |
|
| 2 | + | fn $"Counter::add"; |
|
| 3 | + | } |
|
| 4 | + | ||
| 5 | + | fn w32 $"Counter::add"(w64 %0, w32 %1) { |
|
| 6 | + | @entry0 |
|
| 7 | + | sload w32 %2 %0 0; |
|
| 8 | + | add w32 %3 %2 %1; |
|
| 9 | + | store w32 %3 %0 0; |
|
| 10 | + | sload w32 %4 %0 0; |
|
| 11 | + | ret %4; |
|
| 12 | + | } |
|
| 13 | + | ||
| 14 | + | fn w32 $use_adder() { |
|
| 15 | + | @entry0 |
|
| 16 | + | reserve %0 4 4; |
|
| 17 | + | store w32 0 %0 0; |
|
| 18 | + | reserve %1 16 8; |
|
| 19 | + | store w64 %0 %1 0; |
|
| 20 | + | store w64 $"vtable::Counter::Adder" %1 8; |
|
| 21 | + | load w64 %2 %1 0; |
|
| 22 | + | load w64 %3 %1 8; |
|
| 23 | + | load w64 %4 %3 0; |
|
| 24 | + | call w32 %5 %4(%2, 1); |
|
| 25 | + | ret %5; |
|
| 26 | + | } |
lib/std/lang/lower/tests/trait.supertrait.rad
added
+23 -0
| 1 | + | record Widget { |
|
| 2 | + | x: i32, |
|
| 3 | + | } |
|
| 4 | + | ||
| 5 | + | trait Base { |
|
| 6 | + | fn (*Base) get() -> i32; |
|
| 7 | + | } |
|
| 8 | + | ||
| 9 | + | trait Child: Base { |
|
| 10 | + | fn (*mut Child) set(n: i32); |
|
| 11 | + | } |
|
| 12 | + | ||
| 13 | + | instance Base for Widget { |
|
| 14 | + | fn (w: *Widget) get() -> i32 { |
|
| 15 | + | return w.x; |
|
| 16 | + | } |
|
| 17 | + | } |
|
| 18 | + | ||
| 19 | + | instance Child for Widget { |
|
| 20 | + | fn (w: *mut Widget) set(n: i32) { |
|
| 21 | + | w.x = n; |
|
| 22 | + | } |
|
| 23 | + | } |
lib/std/lang/lower/tests/trait.supertrait.ril
added
+20 -0
| 1 | + | data $"vtable::Widget::Base" align 8 { |
|
| 2 | + | fn $"Widget::get"; |
|
| 3 | + | } |
|
| 4 | + | ||
| 5 | + | data $"vtable::Widget::Child" align 8 { |
|
| 6 | + | fn $"Widget::get"; |
|
| 7 | + | fn $"Widget::set"; |
|
| 8 | + | } |
|
| 9 | + | ||
| 10 | + | fn w32 $"Widget::get"(w64 %0) { |
|
| 11 | + | @entry0 |
|
| 12 | + | sload w32 %1 %0 0; |
|
| 13 | + | ret %1; |
|
| 14 | + | } |
|
| 15 | + | ||
| 16 | + | fn w32 $"Widget::set"(w64 %0, w32 %1) { |
|
| 17 | + | @entry0 |
|
| 18 | + | store w32 %1 %0 0; |
|
| 19 | + | ret; |
|
| 20 | + | } |
lib/std/lang/resolver.rad
+19 -15
| 3506 | 3506 | // Find the matching trait method. |
|
| 3507 | 3507 | let tm = findTraitMethod(traitInfo, methodName) |
|
| 3508 | 3508 | else throw emitError(self, name, ErrorKind::UnresolvedSymbol(methodName)); |
|
| 3509 | 3509 | ||
| 3510 | 3510 | // Determine receiver mutability and validate receiver type. |
|
| 3511 | - | let mut receiverMut = false; |
|
| 3512 | - | if let case ast::NodeValue::TypeSig(typeSig) = receiverType.value { |
|
| 3513 | - | if let case ast::TypeSig::Pointer { mutable, valueType } = typeSig { |
|
| 3514 | - | receiverMut = mutable; |
|
| 3515 | - | // Validate that the receiver type annotation matches the |
|
| 3516 | - | // concrete type from the instance declaration. |
|
| 3517 | - | let annotatedTy = try infer(self, valueType); |
|
| 3518 | - | if not typesEqual(annotatedTy, concreteType) { |
|
| 3519 | - | throw emitTypeMismatch(self, receiverType, TypeMismatch { |
|
| 3520 | - | expected: concreteType, |
|
| 3521 | - | actual: annotatedTy, |
|
| 3522 | - | }); |
|
| 3523 | - | } |
|
| 3524 | - | } |
|
| 3511 | + | // The receiver must be `*Type` or `*mut Type`. |
|
| 3512 | + | let case ast::NodeValue::TypeSig(typeSig) = receiverType.value |
|
| 3513 | + | else throw emitError(self, receiverType, ErrorKind::TraitReceiverMismatch); |
|
| 3514 | + | let case ast::TypeSig::Pointer { mutable: receiverMut, valueType } = typeSig |
|
| 3515 | + | else throw emitError(self, receiverType, ErrorKind::TraitReceiverMismatch); |
|
| 3516 | + | ||
| 3517 | + | // Validate that the receiver type annotation matches the |
|
| 3518 | + | // concrete type from the instance declaration. |
|
| 3519 | + | let annotatedTy = try infer(self, valueType); |
|
| 3520 | + | if not typesEqual(annotatedTy, concreteType) { |
|
| 3521 | + | throw emitTypeMismatch(self, receiverType, TypeMismatch { |
|
| 3522 | + | expected: concreteType, |
|
| 3523 | + | actual: annotatedTy, |
|
| 3524 | + | }); |
|
| 3525 | 3525 | } |
|
| 3526 | 3526 | ||
| 3527 | 3527 | // Check receiver mutability matches in both directions. |
|
| 3528 | 3528 | if tm.mutable and not receiverMut { |
|
| 3529 | 3529 | throw emitError(self, receiverType, ErrorKind::ImmutableBinding); |
| 4744 | 4744 | if let t = typeFor(self, access.parent) { |
|
| 4745 | 4745 | parentTy = t; |
|
| 4746 | 4746 | } |
|
| 4747 | 4747 | let subjectTy = autoDeref(parentTy); |
|
| 4748 | 4748 | ||
| 4749 | - | if let case Type::TraitObject { traitInfo, .. } = subjectTy { |
|
| 4749 | + | if let case Type::TraitObject { traitInfo, mutable: objMutable } = subjectTy { |
|
| 4750 | 4750 | let methodName = try nodeName(self, access.child); |
|
| 4751 | 4751 | let method = findTraitMethod(traitInfo, methodName) |
|
| 4752 | 4752 | else throw emitError(self, access.child, ErrorKind::RecordFieldUnknown(methodName)); |
|
| 4753 | 4753 | ||
| 4754 | + | // Reject mutable-receiver methods called on immutable trait objects. |
|
| 4755 | + | if method.mutable and not objMutable { |
|
| 4756 | + | throw emitError(self, access.parent, ErrorKind::ImmutableBinding); |
|
| 4757 | + | } |
|
| 4754 | 4758 | try checkCallArgs(self, node, call, method.fnType, ctx); |
|
| 4755 | 4759 | setTraitMethodCall(self, node, traitInfo, method.index); |
|
| 4756 | 4760 | ||
| 4757 | 4761 | return try setNodeType(self, node, *method.fnType.returnType); |
|
| 4758 | 4762 | } |
lib/std/lang/resolver/tests.rad
+69 -0
| 253 | 253 | if let case super::ErrorKind::MissingTraitMethod(actualName) = *actual { |
|
| 254 | 254 | return mem::eq(actualName, expectedName); |
|
| 255 | 255 | } |
|
| 256 | 256 | return false; |
|
| 257 | 257 | } |
|
| 258 | + | if let case super::ErrorKind::MissingSupertraitInstance(expectedName) = expected { |
|
| 259 | + | if let case super::ErrorKind::MissingSupertraitInstance(actualName) = *actual { |
|
| 260 | + | return mem::eq(actualName, expectedName); |
|
| 261 | + | } |
|
| 262 | + | return false; |
|
| 263 | + | } |
|
| 258 | 264 | return *actual == expected; |
|
| 259 | 265 | } |
|
| 260 | 266 | ||
| 261 | 267 | /// Extract the first error and ensure it has the expected kind. |
|
| 262 | 268 | fn expectErrorKind(result: *TestResult, kind: super::ErrorKind) -> *super::Error |
| 4612 | 4618 | let appId = try registerModule(&mut MODULE_GRAPH, rootId, "app", "use root::defs; fn test() -> i32 { let mut c = defs::Counter { value: 10 }; let a: *mut opaque defs::Adder = &mut c; return a.add(5); }", &mut arena); |
|
| 4613 | 4619 | ||
| 4614 | 4620 | let result = try resolveModuleTree(&mut a, rootId); |
|
| 4615 | 4621 | try expectNoErrors(&result); |
|
| 4616 | 4622 | } |
|
| 4623 | + | ||
| 4624 | + | /// Calling a mutable-receiver trait method on an immutable trait object |
|
| 4625 | + | /// must be rejected. |
|
| 4626 | + | @test fn testResolveTraitMutMethodOnImmutableObject() throws (testing::TestError) { |
|
| 4627 | + | let mut a = testResolver(); |
|
| 4628 | + | let program = "record Counter { value: i32 } trait Adder { fn (*mut Adder) add(n: i32) -> i32; } instance Adder for Counter { fn (c: *mut Counter) add(n: i32) -> i32 { c.value = c.value + n; return c.value; } } fn caller(a: *opaque Adder) -> i32 { return a.add(1); }"; |
|
| 4629 | + | let result = try resolveProgramStr(&mut a, program); |
|
| 4630 | + | try expectErrorKind(&result, super::ErrorKind::ImmutableBinding); |
|
| 4631 | + | } |
|
| 4632 | + | ||
| 4633 | + | /// Immutable methods on an immutable trait object should be accepted. |
|
| 4634 | + | @test fn testResolveTraitImmutableMethodOnImmutableObject() throws (testing::TestError) { |
|
| 4635 | + | let mut a = testResolver(); |
|
| 4636 | + | let program = "record Counter { value: i32 } trait Reader { fn (*Reader) get() -> i32; } instance Reader for Counter { fn (c: *Counter) get() -> i32 { return c.value; } } fn caller(r: *opaque Reader) -> i32 { return r.get(); }"; |
|
| 4637 | + | let result = try resolveProgramStr(&mut a, program); |
|
| 4638 | + | try expectNoErrors(&result); |
|
| 4639 | + | } |
|
| 4640 | + | ||
| 4641 | + | /// Both mutable and immutable methods on a mutable trait object should work. |
|
| 4642 | + | @test fn testResolveTraitMixedMethodsOnMutableObject() throws (testing::TestError) { |
|
| 4643 | + | let mut a = testResolver(); |
|
| 4644 | + | let program = "record Counter { value: i32 } trait Ops { fn (*mut Ops) inc(); fn (*Ops) get() -> i32; } instance Ops for Counter { fn (c: *mut Counter) inc() { c.value = c.value + 1; } fn (c: *Counter) get() -> i32 { return c.value; } } fn caller(o: *mut opaque Ops) -> i32 { o.inc(); return o.get(); }"; |
|
| 4645 | + | let result = try resolveProgramStr(&mut a, program); |
|
| 4646 | + | try expectNoErrors(&result); |
|
| 4647 | + | } |
|
| 4648 | + | ||
| 4649 | + | /// Instance method body type must match the trait return type. |
|
| 4650 | + | /// The trait declares `-> i32` but the body returns `bool`. |
|
| 4651 | + | @test fn testResolveInstanceReturnTypeMismatch() throws (testing::TestError) { |
|
| 4652 | + | let mut a = testResolver(); |
|
| 4653 | + | let program = "record R { x: i32 } trait T { fn (*T) get() -> i32; } instance T for R { fn (r: *R) get() -> bool { return true; } }"; |
|
| 4654 | + | let result = try resolveProgramStr(&mut a, program); |
|
| 4655 | + | let err = try expectError(&result); |
|
| 4656 | + | let case super::ErrorKind::TypeMismatch(_) = err.kind |
|
| 4657 | + | else throw testing::TestError::Failed; |
|
| 4658 | + | } |
|
| 4659 | + | ||
| 4660 | + | /// Diamond supertrait inheritance: traits B and C both extend A. |
|
| 4661 | + | /// Declaring them independently should work fine. |
|
| 4662 | + | @test fn testResolveTraitDiamondSupertrait() throws (testing::TestError) { |
|
| 4663 | + | let mut a = testResolver(); |
|
| 4664 | + | let program = "trait A { fn (*A) f() -> i32; } trait B: A { fn (*B) g() -> i32; } trait C: A { fn (*C) h() -> i32; }"; |
|
| 4665 | + | let result = try resolveProgramStr(&mut a, program); |
|
| 4666 | + | try expectNoErrors(&result); |
|
| 4667 | + | } |
|
| 4668 | + | ||
| 4669 | + | /// Diamond supertrait with a combined trait that would cause duplicate |
|
| 4670 | + | /// method names should be detected. |
|
| 4671 | + | @test fn testResolveTraitDiamondDuplicateMethod() throws (testing::TestError) { |
|
| 4672 | + | let mut a = testResolver(); |
|
| 4673 | + | let program = "trait A { fn (*A) f() -> i32; } trait B: A { fn (*B) g() -> i32; } trait C: A { fn (*C) h() -> i32; } trait D: B + C { fn (*D) i() -> i32; }"; |
|
| 4674 | + | let result = try resolveProgramStr(&mut a, program); |
|
| 4675 | + | // B inherits `f` from A, C inherits `f` from A. D: B + C sees duplicate `f`. |
|
| 4676 | + | try expectErrorKind(&result, super::ErrorKind::DuplicateBinding("f")); |
|
| 4677 | + | } |
|
| 4678 | + | ||
| 4679 | + | /// Supertrait instance must exist when declaring a combined trait instance. |
|
| 4680 | + | @test fn testResolveInstanceMissingSupertraitInstance() throws (testing::TestError) { |
|
| 4681 | + | let mut a = testResolver(); |
|
| 4682 | + | let program = "trait Base { fn (*Base) f() -> i32; } trait Child: Base { fn (*Child) g() -> i32; } record R { x: i32 } instance Child for R { fn (r: *R) g() -> i32 { return r.x; } }"; |
|
| 4683 | + | let result = try resolveProgramStr(&mut a, program); |
|
| 4684 | + | try expectErrorKind(&result, super::ErrorKind::MissingSupertraitInstance("Base")); |
|
| 4685 | + | } |