04 — The Type Checker Pass

The type checker walks the AST after the resolver and before execution. It visits every expression and statement, computing or verifying types.

The TypeChecker class

class TypeChecker : public ExprVisitor<TypePtr>, public StmtVisitor<void> {
    // Type environment: name → TypePtr
    std::vector<std::unordered_map<std::string, TypePtr>> scopes_;
    TypePtr currentReturnType_;   // expected return type of current function

    void beginScope();
    void endScope();
    void declare(const std::string& name, TypePtr t);
    TypePtr lookup(const std::string& name);

    void checkCompatible(TypePtr expected, TypePtr actual, int line);
public:
    void check(std::vector<StmtPtr>& stmts);
    // ExprVisitor
    TypePtr visitNumber(NumberExpr&) override;
    TypePtr visitBool(BoolExpr&)     override;
    TypePtr visitString(StringExpr&) override;
    TypePtr visitNil(NilExpr&)       override;
    TypePtr visitVar(VarExpr&)       override;
    TypePtr visitBinary(BinaryExpr&) override;
    TypePtr visitUnary(UnaryExpr&)   override;
    TypePtr visitCall(CallExpr&)     override;
    TypePtr visitFn(FnExpr&)        override;
    // StmtVisitor
    void visitLet(LetStmt&)     override;
    void visitBlock(BlockStmt&) override;
    void visitIf(IfStmt&)       override;
    void visitWhile(WhileStmt&) override;
    void visitReturn(ReturnStmt&) override;
    void visitPrint(PrintStmt&)  override;
};

Expression type rules

TypePtr TypeChecker::visitBinary(BinaryExpr& e) {
    auto L = check(*e.left);
    auto R = check(*e.right);
    switch (e.op) {
    case Plus: case Minus: case Star: case Slash:
        checkCompatible(mkNum(), L, e.line);
        checkCompatible(mkNum(), R, e.line);
        return mkNum();
    case EqEq: case BangEq:
        // any two compatible types may be compared for equality
        return mkBool();
    case Lt: case LtEq: case Gt: case GtEq:
        checkCompatible(mkNum(), L, e.line);
        checkCompatible(mkNum(), R, e.line);
        return mkBool();
    case And: case Or:
        checkCompatible(mkBool(), L, e.line);
        checkCompatible(mkBool(), R, e.line);
        return mkBool();
    // ...
    }
}

The checkCompatible(expected, actual, line) function throws a TypeCheckError if !compatible(expected, actual) (see step 06 for the compatible definition):

void TypeChecker::checkCompatible(TypePtr expected, TypePtr actual, int line) {
    if (!compatible(*expected, *actual))
        throw TypeCheckError("[line " + std::to_string(line) +
            "] Expected " + typeToStr(*expected) +
            ", got " + typeToStr(*actual) + ".");
}

Statement rules

void TypeChecker::visitLet(LetStmt& s) {
    TypePtr initType = s.init ? check(*s.init) : mkNil();
    if (s.annotation)
        checkCompatible(s.annotation, initType, s.line);
    TypePtr declaredType = s.annotation ? s.annotation : initType;
    declare(s.name, declaredType);
}

If no annotation is given, the type is inferred from the initialiser. This is a simple form of Hindley-Milner local type inference:

let x = 42;       // inferred: Num
let y = "hello";  // inferred: Str
let z = x + y;    // error: expected Num, got Str (for y)

Function type checking

TypePtr TypeChecker::visitFn(FnExpr& fn) {
    beginScope();
    std::vector<TypePtr> paramTypes;
    for (auto& p : fn.params) {
        TypePtr t = p.annotation ? p.annotation : mkAny();
        declare(p.name, t);
        paramTypes.push_back(t);
    }
    TypePtr retType = fn.retAnnotation ? fn.retAnnotation : mkAny();
    auto saved = currentReturnType_;
    currentReturnType_ = retType;
    check(*fn.body);
    currentReturnType_ = saved;
    endScope();
    return mkFn(std::move(paramTypes), retType);
}

Return type checking

void TypeChecker::visitReturn(ReturnStmt& s) {
    TypePtr t = s.value ? check(*s.value) : mkNil();
    checkCompatible(currentReturnType_, t, s.line);
}

If the function has Any return type (no annotation), any return value is accepted.