Implement trait declarations in resolver

40c5b38c365af48a154f7db1ccc37ceaa07a0b2196a7fc4ef43db78b325084a9
Bind trait names so they are available for forward references. Resolve
trait method signatures.

After this commit, `trait Foo { ... }` is parsed, bound, and its
method signatures are type-checked. But there is no way to use the
trait yet.
Alexis Sellier committed ago 1 parent 7c675193
lib/std/lang/ast.rad +3 -0
3 3
4 4
use std::io;
5 5
use std::debug;
6 6
use std::lang::alloc;
7 7
8 +
/// Maximum number of trait methods.
9 +
pub const MAX_TRAIT_METHODS: u32 = 8;
10 +
8 11
/// Arena for all parser allocations.
9 12
///
10 13
/// Uses a bump allocator for both AST nodes and node pointer arrays.
11 14
pub record NodeArena {
12 15
    /// Bump allocator for all allocations.
lib/std/lang/parser.rad +8 -4
13 13
pub const U32_MAX: u32 = 0xFFFFFFFF;
14 14
/// Maximum representable `u64` value.
15 15
pub const U64_MAX: u64 = 0xFFFFFFFFFFFFFFFF;
16 16
/// Maximum number of fields in a record.
17 17
pub const MAX_RECORD_FIELDS: u32 = 32;
18 -
/// Maximum number of trait methods.
19 -
pub const MAX_TRAIT_METHODS: u32 = 8;
20 18
21 19
/// Maximum number of parser errors before aborting.
22 20
const MAX_ERRORS: u32 = 8;
23 21
24 22
/// Parser error type.
1892 1890
    let mut params = ast::nodeList(p.arena, 8);
1893 1891
1894 1892
    if not check(p, scanner::TokenKind::RParen) {
1895 1893
        loop {
1896 1894
            let param = try parseFnParam(p);
1895 +
            if params.len >= params.list.len {
1896 +
                throw failParsing(p, "too many function parameters");
1897 +
            }
1897 1898
            ast::nodeListPush(&mut params, param);
1898 1899
1899 1900
            if not consume(p, scanner::TokenKind::Comma) {
1900 1901
                break;
1901 1902
            }
2312 2313
{
2313 2314
    try expect(p, scanner::TokenKind::Trait, "expected `trait`");
2314 2315
    let name = try parseIdent(p, "expected trait name");
2315 2316
    try expect(p, scanner::TokenKind::LBrace, "expected `{` after trait name");
2316 2317
2317 -
    let mut methods = ast::nodeList(p.arena, MAX_TRAIT_METHODS);
2318 +
    let mut methods = ast::nodeList(p.arena, ast::MAX_TRAIT_METHODS);
2318 2319
    while not check(p, scanner::TokenKind::RBrace) and
2319 2320
          not check(p, scanner::TokenKind::Eof)
2320 2321
    {
2321 2322
        let method = try parseTraitMethodSig(p);
2322 2323
        ast::nodeListPush(&mut methods, method);
2357 2358
    let traitName = try parseTypePath(p);
2358 2359
    try expect(p, scanner::TokenKind::For, "expected `for` after trait name");
2359 2360
    let targetType = try parseTypePath(p);
2360 2361
    try expect(p, scanner::TokenKind::LBrace, "expected `{` after target type");
2361 2362
2362 -
    let mut methods = ast::nodeList(p.arena, MAX_TRAIT_METHODS);
2363 +
    let mut methods = ast::nodeList(p.arena, ast::MAX_TRAIT_METHODS);
2363 2364
    while not check(p, scanner::TokenKind::RBrace) and
2364 2365
          not check(p, scanner::TokenKind::Eof)
2365 2366
    {
2366 2367
        let method = try parseInstanceMethodDecl(p);
2367 2368
        ast::nodeListPush(&mut methods, method);
2404 2405
    try expect(p, open, listExpectMessage(open));
2405 2406
    let mut items = ast::nodeList(p.arena, 8);
2406 2407
2407 2408
    while not check(p, close) {
2408 2409
        let item = try parseItem(p);
2410 +
        if items.len >= items.list.len {
2411 +
            throw failParsing(p, "too many items in list");
2412 +
        }
2409 2413
        ast::nodeListPush(&mut items, item);
2410 2414
2411 2415
        if not consume(p, scanner::TokenKind::Comma) {
2412 2416
            break;
2413 2417
        }
lib/std/lang/resolver.rad +202 -8
42 42
/// as tags are stored using 8-bits only.
43 43
pub const MAX_UNION_VARIANTS: u32 = 128;
44 44
/// Maximum nesting of loops.
45 45
pub const MAX_LOOP_DEPTH: u32 = 16;
46 46
47 +
/// Trait definition stored in the resolver.
48 +
pub record TraitType {
49 +
    /// Trait name.
50 +
    name: *[u8],
51 +
    /// Method signatures.
52 +
    methods: [TraitMethod; ast::MAX_TRAIT_METHODS],
53 +
    /// Number of methods.
54 +
    methodsLen: u32,
55 +
}
56 +
57 +
/// A single method signature within a trait.
58 +
pub record TraitMethod {
59 +
    /// Method name.
60 +
    name: *[u8],
61 +
    /// Function type for the method, excluding the receiver.
62 +
    fnType: *FnType,
63 +
    /// Whether the receiver is mutable.
64 +
    mutable: bool,
65 +
    /// Vtable slot index.
66 +
    index: u32,
67 +
}
68 +
47 69
/// Identifier for the synthetic `len` field.
48 70
pub const LEN_FIELD: *[u8] = "len";
49 71
/// Identifier for the synthetic `ptr` field.
50 72
pub const PTR_FIELD: *[u8] = "ptr";
51 73
295 317
        /// Module scope.
296 318
        scope: *mut Scope,
297 319
    },
298 320
    /// Payload describing type symbols with their resolved type.
299 321
    Type(*mut NominalType),
322 +
    /// Trait symbol.
323 +
    Trait(*mut TraitType),
300 324
}
301 325
302 326
/// Resolved symbol allocated during semantic analysis.
303 327
pub record Symbol {
304 328
    /// Symbol name in source code.
493 517
    DuplicateMatchPattern,
494 518
    /// `match` has an unreachable `else`: all cases are already handled.
495 519
    UnreachableElse,
496 520
    /// Builtin called with wrong number of arguments.
497 521
    BuiltinArgCountMismatch(CountMismatch),
522 +
    /// Instance method receiver mutability does not match the trait declaration.
523 +
    ReceiverMutabilityMismatch,
524 +
    /// Duplicate instance declaration for the same (trait, type) pair.
525 +
    DuplicateInstance,
526 +
    /// Trait name used as a value expression.
527 +
    UnexpectedTraitName,
528 +
    /// Trait method receiver does not point to the declaring trait.
529 +
    TraitReceiverMismatch,
530 +
    /// Function declaration has too many parameters.
531 +
    FnParamOverflow(CountMismatch),
532 +
    /// Function declaration has too many throws.
533 +
    FnThrowOverflow(CountMismatch),
534 +
    /// Trait declaration has too many methods.
535 +
    TraitMethodOverflow(CountMismatch),
498 536
    /// Internal error.
499 537
    Internal,
500 538
}
501 539
502 540
/// Fixed-capacity diagnostic sink.
2567 2605
                },
2568 2606
                case SymbolData::Module { .. } => {
2569 2607
                    // Module identifiers alone aren't valid expressions.
2570 2608
                    throw emitError(self, node, ErrorKind::UnexpectedModuleName);
2571 2609
                }
2610 +
                case SymbolData::Trait(_) => {
2611 +
                    throw emitError(self, node, ErrorKind::UnexpectedTraitName);
2612 +
                }
2572 2613
            }
2573 2614
        },
2574 2615
        case ast::NodeValue::Super => {
2575 2616
            // `super` by itself is invalid, must be used in scope access.
2576 2617
            throw emitError(self, node, ErrorKind::InvalidModulePath);
2964 3005
        localCount: 0,
2965 3006
    };
2966 3007
    // Enter the function scope to process parameters.
2967 3008
    enterFn(self, node, &fnType);
2968 3009
3010 +
    if decl.sig.params.len > fnType.paramTypes.len {
3011 +
        exitFn(self);
3012 +
        throw emitError(self, node, ErrorKind::FnParamOverflow(CountMismatch {
3013 +
            expected: fnType.paramTypes.len,
3014 +
            actual: decl.sig.params.len,
3015 +
        }));
3016 +
    }
2969 3017
    for i in 0..decl.sig.params.len {
2970 -
        if fnType.paramTypesLen >= fnType.paramTypes.len {
2971 -
            break;
2972 -
        }
2973 3018
        let paramNode = decl.sig.params.list[i];
2974 3019
        let paramTy = try infer(self, paramNode) catch {
2975 3020
            exitFn(self);
2976 3021
            throw ResolveError::Failure;
2977 3022
        };
2978 3023
        fnType.paramTypes[fnType.paramTypesLen] = allocType(self, paramTy);
2979 3024
        fnType.paramTypesLen += 1;
2980 3025
    }
2981 3026
3027 +
    if decl.sig.throwList.len > fnType.throwList.len {
3028 +
        exitFn(self);
3029 +
        throw emitError(self, node, ErrorKind::FnThrowOverflow(CountMismatch {
3030 +
            expected: fnType.throwList.len,
3031 +
            actual: decl.sig.throwList.len,
3032 +
        }));
3033 +
    }
2982 3034
    for i in 0..decl.sig.throwList.len {
2983 -
        if fnType.throwListLen >= fnType.throwList.len {
2984 -
            break;
2985 -
        }
2986 3035
        let throwNode = decl.sig.throwList.list[i];
2987 3036
        let throwTy = try infer(self, throwNode) catch {
2988 3037
            exitFn(self);
2989 3038
            throw ResolveError::Failure;
2990 3039
        };
3144 3193
    let nominalTy = allocNominalType(self, NominalType::Placeholder(node));
3145 3194
3146 3195
    return try bindTypeIdent(self, name, node, nominalTy, attrMask);
3147 3196
}
3148 3197
3198 +
/// Allocate a trait type descriptor and return a pointer to it.
3199 +
fn allocTraitType(self: *mut Resolver, name: *[u8]) -> *mut TraitType {
3200 +
    let p = try! alloc::alloc(&mut self.arena, @sizeOf(TraitType), @alignOf(TraitType));
3201 +
    let entry = p as *mut TraitType;
3202 +
    *entry = TraitType { name, methods: undefined, methodsLen: 0 };
3203 +
3204 +
    return entry;
3205 +
}
3206 +
3207 +
/// Bind a trait name in the current scope.
3208 +
fn bindTraitName(self: *mut Resolver, node: *ast::Node, name: *ast::Node, attrs: ?ast::Attributes) -> *mut Symbol
3209 +
    throws (ResolveError)
3210 +
{
3211 +
    let attrMask = try resolveAttributes(self, attrs);
3212 +
    try ensureDefaultAttrNotAllowed(self, node, attrMask);
3213 +
3214 +
    let traitName = try nodeName(self, name);
3215 +
    let traitType = allocTraitType(self, traitName);
3216 +
    let data = SymbolData::Trait(traitType);
3217 +
    let sym = try bindIdent(self, traitName, node, data, attrMask, self.scope);
3218 +
3219 +
    try setNodeType(self, node, Type::Void);
3220 +
    try setNodeType(self, name, Type::Void);
3221 +
3222 +
    return sym;
3223 +
}
3224 +
3225 +
/// Resolve a trait declaration body, ie. the method signatures.
3226 +
fn resolveTraitBody(self: *mut Resolver, node: *ast::Node, methods: *ast::NodeList)
3227 +
    throws (ResolveError)
3228 +
{
3229 +
    let sym = symbolFor(self, node)
3230 +
        else return;
3231 +
    let case SymbolData::Trait(traitType) = sym.data
3232 +
        else return;
3233 +
3234 +
    if methods.len > traitType.methods.len {
3235 +
        throw emitError(self, node, ErrorKind::TraitMethodOverflow(CountMismatch {
3236 +
            expected: traitType.methods.len,
3237 +
            actual: methods.len,
3238 +
        }));
3239 +
    }
3240 +
    for i in 0..methods.len {
3241 +
        let methodNode = methods.list[i];
3242 +
        let case ast::NodeValue::TraitMethodSig { name, receiver, sig } = methodNode.value
3243 +
            else continue;
3244 +
        let methodName = try nodeName(self, name);
3245 +
3246 +
        // Reject duplicate method names within the same trait.
3247 +
        for j in 0..traitType.methodsLen {
3248 +
            if traitType.methods[j].name == methodName {
3249 +
                throw emitError(self, name, ErrorKind::DuplicateBinding(methodName));
3250 +
            }
3251 +
        }
3252 +
3253 +
        // Determine receiver mutability from the receiver type node
3254 +
        // and validate that the receiver points to the declaring trait.
3255 +
        let case ast::NodeValue::TypeSig(typeSig) = receiver.value
3256 +
            else throw emitError(self, receiver, ErrorKind::TraitReceiverMismatch);
3257 +
        let case ast::TypeSig::Pointer { mutable, valueType } = typeSig
3258 +
            else throw emitError(self, receiver, ErrorKind::TraitReceiverMismatch);
3259 +
        let case ast::NodeValue::TypeSig(innerSig) = valueType.value
3260 +
            else throw emitError(self, receiver, ErrorKind::TraitReceiverMismatch);
3261 +
        let case ast::TypeSig::Nominal(nameNode) = innerSig
3262 +
            else throw emitError(self, receiver, ErrorKind::TraitReceiverMismatch);
3263 +
        let receiverTargetName = try nodeName(self, nameNode);
3264 +
3265 +
        if receiverTargetName != traitType.name {
3266 +
            throw emitError(self, receiver, ErrorKind::TraitReceiverMismatch);
3267 +
        }
3268 +
        // Resolve parameter types and return type.
3269 +
        let mut fnType = FnType {
3270 +
            paramTypes: undefined,
3271 +
            paramTypesLen: 0,
3272 +
            returnType: allocType(self, Type::Void),
3273 +
            throwList: undefined,
3274 +
            throwListLen: 0,
3275 +
            localCount: 0,
3276 +
        };
3277 +
        if sig.params.len > fnType.paramTypes.len {
3278 +
            throw emitError(self, methodNode, ErrorKind::FnParamOverflow(CountMismatch {
3279 +
                expected: fnType.paramTypes.len,
3280 +
                actual: sig.params.len,
3281 +
            }));
3282 +
        }
3283 +
        for j in 0..sig.params.len {
3284 +
            let paramNode = sig.params.list[j];
3285 +
            let paramTy = try infer(self, paramNode);
3286 +
3287 +
            fnType.paramTypes[fnType.paramTypesLen] = allocType(self, paramTy);
3288 +
            fnType.paramTypesLen += 1;
3289 +
        }
3290 +
        if let ret = sig.returnType {
3291 +
            fnType.returnType = allocType(self, try infer(self, ret));
3292 +
        }
3293 +
        // Resolve throws list.
3294 +
        if sig.throwList.len > fnType.throwList.len {
3295 +
            throw emitError(self, methodNode, ErrorKind::FnThrowOverflow(CountMismatch {
3296 +
                expected: fnType.throwList.len,
3297 +
                actual: sig.throwList.len,
3298 +
            }));
3299 +
        }
3300 +
        for j in 0..sig.throwList.len {
3301 +
            let throwNode = sig.throwList.list[j];
3302 +
            let throwTy = try infer(self, throwNode);
3303 +
3304 +
            fnType.throwList[fnType.throwListLen] = allocType(self, throwTy);
3305 +
            fnType.throwListLen += 1;
3306 +
        }
3307 +
3308 +
        traitType.methods[traitType.methodsLen] = TraitMethod {
3309 +
            name: methodName,
3310 +
            fnType: allocFnType(self, fnType),
3311 +
            mutable,
3312 +
            index: traitType.methodsLen,
3313 +
        };
3314 +
        traitType.methodsLen += 1;
3315 +
3316 +
        try setNodeType(self, methodNode, Type::Void);
3317 +
    }
3318 +
}
3319 +
3149 3320
/// Resolve union variant types after all type names are bound (Phase 2 of type resolution).
3150 3321
fn resolveUnionBody(self: *mut Resolver, node: *ast::Node, decl: ast::UnionDecl)
3151 3322
    throws (ResolveError)
3152 3323
{
3153 3324
    // Get the type symbol that was bound to this declaration node.
4606 4777
            return try setNodeType(self, node, ty);
4607 4778
        }
4608 4779
        case SymbolData::Module { .. } => {
4609 4780
            throw emitError(self, node, ErrorKind::UnexpectedModuleName);
4610 4781
        }
4782 +
        case SymbolData::Trait(_) => { // Trait names are not values.
4783 +
            throw emitError(self, node, ErrorKind::UnexpectedTraitName);
4784 +
        }
4611 4785
    }
4612 4786
    return try setNodeType(self, node, ty);
4613 4787
}
4614 4788
4615 4789
/// Analyze a field access expression.
5375 5549
                returnType: allocType(self, Type::Void),
5376 5550
                throwList: undefined,
5377 5551
                throwListLen: 0,
5378 5552
                localCount: 0,
5379 5553
            };
5380 -
            debug::check(t.params.len <= fnType.paramTypes.len);
5381 -
            debug::check(t.throwList.len <= fnType.throwList.len);
5554 +
            if t.params.len > fnType.paramTypes.len {
5555 +
                throw emitError(self, node, ErrorKind::FnParamOverflow(CountMismatch {
5556 +
                    expected: fnType.paramTypes.len,
5557 +
                    actual: t.params.len,
5558 +
                }));
5559 +
            }
5560 +
            if t.throwList.len > fnType.throwList.len {
5561 +
                throw emitError(self, node, ErrorKind::FnThrowOverflow(CountMismatch {
5562 +
                    expected: fnType.throwList.len,
5563 +
                    actual: t.throwList.len,
5564 +
                }));
5565 +
            }
5382 5566
5383 5567
            for i in 0..t.params.len {
5384 5568
                let paramTy = try infer(self, t.params.list[i]);
5385 5569
5386 5570
                fnType.paramTypes[fnType.paramTypesLen] = allocType(self, paramTy);
5493 5677
            case ast::NodeValue::UnionDecl(decl) => {
5494 5678
                if symbolFor(self, node) == nil {
5495 5679
                    try bindTypeName(self, node, decl.name, decl.attrs) catch {};
5496 5680
                }
5497 5681
            }
5682 +
            case ast::NodeValue::TraitDecl { name, attrs, .. } => {
5683 +
                if symbolFor(self, node) == nil {
5684 +
                    try bindTraitName(self, node, name, attrs) catch {};
5685 +
                }
5686 +
            }
5498 5687
            else => {}
5499 5688
        }
5500 5689
    }
5501 5690
}
5502 5691
5513 5702
            case ast::NodeValue::UnionDecl(decl) => {
5514 5703
                try resolveUnionBody(self, node, decl) catch {
5515 5704
                    // Continue resolving other types even if one fails.
5516 5705
                };
5517 5706
            }
5707 +
            case ast::NodeValue::TraitDecl { methods, .. } => {
5708 +
                try resolveTraitBody(self, node, &methods) catch {
5709 +
                    // Continue resolving other types even if one fails.
5710 +
                };
5711 +
            }
5518 5712
            else => {
5519 5713
                // Ignore other declarations.
5520 5714
            }
5521 5715
        }
5522 5716
    }
lib/std/lang/resolver/printer.rad +38 -0
242 242
            io::print("type");
243 243
        }
244 244
        case super::SymbolData::Module { .. } => {
245 245
            io::print("module");
246 246
        }
247 +
        case super::SymbolData::Trait(_) => {
248 +
            io::print("trait");
249 +
        }
247 250
    }
248 251
249 252
    match symbol.data {
250 253
        case super::SymbolData::Value { mutable, type, .. } => {
251 254
            if mutable {
275 278
            // TODO
276 279
        }
277 280
        case super::SymbolData::Module { .. } => {
278 281
            // TODO
279 282
        }
283 +
        case super::SymbolData::Trait(_) => {
284 +
            io::print(" ");
285 +
            io::print(symbol.name);
286 +
            io::print("\n");
287 +
        }
280 288
    }
281 289
}
282 290
283 291
/// Print a single diagnostic entry.
284 292
fn printError(err: *super::Error, res: *super::Resolver) {
534 542
        case super::ErrorKind::UnionVariantPayloadUnexpected(name) => {
535 543
            io::print("union variant '");
536 544
            io::print(name);
537 545
            io::print("' does not expect a payload");
538 546
        }
547 +
        case super::ErrorKind::ReceiverMutabilityMismatch => {
548 +
            io::print("instance receiver mutability does not match trait declaration");
549 +
        }
550 +
        case super::ErrorKind::DuplicateInstance => {
551 +
            io::print("duplicate instance declaration for the same trait and type");
552 +
        }
553 +
        case super::ErrorKind::UnexpectedTraitName => {
554 +
            io::print("trait name cannot be used as a value");
555 +
        }
556 +
        case super::ErrorKind::TraitReceiverMismatch => {
557 +
            io::print("trait method receiver must be a pointer to the declaring trait");
558 +
        }
559 +
        case super::ErrorKind::FnParamOverflow(mismatch) => {
560 +
            io::print("too many function parameters: maximum ");
561 +
            io::printU32(mismatch.expected);
562 +
            io::print(", got ");
563 +
            io::printU32(mismatch.actual);
564 +
        }
565 +
        case super::ErrorKind::FnThrowOverflow(mismatch) => {
566 +
            io::print("too many function throws: maximum ");
567 +
            io::printU32(mismatch.expected);
568 +
            io::print(", got ");
569 +
            io::printU32(mismatch.actual);
570 +
        }
571 +
        case super::ErrorKind::TraitMethodOverflow(mismatch) => {
572 +
            io::print("too many trait methods: maximum ");
573 +
            io::printU32(mismatch.expected);
574 +
            io::print(", got ");
575 +
            io::printU32(mismatch.actual);
576 +
        }
539 577
        case super::ErrorKind::Internal => {
540 578
            io::print("internal compiler error");
541 579
        }
542 580
        case super::ErrorKind::RecordFieldOutOfOrder { .. } => {
543 581
            io::print("record field out of order");