Simplify some of the compiler code

79cea13d7a6e2e7217d3bc2475a4bea2c67e7248848e2b7982f916494324717e
Just some factoring out to reduce code size.
Alexis Sellier committed ago 1 parent 0d8a2d4f
lib/std/lang/lower.rad +20 -92
693 693
        fnCount: 0,
694 694
        fnSyms,
695 695
        fnSymsLen: 0,
696 696
        options: LowerOptions { debug: false, buildTest: false },
697 697
    };
698 -
    let case ast::NodeValue::Block(block) = root.value else {
699 -
        throw LowerError::ExpectedBlock(root);
700 -
    };
701 -
    let stmtsList = block.statements.list;
702 -
    let stmtsLen = block.statements.len;
703 -
    let mut defaultFnIdx: ?u32 = nil;
698 +
    let defaultFnIdx = try lowerDecls(&mut low, root, true);
704 699
705 -
    for i in 0..stmtsLen {
706 -
        let node = stmtsList[i];
707 -
        match node.value {
708 -
            case ast::NodeValue::FnDecl(decl) => {
709 -
                if let f = try lowerFnDecl(&mut low, node, decl) {
710 -
                    if checkAttr(decl.attrs, ast::Attribute::Default) {
711 -
                        defaultFnIdx = low.fnCount;
712 -
                    }
713 -
                    try pushFn(&mut low, f);
714 -
                }
715 -
            }
716 -
            case ast::NodeValue::ConstDecl(decl) => {
717 -
                try lowerDataDecl(&mut low, node, decl.value, true);
718 -
            }
719 -
            case ast::NodeValue::StaticDecl(decl) => {
720 -
                try lowerDataDecl(&mut low, node, decl.value, false);
721 -
            }
722 -
            else => {},
723 -
        }
724 -
    }
725 700
    return il::Program {
726 701
        data: low.data[..low.dataCount],
727 702
        fns: low.fns[..low.fnCount],
728 703
        defaultFnIdx,
729 704
    };
776 751
    moduleId: u16,
777 752
    root: *ast::Node,
778 753
    isRoot: bool
779 754
) -> ?u32 throws (LowerError) {
780 755
    low.currentMod = moduleId;
756 +
    return try lowerDecls(low, root, isRoot);
757 +
}
781 758
759 +
/// Lower all top-level declarations in a block.
760 +
fn lowerDecls(low: *mut Lowerer, root: *ast::Node, isRoot: bool) -> ?u32 throws (LowerError) {
782 761
    let case ast::NodeValue::Block(block) = root.value else {
783 762
        throw LowerError::ExpectedBlock(root);
784 763
    };
785 764
    let stmtsList = block.statements.list;
786 765
    let stmtsLen = block.statements.len;
787 766
    let mut defaultFnIdx: ?u32 = nil;
788 767
789 -
    // Lower all declarations. Cross-module calls may use the fallback path
790 -
    // if the callee's module hasn't been lowered yet.
791 768
    for i in 0..stmtsLen {
792 769
        let node = stmtsList[i];
793 770
        match node.value {
794 771
            case ast::NodeValue::FnDecl(decl) => {
795 772
                if let f = try lowerFnDecl(low, node, decl) {
796 -
                    // Only check `@default` in root module.
797 773
                    if isRoot and checkAttr(decl.attrs, ast::Attribute::Default) {
798 774
                        defaultFnIdx = low.fnCount;
799 775
                    }
800 776
                    try pushFn(low, f);
801 777
                }
2956 2932
    }
2957 2933
    let val = try lowerExpr(self, opt);
2958 2934
    let cmpReg = try optionalNilReg(self, val, optTy);
2959 2935
2960 2936
    if isEq {
2961 -
        return emitEqW8(self, il::Val::Reg(cmpReg), il::Val::Imm(0));
2937 +
        return emitTypedBinOp(self, il::BinOp::Eq, il::Type::W8, il::Val::Reg(cmpReg), il::Val::Imm(0));
2962 2938
    } else {
2963 -
        return emitNeW8(self, il::Val::Reg(cmpReg), il::Val::Imm(0));
2939 +
        return emitTypedBinOp(self, il::BinOp::Ne, il::Type::W8, il::Val::Reg(cmpReg), il::Val::Imm(0));
2964 2940
    }
2965 2941
}
2966 2942
2967 2943
/// Load the payload value from a tagged value aggregate at the given offset.
2968 2944
fn tvalPayloadVal(self: *mut FnLowerer, base: il::Reg, payload: resolver::Type, valOffset: i32) -> il::Val {
3860 3836
        b: il::Val::Reg(offset)
3861 3837
    });
3862 3838
    return dst;
3863 3839
}
3864 3840
3865 -
/// Emit an 8-bit equality comparison.
3866 -
fn emitEqW8(self: *mut FnLowerer, a: il::Val, b: il::Val) -> il::Val {
3867 -
    let dst = nextReg(self);
3868 -
    emit(self, il::Instr::BinOp { op: il::BinOp::Eq, typ: il::Type::W8, dst, a, b });
3869 -
    return il::Val::Reg(dst);
3870 -
}
3871 -
3872 -
/// Emit an 8-bit inequality comparison.
3873 -
fn emitNeW8(self: *mut FnLowerer, a: il::Val, b: il::Val) -> il::Val {
3874 -
    let dst = nextReg(self);
3875 -
    emit(self, il::Instr::BinOp { op: il::BinOp::Ne, typ: il::Type::W8, dst, a, b });
3876 -
    return il::Val::Reg(dst);
3877 -
}
3878 -
3879 -
/// Emit a 32-bit equality comparison.
3880 -
fn emitEqW32(self: *mut FnLowerer, a: il::Val, b: il::Val) -> il::Val {
3881 -
    let dst = nextReg(self);
3882 -
    emit(self, il::Instr::BinOp { op: il::BinOp::Eq, typ: il::Type::W32, dst, a, b });
3883 -
    return il::Val::Reg(dst);
3884 -
}
3885 -
3886 -
/// Emit a 32-bit inequality comparison.
3887 -
fn emitNeW32(self: *mut FnLowerer, a: il::Val, b: il::Val) -> il::Val {
3888 -
    let dst = nextReg(self);
3889 -
    emit(self, il::Instr::BinOp { op: il::BinOp::Ne, typ: il::Type::W32, dst, a, b });
3890 -
    return il::Val::Reg(dst);
3891 -
}
3892 -
3893 -
/// Emit a 64-bit equality comparison.
3894 -
fn emitEqW64(self: *mut FnLowerer, a: il::Val, b: il::Val) -> il::Val {
3841 +
/// Emit a typed binary operation, returning the result as a value.
3842 +
fn emitTypedBinOp(self: *mut FnLowerer, op: il::BinOp, typ: il::Type, a: il::Val, b: il::Val) -> il::Val {
3895 3843
    let dst = nextReg(self);
3896 -
    emit(self, il::Instr::BinOp { op: il::BinOp::Eq, typ: il::Type::W64, dst, a, b });
3897 -
    return il::Val::Reg(dst);
3898 -
}
3899 -
3900 -
/// Emit a 64-bit inequality comparison.
3901 -
fn emitNeW64(self: *mut FnLowerer, a: il::Val, b: il::Val) -> il::Val {
3902 -
    let dst = nextReg(self);
3903 -
    emit(self, il::Instr::BinOp { op: il::BinOp::Ne, typ: il::Type::W64, dst, a, b });
3844 +
    emit(self, il::Instr::BinOp { op, typ, dst, a, b });
3904 3845
    return il::Val::Reg(dst);
3905 3846
}
3906 3847
3907 3848
/// Emit a tag comparison for void variant equality/inequality.
3908 3849
fn emitTagCmp(self: *mut FnLowerer, op: ast::BinaryOp, val: il::Val, tagIdx: i32, valType: resolver::Type) -> il::Val
3916 3857
        tag = il::Val::Reg(reg);
3917 3858
    } else {
3918 3859
        tag = loadTag(self, reg, TVAL_TAG_OFFSET, il::Type::W8);
3919 3860
    }
3920 3861
    if op == ast::BinaryOp::Eq {
3921 -
        return emitEqW8(self, tag, il::Val::Imm(tagIdx));
3862 +
        return emitTypedBinOp(self, il::BinOp::Eq, il::Type::W8, tag, il::Val::Imm(tagIdx));
3922 3863
    }
3923 -
    return emitNeW8(self, tag, il::Val::Imm(tagIdx));
3924 -
}
3925 -
3926 -
/// Emit a 32-bit bitwise AND.
3927 -
fn emitAndW32(self: *mut FnLowerer, a: il::Val, b: il::Val) -> il::Val {
3928 -
    let dst = nextReg(self);
3929 -
    emit(self, il::Instr::BinOp { op: il::BinOp::And, typ: il::Type::W32, dst, a, b });
3930 -
    return il::Val::Reg(dst);
3931 -
}
3932 -
3933 -
/// Emit a 32-bit bitwise OR.
3934 -
fn emitOrW32(self: *mut FnLowerer, a: il::Val, b: il::Val) -> il::Val {
3935 -
    let dst = nextReg(self);
3936 -
    emit(self, il::Instr::BinOp { op: il::BinOp::Or, typ: il::Type::W32, dst, a, b });
3937 -
    return il::Val::Reg(dst);
3864 +
    return emitTypedBinOp(self, il::BinOp::Ne, il::Type::W8, tag, il::Val::Imm(tagIdx));
3938 3865
}
3939 3866
3940 3867
/// Logical "and" between two values. Returns the result in a register.
3941 3868
fn emitLogicalAnd(self: *mut FnLowerer, left: ?il::Val, right: il::Val) -> il::Val {
3942 3869
    let prev = left else {
3943 3870
        return right;
3944 3871
    };
3945 -
    return emitAndW32(self, prev, right);
3872 +
    return emitTypedBinOp(self, il::BinOp::And, il::Type::W32, prev, right);
3946 3873
}
3947 3874
3948 3875
//////////////////////////
3949 3876
// Aggregate Comparison //
3950 3877
//////////////////////////
4003 3930
) -> il::Val throws (LowerError) {
4004 3931
    let ptrTy = resolver::Type::Pointer { target: elemTy, mutable };
4005 3932
    let ptrEq = try emitEqAtOffset(self, a, b, offset + SLICE_PTR_OFFSET, ptrTy);
4006 3933
    let lenEq = try emitEqAtOffset(self, a, b, offset + SLICE_LEN_OFFSET, resolver::Type::U32);
4007 3934
4008 -
    return emitAndW32(self, ptrEq, lenEq);
3935 +
    return emitTypedBinOp(self, il::BinOp::And, il::Type::W32, ptrEq, lenEq);
4009 3936
}
4010 3937
4011 3938
/// Check if a type may contain uninitialized payload bytes when used as the
4012 3939
/// inner type of a `nil` optional. Unions with non-void variants and nested
4013 3940
/// optionals fall into this category; records and primitives do not.
4046 3973
    let tagB = loadTag(self, b, offset + TVAL_TAG_OFFSET, il::Type::W8);
4047 3974
4048 3975
    // For simple inner types (no unions/nested optionals), use branchless comparison.
4049 3976
    // TODO: Inline this function.
4050 3977
    if not needsGuardedPayloadCmp(inner) {
4051 -
        let tagEq = emitEqW8(self, tagA, tagB);
4052 -
        let tagNil = emitEqW8(self, tagA, il::Val::Imm(0));
3978 +
        let tagEq = emitTypedBinOp(self, il::BinOp::Eq, il::Type::W8, tagA, tagB);
3979 +
        let tagNil = emitTypedBinOp(self, il::BinOp::Eq, il::Type::W8, tagA, il::Val::Imm(0));
4053 3980
        let payloadEq = try emitEqAtOffset(self, a, b, offset + valOffset, inner);
4054 3981
4055 -
        return emitAndW32(self, tagEq, emitOrW32(self, tagNil, payloadEq));
3982 +
        return emitTypedBinOp(self, il::BinOp::And, il::Type::W32, tagEq,
3983 +
            emitTypedBinOp(self, il::BinOp::Or, il::Type::W32, tagNil, payloadEq));
4056 3984
    }
4057 3985
4058 3986
    // For complex inner types, use branching comparison to avoid inspecting
4059 3987
    // uninitialized payload bytes.
4060 3988
    let resultReg = nextReg(self);
4128 4056
    let tagA = loadTag(self, a, offset + TVAL_TAG_OFFSET, il::Type::W8);
4129 4057
    let tagB = loadTag(self, b, offset + TVAL_TAG_OFFSET, il::Type::W8);
4130 4058
4131 4059
    // Fast path: all-void union just needs tag comparison.
4132 4060
    if unionInfo.isAllVoid {
4133 -
        return emitEqW8(self, tagA, tagB);
4061 +
        return emitTypedBinOp(self, il::BinOp::Eq, il::Type::W8, tagA, tagB);
4134 4062
    }
4135 4063
    // Holds the equality result.
4136 4064
    let resultReg = nextReg(self);
4137 4065
4138 4066
    // Where control flow continues after equality check is done. Receives
5411 5339
    let regA = try emitValToReg(self, a);
5412 5340
    let regB = try emitValToReg(self, b);
5413 5341
    let result = try lowerAggregateEq(self, typ, regA, regB, 0);
5414 5342
5415 5343
    if op == ast::BinaryOp::Ne {
5416 -
        return emitEqW32(self, result, il::Val::Imm(0));
5344 +
        return emitTypedBinOp(self, il::BinOp::Eq, il::Type::W32, result, il::Val::Imm(0));
5417 5345
    }
5418 5346
    return result;
5419 5347
}
5420 5348
5421 5349
/// Emit a scalar binary operation instruction.
lib/std/lang/resolver.rad +19 -33
2400 2400
        return ty;
2401 2401
    }
2402 2402
    match node.value {
2403 2403
        case ast::NodeValue::Block(block) => return try resolveBlock(self, node, block),
2404 2404
        case ast::NodeValue::Let(decl) => return try resolveLet(self, node, decl),
2405 -
        case ast::NodeValue::ConstDecl(decl) => return try resolveConst(
2406 -
            self, node, decl.ident, decl.type, decl.value, decl.attrs
2405 +
        case ast::NodeValue::ConstDecl(decl) => return try resolveConstOrStatic(
2406 +
            self, node, decl.ident, decl.type, decl.value, decl.attrs, true
2407 2407
        ),
2408 -
        case ast::NodeValue::StaticDecl(decl) => return try resolveStatic(
2409 -
            self, node, decl.ident, decl.type, decl.value, decl.attrs
2408 +
        case ast::NodeValue::StaticDecl(decl) => return try resolveConstOrStatic(
2409 +
            self, node, decl.ident, decl.type, decl.value, decl.attrs, false
2410 2410
        ),
2411 2411
        case ast::NodeValue::FnParam(param) => return try resolveFnParam(self, node, param),
2412 2412
        case ast::NodeValue::If(cond) => return try resolveIf(self, node, cond),
2413 2413
        case ast::NodeValue::CondExpr(cond) => return try resolveCondExpr(self, node, cond),
2414 2414
        case ast::NodeValue::IfLet(cond) => return try resolveIfLet(self, node, cond),
2835 2835
            actual: args.len,
2836 2836
        }));
2837 2837
    }
2838 2838
}
2839 2839
2840 -
/// Helper for analyzing `const` declarations.
2841 -
fn resolveConst(
2840 +
/// Helper for analyzing `const` and `static` declarations.
2841 +
fn resolveConstOrStatic(
2842 2842
    self: *mut Resolver,
2843 2843
    node: *ast::Node,
2844 2844
    ident: *ast::Node,
2845 2845
    typeNode: *ast::Node,
2846 2846
    valueNode: *ast::Node,
2847 -
    attrList: ?ast::Attributes
2848 -
) -> Type throws (ResolveError) {
2849 -
    let attrs = try resolveAttributes(self, attrList);
2850 -
    let bindingTy = try infer(self, typeNode);
2851 -
    let valueTy = try checkAssignable(self, valueNode, bindingTy);
2852 -
2853 -
    let constVal = constValueEntry(self, valueNode);
2854 -
    if constVal == nil and not isConstExpr(self, valueNode) {
2855 -
        throw emitError(self, valueNode, ErrorKind::ConstExprRequired);
2856 -
    }
2857 -
    try bindConstIdent(self, ident, node, bindingTy, constVal, attrs);
2858 -
    try setNodeType(self, valueNode, bindingTy);
2859 -
2860 -
    return Type::Void;
2861 -
}
2862 -
2863 -
/// Helper for analyzing `static` declarations.
2864 -
fn resolveStatic(
2865 -
    self: *mut Resolver,
2866 -
    node: *ast::Node,
2867 -
    ident: *ast::Node,
2868 -
    typeNode: *ast::Node,
2869 -
    valueNode: *ast::Node,
2870 -
    attrList: ?ast::Attributes
2847 +
    attrList: ?ast::Attributes,
2848 +
    isConst: bool
2871 2849
) -> Type throws (ResolveError) {
2872 2850
    let attrs = try resolveAttributes(self, attrList);
2873 2851
    let bindingTy = try infer(self, typeNode);
2874 2852
    let valueTy = try checkAssignable(self, valueNode, bindingTy);
2875 2853
2876 -
    if not isConstExpr(self, valueNode) {
2877 -
        throw emitError(self, valueNode, ErrorKind::ConstExprRequired);
2854 +
    if isConst {
2855 +
        let constVal = constValueEntry(self, valueNode);
2856 +
        if constVal == nil and not isConstExpr(self, valueNode) {
2857 +
            throw emitError(self, valueNode, ErrorKind::ConstExprRequired);
2858 +
        }
2859 +
        try bindConstIdent(self, ident, node, bindingTy, constVal, attrs);
2860 +
    } else {
2861 +
        if not isConstExpr(self, valueNode) {
2862 +
            throw emitError(self, valueNode, ErrorKind::ConstExprRequired);
2863 +
        }
2864 +
        try bindValueIdent(self, ident, node, bindingTy, true, 0, attrs);
2878 2865
    }
2879 -
    try bindValueIdent(self, ident, node, bindingTy, true, 0, attrs);
2880 2866
    try setNodeType(self, valueNode, bindingTy);
2881 2867
2882 2868
    return Type::Void;
2883 2869
}
2884 2870