Implement dynamic dispatch through v-tables

a521693edfb448a4a9cf40aa30e8e56e4c8d6c79730e5e181bc9ad74e46917ce
Alexis Sellier committed ago 1 parent e3e79cd2
lib/std/arch/rv64/tests/trait.basic.rad added +32 -0
1 +
//! Basic trait object dispatch test.
2 +
record Counter {
3 +
    value: i32,
4 +
}
5 +
6 +
trait Adder {
7 +
    fn (*mut Adder) add(n: i32) -> i32;
8 +
}
9 +
10 +
instance Adder for Counter {
11 +
    fn (c: *mut Counter) add(n: i32) -> i32 {
12 +
        c.value = c.value + n;
13 +
        return c.value;
14 +
    }
15 +
}
16 +
17 +
@default fn main() -> i32 {
18 +
    let mut c = Counter { value: 10 };
19 +
    let a: *mut opaque Adder = &mut c;
20 +
21 +
    let result = a.add(5);
22 +
    // c.value should be 15 now.
23 +
    if result != 15 {
24 +
        return 1;
25 +
    }
26 +
    let result2 = a.add(3);
27 +
    // c.value should be 18 now.
28 +
    if result2 != 18 {
29 +
        return 2;
30 +
    }
31 +
    return 0;
32 +
}
lib/std/lang/lower.rad +138 -8
4081 4081
    try emitStore(self, dst, SLICE_LEN_OFFSET, resolver::Type::U32, lenVal);
4082 4082
4083 4083
    return il::Val::Reg(dst);
4084 4084
}
4085 4085
4086 +
/// Build a trait object fat pointer from a data pointer and a v-table.
4087 +
fn buildTraitObject(
4088 +
    self: *mut FnLowerer,
4089 +
    dataVal: il::Val,
4090 +
    traitInfo: *resolver::TraitType,
4091 +
    inst: *resolver::InstanceEntry
4092 +
) -> il::Val throws (LowerError) {
4093 +
    let vName = vtableName(self.low, inst.moduleId, inst.concreteTypeName, traitInfo.name);
4094 +
4095 +
    // Reserve space for the trait object on the stack.
4096 +
    let slot = try emitReserveLayout(self, resolver::Layout {
4097 +
        size: resolver::PTR_SIZE * 2,
4098 +
        alignment: resolver::PTR_SIZE,
4099 +
    });
4100 +
4101 +
    // Store data pointer.
4102 +
    emit(self, il::Instr::Store {
4103 +
        typ: il::Type::W64,
4104 +
        src: dataVal,
4105 +
        dst: slot,
4106 +
        offset: TRAIT_OBJ_DATA_OFFSET,
4107 +
    });
4108 +
4109 +
    // Store v-table address.
4110 +
    emit(self, il::Instr::Store {
4111 +
        typ: il::Type::W64,
4112 +
        src: il::Val::DataSym(vName),
4113 +
        dst: slot,
4114 +
        offset: TRAIT_OBJ_VTABLE_OFFSET,
4115 +
    });
4116 +
    return il::Val::Reg(slot);
4117 +
}
4118 +
4086 4119
/// Compute a field pointer by adding a byte offset to a base address.
4087 4120
fn emitPtrOffset(self: *mut FnLowerer, base: il::Reg, offset: i32) -> il::Reg {
4088 4121
    if offset == 0 {
4089 4122
        return base;
4090 4123
    }
6124 6157
    }
6125 6158
}
6126 6159
6127 6160
/// Lower a call expression, which may be a function call or type constructor.
6128 6161
fn lowerCallOrCtor(self: *mut FnLowerer, node: *ast::Node, call: ast::Call) -> il::Val throws (LowerError) {
6162 +
    // Check for trait method dispatch.
6163 +
    if let case resolver::NodeExtra::TraitMethodCall {
6164 +
        traitInfo, methodIndex
6165 +
    } = resolver::nodeData(self.low.resolver, node).extra {
6166 +
        return try lowerTraitMethodCall(self, node, call, traitInfo, methodIndex);
6167 +
    }
6129 6168
    if let sym = resolver::nodeData(self.low.resolver, call.callee).sym {
6130 6169
        if let case resolver::SymbolData::Type(nominal) = sym.data {
6131 6170
            let case resolver::NominalType::Record(_) = *nominal else {
6132 6171
                throw LowerError::ExpectedRecord;
6133 6172
            };
6138 6177
        }
6139 6178
    }
6140 6179
    return try lowerCall(self, node, call);
6141 6180
}
6142 6181
6182 +
/// Lower a trait method call through v-table dispatch.
6183 +
///
6184 +
/// Given `obj.method(args)` where `obj` is a trait object, emits:
6185 +
///
6186 +
///     load w64 %data %obj 0          // data pointer
6187 +
///     load w64 %vtable %obj 8        // v-table pointer
6188 +
///     load w64 %fn %vtable <slot>    // function pointer
6189 +
///     call <retTy> %ret %fn(%data, args...)
6190 +
///
6191 +
fn lowerTraitMethodCall(
6192 +
    self: *mut FnLowerer,
6193 +
    node: *ast::Node,
6194 +
    call: ast::Call,
6195 +
    traitInfo: *resolver::TraitType,
6196 +
    methodIndex: u32
6197 +
) -> il::Val throws (LowerError) {
6198 +
    // Method calls look like field accesses.
6199 +
    let case ast::NodeValue::FieldAccess(access) = call.callee.value
6200 +
        else throw LowerError::MissingMetadata;
6201 +
6202 +
    // Lower the trait object expression.
6203 +
    let traitObjVal = try lowerExpr(self, access.parent);
6204 +
    let traitObjReg = try emitValToReg(self, traitObjVal);
6205 +
6206 +
    // Load data pointer from trait object.
6207 +
    let dataReg = nextReg(self);
6208 +
    emit(self, il::Instr::Load {
6209 +
        typ: il::Type::W64,
6210 +
        dst: dataReg,
6211 +
        src: traitObjReg,
6212 +
        offset: TRAIT_OBJ_DATA_OFFSET,
6213 +
    });
6214 +
6215 +
    // Load v-table pointer from trait object.
6216 +
    let vtableReg = nextReg(self);
6217 +
    emit(self, il::Instr::Load {
6218 +
        typ: il::Type::W64,
6219 +
        dst: vtableReg,
6220 +
        src: traitObjReg,
6221 +
        offset: TRAIT_OBJ_VTABLE_OFFSET,
6222 +
    });
6223 +
6224 +
    // Load function pointer from v-table at the method's slot offset.
6225 +
    let fnPtrReg = nextReg(self);
6226 +
    let slotOffset = (methodIndex * resolver::PTR_SIZE) as i32;
6227 +
6228 +
    emit(self, il::Instr::Load {
6229 +
        typ: il::Type::W64,
6230 +
        dst: fnPtrReg,
6231 +
        src: vtableReg,
6232 +
        offset: slotOffset,
6233 +
    });
6234 +
6235 +
    // Check if the method needs a hidden return parameter.
6236 +
    let methodFnType = traitInfo.methods[methodIndex].fnType;
6237 +
    let retTy = *methodFnType.returnType;
6238 +
    let returnParam = requiresReturnParam(methodFnType);
6239 +
6240 +
    // Build args: data pointer (receiver) + user args.
6241 +
    let argOffset: u32 = 1 if returnParam else 0;
6242 +
    let args = try allocVals(self, call.args.len + 1 + argOffset);
6243 +
6244 +
    if returnParam {
6245 +
        args[0] = il::Val::Reg(try emitReserve(self, retTy));
6246 +
    }
6247 +
    // Data pointer is the receiver (first argument).
6248 +
    args[argOffset] = il::Val::Reg(dataReg);
6249 +
6250 +
    // Lower user arguments.
6251 +
    for i in 0..call.args.len {
6252 +
        args[i + 1 + argOffset] = try lowerExpr(self, call.args.list[i]);
6253 +
    }
6254 +
    // Emit the call. Hidden return param calls always return a pointer.
6255 +
    let callRetTy = il::Type::W64 if returnParam else ilType(self.low, retTy);
6256 +
    let mut dst: ?il::Reg = nil;
6257 +
6258 +
    if returnParam or retTy != resolver::Type::Void {
6259 +
        dst = nextReg(self);
6260 +
    }
6261 +
    emit(self, il::Instr::Call {
6262 +
        retTy: callRetTy,
6263 +
        dst,
6264 +
        func: il::Val::Reg(fnPtrReg),
6265 +
        args,
6266 +
    });
6267 +
6268 +
    if let d = dst {
6269 +
        return il::Val::Reg(d);
6270 +
    }
6271 +
    return il::Val::Undef;
6272 +
}
6273 +
6143 6274
/// Check if a call is to a compiler intrinsic and lower it directly.
6144 6275
fn lowerIntrinsicCall(self: *mut FnLowerer, call: ast::Call) -> ?il::Val throws (LowerError) {
6145 6276
    // Get the callee symbol and check if it's marked as an intrinsic.
6146 6277
    let sym = resolver::nodeData(self.low.resolver, call.callee).sym else {
6147 6278
        // Expressions or function pointers may not have an associated symbol.
6223 6354
        throw LowerError::ExpectedFunction;
6224 6355
    };
6225 6356
    let retTy = resolver::typeFor(self.low.resolver, node) else {
6226 6357
        throw LowerError::MissingType(node);
6227 6358
    };
6228 -
    let isThrowing = fnInfo.throwListLen > 0;
6229 -
    let needsRetBuf = isThrowing or (isAggregateType(retTy) and not isSmallAggregate(retTy));
6359 +
    let returnParam = requiresReturnParam(fnInfo);
6230 6360
6231 6361
    // Lower function value and arguments, reserving an extra slot for the
6232 6362
    // hidden return buffer when needed.
6233 6363
    let callee = try lowerCallee(self, call.callee);
6234 -
    let offset: u32 = 1 if needsRetBuf else 0;
6364 +
    let offset: u32 = 1 if returnParam else 0;
6235 6365
    let args = try allocVals(self, call.args.len + offset);
6236 6366
    for i in 0..call.args.len {
6237 6367
        args[i + offset] = try lowerExpr(self, call.args.list[i]);
6238 6368
    }
6239 6369
6240 6370
    // Allocate the return buffer when needed.
6241 -
    if needsRetBuf {
6242 -
        if isThrowing {
6371 +
    if returnParam {
6372 +
        if fnInfo.throwListLen > 0 {
6243 6373
            let successType = *fnInfo.returnType;
6244 6374
            let layout = resolver::getResultLayout(successType, &fnInfo.throwList[..fnInfo.throwListLen]);
6245 6375
6246 6376
            args[0] = il::Val::Reg(try emitReserveLayout(self, layout));
6247 6377
        } else {
6309 6439
        }
6310 6440
        case resolver::Coercion::ResultWrap => {
6311 6441
            let payloadType = *self.fnType.returnType;
6312 6442
            return try buildResult(self, 0, val, payloadType);
6313 6443
        }
6314 -
        case resolver::Coercion::TraitObject(_) => {
6315 -
            // TODO.
6316 -
            return val;
6444 +
        case resolver::Coercion::TraitObject { traitInfo, inst } => {
6445 +
            return try buildTraitObject(self, val, traitInfo, inst);
6317 6446
        }
6318 6447
        case resolver::Coercion::Identity => return val,
6319 6448
    }
6320 6449
}
6321 6450
6651 6780
                return il::Type::W8;
6652 6781
            }
6653 6782
            return il::Type::W64;
6654 6783
        }
6655 6784
        case resolver::Type::Fn(_) => return il::Type::W64,
6785 +
        case resolver::Type::TraitObject { .. } => return il::Type::W64,
6656 6786
        // Void functions return zero at the IL level.
6657 6787
        case resolver::Type::Void => return il::Type::W32,
6658 6788
        // FIXME: We shouldn't try to lower this type, it should be behind a pointer.
6659 6789
        case resolver::Type::Opaque => return il::Type::W32,
6660 6790
        // FIXME: This should be resolved to a concrete integer type in the resolver.
lib/std/lang/resolver.rad +8 -3
177 177
    /// Eg. `T -> ?T`. Stores the inner value type.
178 178
    OptionalLift(Type),
179 179
    /// Wrap return value in success variant of result type.
180 180
    ResultWrap,
181 181
    /// Coerce a concrete pointer to a trait object.
182 -
    TraitObject(*TraitType),
182 +
    TraitObject {
183 +
        /// Trait type information.
184 +
        traitInfo: *TraitType,
185 +
        /// Instance entry for v-table lookup.
186 +
        inst: *InstanceEntry,
187 +
    },
183 188
}
184 189
185 190
/// Result of resolving a module path.
186 191
record ResolvedModule {
187 192
    /// Module entry in the graph.
1761 1766
            if let case Type::Pointer { target, mutable: rhsMutable } = from {
1762 1767
                if lhsMutable and not rhsMutable {
1763 1768
                    return nil;
1764 1769
                }
1765 1770
                // Look up instance registry.
1766 -
                if findInstance(self, traitInfo, *target) != nil {
1767 -
                    return Coercion::TraitObject(traitInfo);
1771 +
                if let inst = findInstance(self, traitInfo, *target) {
1772 +
                    return Coercion::TraitObject { traitInfo, inst };
1768 1773
                }
1769 1774
            }
1770 1775
            // Identity: same trait object.
1771 1776
            if let case Type::TraitObject { traitInfo: rhsTrait, mutable: rhsMutable } = from {
1772 1777
                if traitInfo != rhsTrait {