Implement dynamic dispatch through v-tables
a521693edfb448a4a9cf40aa30e8e56e4c8d6c79730e5e181bc9ad74e46917ce
1 parent
e3e79cd2
lib/std/arch/rv64/tests/trait.basic.rad
added
+32 -0
| 1 | + | //! Basic trait object dispatch test. |
|
| 2 | + | record Counter { |
|
| 3 | + | value: i32, |
|
| 4 | + | } |
|
| 5 | + | ||
| 6 | + | trait Adder { |
|
| 7 | + | fn (*mut Adder) add(n: i32) -> i32; |
|
| 8 | + | } |
|
| 9 | + | ||
| 10 | + | instance Adder for Counter { |
|
| 11 | + | fn (c: *mut Counter) add(n: i32) -> i32 { |
|
| 12 | + | c.value = c.value + n; |
|
| 13 | + | return c.value; |
|
| 14 | + | } |
|
| 15 | + | } |
|
| 16 | + | ||
| 17 | + | @default fn main() -> i32 { |
|
| 18 | + | let mut c = Counter { value: 10 }; |
|
| 19 | + | let a: *mut opaque Adder = &mut c; |
|
| 20 | + | ||
| 21 | + | let result = a.add(5); |
|
| 22 | + | // c.value should be 15 now. |
|
| 23 | + | if result != 15 { |
|
| 24 | + | return 1; |
|
| 25 | + | } |
|
| 26 | + | let result2 = a.add(3); |
|
| 27 | + | // c.value should be 18 now. |
|
| 28 | + | if result2 != 18 { |
|
| 29 | + | return 2; |
|
| 30 | + | } |
|
| 31 | + | return 0; |
|
| 32 | + | } |
lib/std/lang/lower.rad
+138 -8
| 4081 | 4081 | try emitStore(self, dst, SLICE_LEN_OFFSET, resolver::Type::U32, lenVal); |
|
| 4082 | 4082 | ||
| 4083 | 4083 | return il::Val::Reg(dst); |
|
| 4084 | 4084 | } |
|
| 4085 | 4085 | ||
| 4086 | + | /// Build a trait object fat pointer from a data pointer and a v-table. |
|
| 4087 | + | fn buildTraitObject( |
|
| 4088 | + | self: *mut FnLowerer, |
|
| 4089 | + | dataVal: il::Val, |
|
| 4090 | + | traitInfo: *resolver::TraitType, |
|
| 4091 | + | inst: *resolver::InstanceEntry |
|
| 4092 | + | ) -> il::Val throws (LowerError) { |
|
| 4093 | + | let vName = vtableName(self.low, inst.moduleId, inst.concreteTypeName, traitInfo.name); |
|
| 4094 | + | ||
| 4095 | + | // Reserve space for the trait object on the stack. |
|
| 4096 | + | let slot = try emitReserveLayout(self, resolver::Layout { |
|
| 4097 | + | size: resolver::PTR_SIZE * 2, |
|
| 4098 | + | alignment: resolver::PTR_SIZE, |
|
| 4099 | + | }); |
|
| 4100 | + | ||
| 4101 | + | // Store data pointer. |
|
| 4102 | + | emit(self, il::Instr::Store { |
|
| 4103 | + | typ: il::Type::W64, |
|
| 4104 | + | src: dataVal, |
|
| 4105 | + | dst: slot, |
|
| 4106 | + | offset: TRAIT_OBJ_DATA_OFFSET, |
|
| 4107 | + | }); |
|
| 4108 | + | ||
| 4109 | + | // Store v-table address. |
|
| 4110 | + | emit(self, il::Instr::Store { |
|
| 4111 | + | typ: il::Type::W64, |
|
| 4112 | + | src: il::Val::DataSym(vName), |
|
| 4113 | + | dst: slot, |
|
| 4114 | + | offset: TRAIT_OBJ_VTABLE_OFFSET, |
|
| 4115 | + | }); |
|
| 4116 | + | return il::Val::Reg(slot); |
|
| 4117 | + | } |
|
| 4118 | + | ||
| 4086 | 4119 | /// Compute a field pointer by adding a byte offset to a base address. |
|
| 4087 | 4120 | fn emitPtrOffset(self: *mut FnLowerer, base: il::Reg, offset: i32) -> il::Reg { |
|
| 4088 | 4121 | if offset == 0 { |
|
| 4089 | 4122 | return base; |
|
| 4090 | 4123 | } |
| 6124 | 6157 | } |
|
| 6125 | 6158 | } |
|
| 6126 | 6159 | ||
| 6127 | 6160 | /// Lower a call expression, which may be a function call or type constructor. |
|
| 6128 | 6161 | fn lowerCallOrCtor(self: *mut FnLowerer, node: *ast::Node, call: ast::Call) -> il::Val throws (LowerError) { |
|
| 6162 | + | // Check for trait method dispatch. |
|
| 6163 | + | if let case resolver::NodeExtra::TraitMethodCall { |
|
| 6164 | + | traitInfo, methodIndex |
|
| 6165 | + | } = resolver::nodeData(self.low.resolver, node).extra { |
|
| 6166 | + | return try lowerTraitMethodCall(self, node, call, traitInfo, methodIndex); |
|
| 6167 | + | } |
|
| 6129 | 6168 | if let sym = resolver::nodeData(self.low.resolver, call.callee).sym { |
|
| 6130 | 6169 | if let case resolver::SymbolData::Type(nominal) = sym.data { |
|
| 6131 | 6170 | let case resolver::NominalType::Record(_) = *nominal else { |
|
| 6132 | 6171 | throw LowerError::ExpectedRecord; |
|
| 6133 | 6172 | }; |
| 6138 | 6177 | } |
|
| 6139 | 6178 | } |
|
| 6140 | 6179 | return try lowerCall(self, node, call); |
|
| 6141 | 6180 | } |
|
| 6142 | 6181 | ||
| 6182 | + | /// Lower a trait method call through v-table dispatch. |
|
| 6183 | + | /// |
|
| 6184 | + | /// Given `obj.method(args)` where `obj` is a trait object, emits: |
|
| 6185 | + | /// |
|
| 6186 | + | /// load w64 %data %obj 0 // data pointer |
|
| 6187 | + | /// load w64 %vtable %obj 8 // v-table pointer |
|
| 6188 | + | /// load w64 %fn %vtable <slot> // function pointer |
|
| 6189 | + | /// call <retTy> %ret %fn(%data, args...) |
|
| 6190 | + | /// |
|
| 6191 | + | fn lowerTraitMethodCall( |
|
| 6192 | + | self: *mut FnLowerer, |
|
| 6193 | + | node: *ast::Node, |
|
| 6194 | + | call: ast::Call, |
|
| 6195 | + | traitInfo: *resolver::TraitType, |
|
| 6196 | + | methodIndex: u32 |
|
| 6197 | + | ) -> il::Val throws (LowerError) { |
|
| 6198 | + | // Method calls look like field accesses. |
|
| 6199 | + | let case ast::NodeValue::FieldAccess(access) = call.callee.value |
|
| 6200 | + | else throw LowerError::MissingMetadata; |
|
| 6201 | + | ||
| 6202 | + | // Lower the trait object expression. |
|
| 6203 | + | let traitObjVal = try lowerExpr(self, access.parent); |
|
| 6204 | + | let traitObjReg = try emitValToReg(self, traitObjVal); |
|
| 6205 | + | ||
| 6206 | + | // Load data pointer from trait object. |
|
| 6207 | + | let dataReg = nextReg(self); |
|
| 6208 | + | emit(self, il::Instr::Load { |
|
| 6209 | + | typ: il::Type::W64, |
|
| 6210 | + | dst: dataReg, |
|
| 6211 | + | src: traitObjReg, |
|
| 6212 | + | offset: TRAIT_OBJ_DATA_OFFSET, |
|
| 6213 | + | }); |
|
| 6214 | + | ||
| 6215 | + | // Load v-table pointer from trait object. |
|
| 6216 | + | let vtableReg = nextReg(self); |
|
| 6217 | + | emit(self, il::Instr::Load { |
|
| 6218 | + | typ: il::Type::W64, |
|
| 6219 | + | dst: vtableReg, |
|
| 6220 | + | src: traitObjReg, |
|
| 6221 | + | offset: TRAIT_OBJ_VTABLE_OFFSET, |
|
| 6222 | + | }); |
|
| 6223 | + | ||
| 6224 | + | // Load function pointer from v-table at the method's slot offset. |
|
| 6225 | + | let fnPtrReg = nextReg(self); |
|
| 6226 | + | let slotOffset = (methodIndex * resolver::PTR_SIZE) as i32; |
|
| 6227 | + | ||
| 6228 | + | emit(self, il::Instr::Load { |
|
| 6229 | + | typ: il::Type::W64, |
|
| 6230 | + | dst: fnPtrReg, |
|
| 6231 | + | src: vtableReg, |
|
| 6232 | + | offset: slotOffset, |
|
| 6233 | + | }); |
|
| 6234 | + | ||
| 6235 | + | // Check if the method needs a hidden return parameter. |
|
| 6236 | + | let methodFnType = traitInfo.methods[methodIndex].fnType; |
|
| 6237 | + | let retTy = *methodFnType.returnType; |
|
| 6238 | + | let returnParam = requiresReturnParam(methodFnType); |
|
| 6239 | + | ||
| 6240 | + | // Build args: data pointer (receiver) + user args. |
|
| 6241 | + | let argOffset: u32 = 1 if returnParam else 0; |
|
| 6242 | + | let args = try allocVals(self, call.args.len + 1 + argOffset); |
|
| 6243 | + | ||
| 6244 | + | if returnParam { |
|
| 6245 | + | args[0] = il::Val::Reg(try emitReserve(self, retTy)); |
|
| 6246 | + | } |
|
| 6247 | + | // Data pointer is the receiver (first argument). |
|
| 6248 | + | args[argOffset] = il::Val::Reg(dataReg); |
|
| 6249 | + | ||
| 6250 | + | // Lower user arguments. |
|
| 6251 | + | for i in 0..call.args.len { |
|
| 6252 | + | args[i + 1 + argOffset] = try lowerExpr(self, call.args.list[i]); |
|
| 6253 | + | } |
|
| 6254 | + | // Emit the call. Hidden return param calls always return a pointer. |
|
| 6255 | + | let callRetTy = il::Type::W64 if returnParam else ilType(self.low, retTy); |
|
| 6256 | + | let mut dst: ?il::Reg = nil; |
|
| 6257 | + | ||
| 6258 | + | if returnParam or retTy != resolver::Type::Void { |
|
| 6259 | + | dst = nextReg(self); |
|
| 6260 | + | } |
|
| 6261 | + | emit(self, il::Instr::Call { |
|
| 6262 | + | retTy: callRetTy, |
|
| 6263 | + | dst, |
|
| 6264 | + | func: il::Val::Reg(fnPtrReg), |
|
| 6265 | + | args, |
|
| 6266 | + | }); |
|
| 6267 | + | ||
| 6268 | + | if let d = dst { |
|
| 6269 | + | return il::Val::Reg(d); |
|
| 6270 | + | } |
|
| 6271 | + | return il::Val::Undef; |
|
| 6272 | + | } |
|
| 6273 | + | ||
| 6143 | 6274 | /// Check if a call is to a compiler intrinsic and lower it directly. |
|
| 6144 | 6275 | fn lowerIntrinsicCall(self: *mut FnLowerer, call: ast::Call) -> ?il::Val throws (LowerError) { |
|
| 6145 | 6276 | // Get the callee symbol and check if it's marked as an intrinsic. |
|
| 6146 | 6277 | let sym = resolver::nodeData(self.low.resolver, call.callee).sym else { |
|
| 6147 | 6278 | // Expressions or function pointers may not have an associated symbol. |
| 6223 | 6354 | throw LowerError::ExpectedFunction; |
|
| 6224 | 6355 | }; |
|
| 6225 | 6356 | let retTy = resolver::typeFor(self.low.resolver, node) else { |
|
| 6226 | 6357 | throw LowerError::MissingType(node); |
|
| 6227 | 6358 | }; |
|
| 6228 | - | let isThrowing = fnInfo.throwListLen > 0; |
|
| 6229 | - | let needsRetBuf = isThrowing or (isAggregateType(retTy) and not isSmallAggregate(retTy)); |
|
| 6359 | + | let returnParam = requiresReturnParam(fnInfo); |
|
| 6230 | 6360 | ||
| 6231 | 6361 | // Lower function value and arguments, reserving an extra slot for the |
|
| 6232 | 6362 | // hidden return buffer when needed. |
|
| 6233 | 6363 | let callee = try lowerCallee(self, call.callee); |
|
| 6234 | - | let offset: u32 = 1 if needsRetBuf else 0; |
|
| 6364 | + | let offset: u32 = 1 if returnParam else 0; |
|
| 6235 | 6365 | let args = try allocVals(self, call.args.len + offset); |
|
| 6236 | 6366 | for i in 0..call.args.len { |
|
| 6237 | 6367 | args[i + offset] = try lowerExpr(self, call.args.list[i]); |
|
| 6238 | 6368 | } |
|
| 6239 | 6369 | ||
| 6240 | 6370 | // Allocate the return buffer when needed. |
|
| 6241 | - | if needsRetBuf { |
|
| 6242 | - | if isThrowing { |
|
| 6371 | + | if returnParam { |
|
| 6372 | + | if fnInfo.throwListLen > 0 { |
|
| 6243 | 6373 | let successType = *fnInfo.returnType; |
|
| 6244 | 6374 | let layout = resolver::getResultLayout(successType, &fnInfo.throwList[..fnInfo.throwListLen]); |
|
| 6245 | 6375 | ||
| 6246 | 6376 | args[0] = il::Val::Reg(try emitReserveLayout(self, layout)); |
|
| 6247 | 6377 | } else { |
| 6309 | 6439 | } |
|
| 6310 | 6440 | case resolver::Coercion::ResultWrap => { |
|
| 6311 | 6441 | let payloadType = *self.fnType.returnType; |
|
| 6312 | 6442 | return try buildResult(self, 0, val, payloadType); |
|
| 6313 | 6443 | } |
|
| 6314 | - | case resolver::Coercion::TraitObject(_) => { |
|
| 6315 | - | // TODO. |
|
| 6316 | - | return val; |
|
| 6444 | + | case resolver::Coercion::TraitObject { traitInfo, inst } => { |
|
| 6445 | + | return try buildTraitObject(self, val, traitInfo, inst); |
|
| 6317 | 6446 | } |
|
| 6318 | 6447 | case resolver::Coercion::Identity => return val, |
|
| 6319 | 6448 | } |
|
| 6320 | 6449 | } |
|
| 6321 | 6450 |
| 6651 | 6780 | return il::Type::W8; |
|
| 6652 | 6781 | } |
|
| 6653 | 6782 | return il::Type::W64; |
|
| 6654 | 6783 | } |
|
| 6655 | 6784 | case resolver::Type::Fn(_) => return il::Type::W64, |
|
| 6785 | + | case resolver::Type::TraitObject { .. } => return il::Type::W64, |
|
| 6656 | 6786 | // Void functions return zero at the IL level. |
|
| 6657 | 6787 | case resolver::Type::Void => return il::Type::W32, |
|
| 6658 | 6788 | // FIXME: We shouldn't try to lower this type, it should be behind a pointer. |
|
| 6659 | 6789 | case resolver::Type::Opaque => return il::Type::W32, |
|
| 6660 | 6790 | // FIXME: This should be resolved to a concrete integer type in the resolver. |
lib/std/lang/resolver.rad
+8 -3
| 177 | 177 | /// Eg. `T -> ?T`. Stores the inner value type. |
|
| 178 | 178 | OptionalLift(Type), |
|
| 179 | 179 | /// Wrap return value in success variant of result type. |
|
| 180 | 180 | ResultWrap, |
|
| 181 | 181 | /// Coerce a concrete pointer to a trait object. |
|
| 182 | - | TraitObject(*TraitType), |
|
| 182 | + | TraitObject { |
|
| 183 | + | /// Trait type information. |
|
| 184 | + | traitInfo: *TraitType, |
|
| 185 | + | /// Instance entry for v-table lookup. |
|
| 186 | + | inst: *InstanceEntry, |
|
| 187 | + | }, |
|
| 183 | 188 | } |
|
| 184 | 189 | ||
| 185 | 190 | /// Result of resolving a module path. |
|
| 186 | 191 | record ResolvedModule { |
|
| 187 | 192 | /// Module entry in the graph. |
| 1761 | 1766 | if let case Type::Pointer { target, mutable: rhsMutable } = from { |
|
| 1762 | 1767 | if lhsMutable and not rhsMutable { |
|
| 1763 | 1768 | return nil; |
|
| 1764 | 1769 | } |
|
| 1765 | 1770 | // Look up instance registry. |
|
| 1766 | - | if findInstance(self, traitInfo, *target) != nil { |
|
| 1767 | - | return Coercion::TraitObject(traitInfo); |
|
| 1771 | + | if let inst = findInstance(self, traitInfo, *target) { |
|
| 1772 | + | return Coercion::TraitObject { traitInfo, inst }; |
|
| 1768 | 1773 | } |
|
| 1769 | 1774 | } |
|
| 1770 | 1775 | // Identity: same trait object. |
|
| 1771 | 1776 | if let case Type::TraitObject { traitInfo: rhsTrait, mutable: rhsMutable } = from { |
|
| 1772 | 1777 | if traitInfo != rhsTrait { |