Implement trait method dispatch in resolver

37fa61159c53b23be15e735453f0cd92b3c2b1eb77063550f12bb795e917a739
Alexis Sellier committed ago 1 parent 0125f9b7
lib/std/lang/lower.rad +4 -0
6121 6121
        }
6122 6122
        case resolver::Coercion::ResultWrap => {
6123 6123
            let payloadType = *self.fnType.returnType;
6124 6124
            return try buildResult(self, 0, val, payloadType);
6125 6125
        }
6126 +
        case resolver::Coercion::TraitObject(_) => {
6127 +
            // TODO.
6128 +
            return val;
6129 +
        }
6126 6130
        case resolver::Coercion::Identity => return val,
6127 6131
    }
6128 6132
}
6129 6133
6130 6134
/// Lower an implicit numeric cast coercion.
lib/std/lang/resolver.rad +104 -63
176 176
    NumericCast { from: Type, to: Type },
177 177
    /// Eg. `T -> ?T`. Stores the inner value type.
178 178
    OptionalLift(Type),
179 179
    /// Wrap return value in success variant of result type.
180 180
    ResultWrap,
181 +
    /// Coerce a concrete pointer to a trait object.
182 +
    TraitObject(*TraitType),
181 183
}
182 184
183 185
/// Result of resolving a module path.
184 186
record ResolvedModule {
185 187
    /// Module entry in the graph.
611 613
    MatchProng { catchAll: bool },
612 614
    /// Match expression metadata.
613 615
    Match { isConst: bool },
614 616
    /// For-loop iteration metadata.
615 617
    ForLoop(ForLoopInfo),
618 +
    /// Trait method call metadata.
619 +
    TraitMethodCall {
620 +
        /// Trait definition.
621 +
        traitInfo: *TraitType,
622 +
        /// Method index in the v-table.
623 +
        methodIndex: u32,
624 +
    },
616 625
}
617 626
618 627
/// Combined resolver metadata for a single AST node.
619 628
pub record NodeData {
620 629
    /// Resolved type for this node.
1143 1152
/// Associate union variant metadata with a pattern or constructor node.
1144 1153
fn setVariantInfo(self: *mut Resolver, node: *ast::Node, ordinal: u32, tag: u32) {
1145 1154
    self.nodeData.entries[node.id].extra = NodeExtra::UnionVariant { ordinal, tag };
1146 1155
}
1147 1156
1157 +
/// Associate trait method call metadata with a call node.
1158 +
fn setTraitMethodCall(self: *mut Resolver, node: *ast::Node, traitInfo: *TraitType, methodIndex: u32) {
1159 +
    self.nodeData.entries[node.id].extra = NodeExtra::TraitMethodCall { traitInfo, methodIndex };
1160 +
}
1161 +
1148 1162
/// Associate for-loop metadata with a for-loop node.
1149 1163
fn setForLoopInfo(self: *mut Resolver, node: *ast::Node, info: ForLoopInfo)
1150 1164
    throws (ResolveError)
1151 1165
{
1152 1166
    self.nodeData.entries[node.id].extra = NodeExtra::ForLoop(info);
1740 1754
            if let case Type::Optional(fromInner) = from {
1741 1755
                return isAssignable(self, *inner, *fromInner, rval);
1742 1756
            }
1743 1757
            return nil;
1744 1758
        }
1759 +
        case Type::TraitObject { traitInfo, mutable: lhsMutable } => {
1760 +
            // Coerce `*T` or `*mut T` where `T` implements the trait.
1761 +
            if let case Type::Pointer { target, mutable: rhsMutable } = from {
1762 +
                if lhsMutable and not rhsMutable {
1763 +
                    return nil;
1764 +
                }
1765 +
                // Look up instance registry.
1766 +
                if findInstance(self, traitInfo, *target) != nil {
1767 +
                    return Coercion::TraitObject(traitInfo);
1768 +
                }
1769 +
            }
1770 +
            // Identity: same trait object.
1771 +
            if let case Type::TraitObject { traitInfo: rhsTrait, mutable: rhsMutable } = from {
1772 +
                if traitInfo != rhsTrait {
1773 +
                    return nil;
1774 +
                }
1775 +
                if lhsMutable and not rhsMutable {
1776 +
                    return nil;
1777 +
                }
1778 +
                return Coercion::Identity;
1779 +
            }
1780 +
            return nil;
1781 +
        }
1745 1782
        case Type::Slice { item: lhsItem, mutable: lhsMutable } => {
1746 1783
            match from {
1747 1784
                case Type::Slice { item: rhsItem, mutable: rhsMutable } => {
1748 1785
                    if lhsMutable and not rhsMutable {
1749 1786
                        return nil;
1788 1825
        }
1789 1826
    }
1790 1827
    return nil;
1791 1828
}
1792 1829
1793 -
/// Check if two nominal type descriptors are structurally equivalent.
1794 -
fn nominalTypeEqual(a: *NominalType, b: *NominalType) -> bool {
1795 -
    match *a {
1796 -
        case NominalType::Record(aRecord) => {
1797 -
            let case NominalType::Record(bRecord) = *b else return false;
1798 -
            if aRecord.fieldsLen != bRecord.fieldsLen {
1799 -
                return false;
1800 -
            }
1801 -
            for i in 0..aRecord.fieldsLen {
1802 -
                let aField = aRecord.fields[i];
1803 -
                let bField = bRecord.fields[i];
1804 -
                if aField.name != bField.name {
1805 -
                    return false;
1806 -
                }
1807 -
                if not typesEqual(aField.fieldType, bField.fieldType) {
1808 -
                    return false;
1809 -
                }
1810 -
            }
1811 -
            return true;
1812 -
        }
1813 -
        else => return false,
1814 -
    }
1815 -
}
1816 -
1817 1830
/// Check if two function type descriptors are structurally equivalent.
1818 1831
fn fnTypeEqual(a: *FnType, b: *FnType) -> bool {
1819 1832
    if a.paramTypesLen != b.paramTypesLen {
1820 1833
        return false;
1821 1834
    }
3263 3276
    try setNodeType(self, name, Type::Void);
3264 3277
3265 3278
    return sym;
3266 3279
}
3267 3280
3281 +
/// Find a trait method by name.
3282 +
fn findTraitMethod(traitType: *TraitType, name: *[u8]) -> ?*TraitMethod {
3283 +
    for i in 0..traitType.methodsLen {
3284 +
        if traitType.methods[i].name == name {
3285 +
            return &traitType.methods[i];
3286 +
        }
3287 +
    }
3288 +
    return nil;
3289 +
}
3290 +
3268 3291
/// Resolve a trait declaration body, ie. the method signatures.
3269 3292
fn resolveTraitBody(self: *mut Resolver, node: *ast::Node, methods: *ast::NodeList)
3270 3293
    throws (ResolveError)
3271 3294
{
3272 3295
    let sym = symbolFor(self, node)
3285 3308
        let case ast::NodeValue::TraitMethodSig { name, receiver, sig } = methodNode.value
3286 3309
            else continue;
3287 3310
        let methodName = try nodeName(self, name);
3288 3311
3289 3312
        // Reject duplicate method names within the same trait.
3290 -
        for j in 0..traitType.methodsLen {
3291 -
            if traitType.methods[j].name == methodName {
3292 -
                throw emitError(self, name, ErrorKind::DuplicateBinding(methodName));
3293 -
            }
3313 +
        if let _ = findTraitMethod(traitType, methodName) {
3314 +
            throw emitError(self, name, ErrorKind::DuplicateBinding(methodName));
3294 3315
        }
3295 3316
3296 3317
        // Determine receiver mutability from the receiver type node
3297 3318
        // and validate that the receiver points to the declaring trait.
3298 3319
        let case ast::NodeValue::TypeSig(typeSig) = receiver.value
3434 3455
        } = methodNode.value else continue;
3435 3456
3436 3457
        let methodName = try nodeName(self, name);
3437 3458
3438 3459
        // Find the matching trait method.
3439 -
        let mut traitMethod: ?*TraitMethod = nil;
3440 -
        for j in 0..traitInfo.methodsLen {
3441 -
            if traitInfo.methods[j].name == methodName {
3442 -
                traitMethod = &traitInfo.methods[j];
3443 -
                break;
3444 -
            }
3445 -
        }
3446 -
        let tm = traitMethod
3460 +
        let tm = findTraitMethod(traitInfo, methodName)
3447 3461
            else throw emitError(self, name, ErrorKind::UnresolvedSymbol(methodName));
3448 3462
3449 3463
        // Determine receiver mutability and validate receiver type.
3450 3464
        let mut receiverMut = false;
3451 3465
        if let case ast::NodeValue::TypeSig(typeSig) = receiverType.value {
4606 4620
        negative: false,
4607 4621
    }));
4608 4622
    return try setNodeType(self, node, Type::U32);
4609 4623
}
4610 4624
4625 +
/// Validate call arguments against a function type: check argument count,
4626 +
/// type-check each argument, and verify that throwing functions use `try`.
4627 +
fn checkCallArgs(self: *mut Resolver, node: *ast::Node, call: ast::Call, info: *FnType, ctx: CallCtx)
4628 +
    throws (ResolveError)
4629 +
{
4630 +
    if ctx == CallCtx::Normal and info.throwListLen > 0 {
4631 +
        throw emitError(self, node, ErrorKind::MissingTry);
4632 +
    }
4633 +
    if call.args.len != info.paramTypesLen {
4634 +
        throw emitError(self, node, ErrorKind::FnArgCountMismatch(CountMismatch {
4635 +
            expected: info.paramTypesLen,
4636 +
            actual: call.args.len,
4637 +
        }));
4638 +
    }
4639 +
    for i in 0..call.args.len {
4640 +
        debug::check(i < MAX_FN_PARAMS);
4641 +
4642 +
        let argNode = call.args.list[i];
4643 +
        let expectedTy = *info.paramTypes[i];
4644 +
4645 +
        try checkAssignable(self, argNode, expectedTy);
4646 +
    }
4647 +
}
4648 +
4611 4649
/// Analyze a function call expression.
4612 4650
fn resolveCall(self: *mut Resolver, node: *ast::Node, call: ast::Call, ctx: CallCtx) -> Type
4613 4651
    throws (ResolveError)
4614 4652
{
4615 4653
    let calleeTy = try infer(self, call.callee);
4634 4672
                    return try resolveRecordConstructorCall(self, node, call, ty);
4635 4673
                }
4636 4674
            }
4637 4675
        }
4638 4676
    }
4677 +
4678 +
    // Check if we have a trait method call, ie. callee is a trait object.
4679 +
    if let case ast::NodeValue::FieldAccess(access) = call.callee.value {
4680 +
        let mut parentTy = Type::Unknown;
4681 +
        if let t = typeFor(self, access.parent) {
4682 +
            parentTy = t;
4683 +
        }
4684 +
        let subjectTy = autoDeref(parentTy);
4685 +
4686 +
        if let case Type::TraitObject { traitInfo, .. } = subjectTy {
4687 +
            let methodName = try nodeName(self, access.child);
4688 +
            let method = findTraitMethod(traitInfo, methodName)
4689 +
                else throw emitError(self, access.child, ErrorKind::RecordFieldUnknown(methodName));
4690 +
4691 +
            try checkCallArgs(self, node, call, method.fnType, ctx);
4692 +
            setTraitMethodCall(self, node, traitInfo, method.index);
4693 +
4694 +
            return try setNodeType(self, node, *method.fnType.returnType);
4695 +
        }
4696 +
    }
4639 4697
    let case Type::Fn(info) = calleeTy else {
4640 4698
        // TODO: Emit type error.
4641 4699
        panic;
4642 4700
    };
4643 -
    // For normal calls, make sure we're not calling a function
4644 -
    // that throws.
4645 -
    if ctx == CallCtx::Normal {
4646 -
        if info.throwListLen > 0 {
4647 -
            throw emitError(self, node, ErrorKind::MissingTry);
4648 -
        }
4649 -
    }
4650 -
    if call.args.len != info.paramTypesLen {
4651 -
        throw emitError(self, node, ErrorKind::FnArgCountMismatch(CountMismatch {
4652 -
            expected: info.paramTypesLen,
4653 -
            actual: call.args.len,
4654 -
        }));
4655 -
    }
4656 -
    // TODO: Check what happens when we exceed `MAX_FN_PARAMS`.
4657 -
    for i in 0..call.args.len {
4658 -
        debug::check(i < MAX_FN_PARAMS);
4659 -
4660 -
        let argNode = call.args.list[i];
4661 -
        let expectedTy = *info.paramTypes[i];
4662 -
        try checkAssignable(self, argNode, expectedTy);
4663 -
    }
4701 +
    try checkCallArgs(self, node, call, info, ctx);
4664 4702
    // Associate function type to callee.
4665 4703
    try setNodeType(self, call.callee, calleeTy);
4704 +
4666 4705
    // Associate return type to call.
4667 4706
    return try setNodeType(self, node, *info.returnType);
4668 4707
}
4669 4708
4670 4709
/// Analyze an assignment expression.
5134 5173
                try setRecordFieldIndex(self, fieldNode, 1);
5135 5174
                return try setNodeType(self, node, Type::U32);
5136 5175
            }
5137 5176
            throw emitError(self, node, ErrorKind::SliceFieldUnknown(fieldName));
5138 5177
        }
5178 +
        case Type::TraitObject { traitInfo, .. } => {
5179 +
            let fieldName = try nodeName(self, access.child);
5180 +
            let method = findTraitMethod(traitInfo, fieldName)
5181 +
                else throw emitError(self, node, ErrorKind::RecordFieldUnknown(fieldName));
5182 +
5183 +
            return try setNodeType(self, node, Type::Fn(method.fnType));
5184 +
        }
5139 5185
        else => {}
5140 5186
    }
5187 +
    // FIXME: We can't move this to the `else` branch due to a resolver bug.
5141 5188
    throw emitError(self, access.parent, ErrorKind::ExpectedRecord);
5142 5189
}
5143 5190
5144 5191
/// Determine whether an expression can yield a mutable location for borrowing.
5145 5192
fn canBorrowMutFrom(self: *mut Resolver, node: *ast::Node) -> bool
5903 5950
        case Type::Pointer { target, .. } => return isTypeInferrable(*target),
5904 5951
        else => return true,
5905 5952
    }
5906 5953
}
5907 5954
5908 -
/// Analyze any node.
5909 -
pub fn resolveNode(res: *mut Resolver, node: *ast::Node) -> Diagnostics throws (ResolveError) {
5910 -
    try infer(res, node);
5911 -
    return Diagnostics { errors: res.errors };
5912 -
}
5913 -
5914 5955
/// Analyze a standalone expression by wrapping it in a synthetic function.
5915 5956
pub fn resolveExpr(
5916 5957
    self: *mut Resolver, expr: *ast::Node, arena: *mut ast::NodeArena
5917 5958
) -> Diagnostics throws (ResolveError) {
5918 5959
    let mut bodyStmts = ast::nodeList(arena, 1);