04 — The Type Checker Pass
The type checker walks the AST after the resolver and before execution. It visits every expression and statement, computing or verifying types.
The TypeChecker class
class TypeChecker : public ExprVisitor<TypePtr>, public StmtVisitor<void> {
// Type environment: name → TypePtr
std::vector<std::unordered_map<std::string, TypePtr>> scopes_;
TypePtr currentReturnType_; // expected return type of current function
void beginScope();
void endScope();
void declare(const std::string& name, TypePtr t);
TypePtr lookup(const std::string& name);
void checkCompatible(TypePtr expected, TypePtr actual, int line);
public:
void check(std::vector<StmtPtr>& stmts);
// ExprVisitor
TypePtr visitNumber(NumberExpr&) override;
TypePtr visitBool(BoolExpr&) override;
TypePtr visitString(StringExpr&) override;
TypePtr visitNil(NilExpr&) override;
TypePtr visitVar(VarExpr&) override;
TypePtr visitBinary(BinaryExpr&) override;
TypePtr visitUnary(UnaryExpr&) override;
TypePtr visitCall(CallExpr&) override;
TypePtr visitFn(FnExpr&) override;
// StmtVisitor
void visitLet(LetStmt&) override;
void visitBlock(BlockStmt&) override;
void visitIf(IfStmt&) override;
void visitWhile(WhileStmt&) override;
void visitReturn(ReturnStmt&) override;
void visitPrint(PrintStmt&) override;
};
Expression type rules
TypePtr TypeChecker::visitBinary(BinaryExpr& e) {
auto L = check(*e.left);
auto R = check(*e.right);
switch (e.op) {
case Plus: case Minus: case Star: case Slash:
checkCompatible(mkNum(), L, e.line);
checkCompatible(mkNum(), R, e.line);
return mkNum();
case EqEq: case BangEq:
// any two compatible types may be compared for equality
return mkBool();
case Lt: case LtEq: case Gt: case GtEq:
checkCompatible(mkNum(), L, e.line);
checkCompatible(mkNum(), R, e.line);
return mkBool();
case And: case Or:
checkCompatible(mkBool(), L, e.line);
checkCompatible(mkBool(), R, e.line);
return mkBool();
// ...
}
}
The checkCompatible(expected, actual, line) function throws a
TypeCheckError if !compatible(expected, actual) (see step 06 for the
compatible definition):
void TypeChecker::checkCompatible(TypePtr expected, TypePtr actual, int line) {
if (!compatible(*expected, *actual))
throw TypeCheckError("[line " + std::to_string(line) +
"] Expected " + typeToStr(*expected) +
", got " + typeToStr(*actual) + ".");
}
Statement rules
void TypeChecker::visitLet(LetStmt& s) {
TypePtr initType = s.init ? check(*s.init) : mkNil();
if (s.annotation)
checkCompatible(s.annotation, initType, s.line);
TypePtr declaredType = s.annotation ? s.annotation : initType;
declare(s.name, declaredType);
}
If no annotation is given, the type is inferred from the initialiser. This is a simple form of Hindley-Milner local type inference:
let x = 42; // inferred: Num
let y = "hello"; // inferred: Str
let z = x + y; // error: expected Num, got Str (for y)
Function type checking
TypePtr TypeChecker::visitFn(FnExpr& fn) {
beginScope();
std::vector<TypePtr> paramTypes;
for (auto& p : fn.params) {
TypePtr t = p.annotation ? p.annotation : mkAny();
declare(p.name, t);
paramTypes.push_back(t);
}
TypePtr retType = fn.retAnnotation ? fn.retAnnotation : mkAny();
auto saved = currentReturnType_;
currentReturnType_ = retType;
check(*fn.body);
currentReturnType_ = saved;
endScope();
return mkFn(std::move(paramTypes), retType);
}
Return type checking
void TypeChecker::visitReturn(ReturnStmt& s) {
TypePtr t = s.value ? check(*s.value) : mkNil();
checkCompatible(currentReturnType_, t, s.line);
}
If the function has Any return type (no annotation), any return value
is accepted.