// cp-06 / tests / test_compiler.cpp
//
// Unit tests for the AST -> bytecode pipeline.
//
// Strategy: feed source through the full pipeline (lex/parse/resolve/typecheck/compile),
// then assert either:
//   (a) properties of the resulting Chunk (opcode sequence, constants pool), or
//   (b) substring matches in the disassembled text.
//
// (a) is preferred for short programs where the exact instruction list is the
// thing under test (arithmetic, locals). (b) is preferred for control flow
// where exact offsets are noisy but landmark opcodes / line annotations
// matter.

#include "compiler.hpp"
#include "disassembler.hpp"
#include "lexer.hpp"
#include "parser.hpp"
#include "resolver.hpp"
#include "typecheck.hpp"

#include <cassert>
#include <cstdint>
#include <cstdio>
#include <iostream>
#include <sstream>
#include <string>
#include <vector>

using namespace ml;

namespace
{

    struct CompileOutput
    {
        bool parsedOk = false;
        bool resolvedOk = false;
        bool typedOk = false;
        bool compiledOk = false;
        Chunk chunk;
        std::vector<std::string> diagnostics;
        std::string disasm;
    };

    CompileOutput compileSource(const std::string &src)
    {
        CompileOutput out;
        try
        {
            Lexer lex(src);
            Parser parser(lex.tokenize());
            auto program = parser.parseProgram();
            out.parsedOk = true;

            Resolver resolver;
            auto rdiags = resolver.resolve(program);
            out.resolvedOk = rdiags.empty();
            for (auto &d : rdiags)
                out.diagnostics.push_back(d);

            TypeChecker tc;
            auto tdiags = tc.check(program);
            out.typedOk = tdiags.empty();
            for (auto &d : tdiags)
                out.diagnostics.push_back(d);

            Compiler c;
            auto r = c.compile(program);
            out.compiledOk = r.ok();
            for (auto &d : r.diagnostics)
                out.diagnostics.push_back(d);
            out.chunk = std::move(r.chunk);
            out.disasm = disassemble(out.chunk);
        }
        catch (const std::exception &e)
        {
            out.diagnostics.push_back(std::string("exception: ") + e.what());
        }
        return out;
    }

    // ---- minimal test harness ----
    int g_failures = 0;
    int g_total = 0;

#define CHECK(cond)                                             \
    do                                                          \
    {                                                           \
        ++g_total;                                              \
        if (!(cond))                                            \
        {                                                       \
            ++g_failures;                                       \
            std::cerr << "FAIL " << __FILE__ << ":" << __LINE__ \
                      << ": " #cond "\n";                       \
        }                                                       \
    } while (0)

#define CHECK_CONTAINS(haystack, needle)                             \
    do                                                               \
    {                                                                \
        ++g_total;                                                   \
        const std::string _h = (haystack);                           \
        const std::string _n = (needle);                             \
        if (_h.find(_n) == std::string::npos)                        \
        {                                                            \
            ++g_failures;                                            \
            std::cerr << "FAIL " << __FILE__ << ":" << __LINE__      \
                      << ": expected to find \"" << _n << "\" in:\n" \
                      << _h << "\n";                                 \
        }                                                            \
    } while (0)

    bool opsMatch(const Chunk &c, const std::vector<Op> &expected)
    {
        std::vector<Op> got;
        size_t off = 0;
        while (off < c.code.size())
        {
            Op op = static_cast<Op>(c.code[off]);
            got.push_back(op);
            switch (op)
            {
            // 2-byte (opcode + 1 operand) instructions.
            case Op::Constant:
            case Op::DefGlobal:
            case Op::GetGlobal:
            case Op::SetGlobal:
            case Op::GetLocal:
            case Op::SetLocal:
            case Op::Call:
            case Op::Closure:
            case Op::GetUpvalue:
            case Op::SetUpvalue:
                off += 2;
                break;
            // 3-byte (opcode + 2-byte operand) instructions.
            case Op::Jump:
            case Op::JumpIfFalse:
            case Op::Loop:
                off += 3;
                break;
            default:
                off += 1;
            }
        }
        if (got.size() != expected.size())
            return false;
        for (size_t i = 0; i < got.size(); ++i)
            if (got[i] != expected[i])
                return false;
        return true;
    }

    void dumpOps(const Chunk &c, std::ostream &os) __attribute__((unused));
    void dumpOps(const Chunk &c, std::ostream &os)
    {
        size_t off = 0;
        while (off < c.code.size())
        {
            os << opName(static_cast<Op>(c.code[off])) << " ";
            Op op = static_cast<Op>(c.code[off]);
            switch (op)
            {
            case Op::Constant:
            case Op::DefGlobal:
            case Op::GetGlobal:
            case Op::SetGlobal:
            case Op::GetLocal:
            case Op::SetLocal:
            case Op::Call:
            case Op::Closure:
            case Op::GetUpvalue:
            case Op::SetUpvalue:
                off += 2;
                break;
            case Op::Jump:
            case Op::JumpIfFalse:
            case Op::Loop:
                off += 3;
                break;
            default:
                off += 1;
            }
        }
        os << "\n";
    }

    // ---- individual test cases ----

    void test_arithmetic_literal()
    {
        auto out = compileSource("print 1 + 2 * 3;");
        CHECK(out.compiledOk);
        // 1 + (2 * 3) -> Constant(1) Constant(2) Constant(3) Mul Add Print Return
        CHECK(opsMatch(out.chunk,
                       {Op::Constant, Op::Constant, Op::Constant, Op::Mul, Op::Add, Op::Print, Op::Return}));
        CHECK(out.chunk.constants.size() == 3);
    }

    void test_unary_and_logic()
    {
        auto out = compileSource("print !(1 < 2) || true;");
        CHECK(out.compiledOk);
        // !(1 < 2) || true
        // <(1<2)>  Not   [or-jump]
        // Constant(1) Constant(2) Lt Not JumpIfFalse Jump Pop True Print Return
        CHECK(opsMatch(out.chunk, {Op::Constant, Op::Constant, Op::Lt, Op::Not,
                                   Op::JumpIfFalse, Op::Jump, Op::Pop, Op::True,
                                   Op::Print, Op::Return}));
    }

    void test_global_let_var()
    {
        auto out = compileSource(
            "let x = 10;\n"
            "var y = x + 1;\n"
            "y = y * 2;\n"
            "print y;\n");
        CHECK(out.compiledOk);
        // Constant(10) DefGlobal(x)
        // GetGlobal(x) Constant(1) Add DefGlobal(y)
        // GetGlobal(y) Constant(2) Mul SetGlobal(y) Pop
        // GetGlobal(y) Print
        // Return
        CHECK(opsMatch(out.chunk, {Op::Constant, Op::DefGlobal,
                                   Op::GetGlobal, Op::Constant, Op::Add, Op::DefGlobal,
                                   Op::GetGlobal, Op::Constant, Op::Mul, Op::SetGlobal, Op::Pop,
                                   Op::GetGlobal, Op::Print,
                                   Op::Return}));
    }

    void test_locals_block()
    {
        auto out = compileSource(
            "{\n"
            "  let a = 1;\n"
            "  let b = 2;\n"
            "  print a + b;\n"
            "}\n");
        CHECK(out.compiledOk);
        // Constant(1) [a -> slot 0]
        // Constant(2) [b -> slot 1]
        // GetLocal(0) GetLocal(1) Add Print
        // Pop Pop      (end-of-scope: b then a)
        // Return
        CHECK(opsMatch(out.chunk, {Op::Constant,
                                   Op::Constant,
                                   Op::GetLocal, Op::GetLocal, Op::Add, Op::Print,
                                   Op::Pop, Op::Pop,
                                   Op::Return}));
        // Locals should not show up in the constants pool as names.
        for (auto &c : out.chunk.constants)
        {
            CHECK(c.kind != Value::K::Str || (c.s != "a" && c.s != "b"));
        }
    }

    void test_let_immutable()
    {
        auto out = compileSource(
            "{\n"
            "  let a = 1;\n"
            "  a = 2;\n"
            "}\n");
        CHECK(!out.compiledOk);
        bool sawImmutableMsg = false;
        for (auto &d : out.diagnostics)
        {
            if (d.find("immutable") != std::string::npos)
                sawImmutableMsg = true;
        }
        CHECK(sawImmutableMsg);
    }

    void test_if_else_disasm()
    {
        auto out = compileSource(
            "if (true) { print 1; } else { print 2; }\n");
        CHECK(out.compiledOk);
        CHECK_CONTAINS(out.disasm, "JUMP_IF_FALSE");
        CHECK_CONTAINS(out.disasm, "JUMP ");
        CHECK_CONTAINS(out.disasm, "PRINT");
        // Should compile to:
        //   True
        //   JumpIfFalse to else
        //   Pop
        //   Constant(1) Print
        //   Jump to end
        //   Pop
        //   Constant(2) Print
        //   Return
        CHECK(opsMatch(out.chunk, {Op::True,
                                   Op::JumpIfFalse, Op::Pop,
                                   Op::Constant, Op::Print,
                                   Op::Jump, Op::Pop,
                                   Op::Constant, Op::Print,
                                   Op::Return}));
    }

    void test_while_loop_emits_loop_op()
    {
        auto out = compileSource(
            "var i = 0;\n"
            "while (i < 3) { i = i + 1; }\n");
        CHECK(out.compiledOk);
        CHECK_CONTAINS(out.disasm, "LOOP");
        CHECK_CONTAINS(out.disasm, "JUMP_IF_FALSE");
    }

    void test_string_constant_dedup()
    {
        auto out = compileSource(
            "print \"hi\";\n"
            "print \"hi\";\n");
        CHECK(out.compiledOk);
        int hiCount = 0;
        for (auto &c : out.chunk.constants)
        {
            if (c.kind == Value::K::Str && c.s == "hi")
                ++hiCount;
        }
        CHECK(hiCount == 1);
    }

    void test_line_table()
    {
        auto out = compileSource(
            "print 1;\n"
            "print 2;\n");
        CHECK(out.compiledOk);
        CHECK(out.chunk.lines.size() == out.chunk.code.size());
        // First Constant should be from line 1, somewhere in the stream a 2.
        bool seenLine1 = false, seenLine2 = false;
        for (int l : out.chunk.lines)
        {
            if (l == 1)
                seenLine1 = true;
            if (l == 2)
                seenLine2 = true;
        }
        CHECK(seenLine1);
        CHECK(seenLine2);
    }

    void test_function_declaration_deferred_to_cp07()
    {
        auto out = compileSource("fn add(a, b) { return a + b; }\n");
        // Function declarations flag an error in cp-06. (We don't walk the body so
        // the inner `return` is not surfaced separately — both get re-enabled in cp-07.)
        CHECK(!out.compiledOk);
        bool sawFnMsg = false;
        for (auto &d : out.diagnostics)
        {
            if (d.find("function declarations") != std::string::npos)
                sawFnMsg = true;
        }
        CHECK(sawFnMsg);
    }

    void test_call_expression_deferred_to_cp07()
    {
        auto out = compileSource(
            "fn id(x) { return x; }\n"
            "print id(42);\n");
        CHECK(!out.compiledOk);
        bool sawCallMsg = false;
        for (auto &d : out.diagnostics)
        {
            if (d.find("function calls are implemented in cp-07") != std::string::npos)
                sawCallMsg = true;
        }
        CHECK(sawCallMsg);
    }

    void test_local_shadowing_across_scopes_ok()
    {
        auto out = compileSource(
            "{\n"
            "  let x = 1;\n"
            "  {\n"
            "    let x = 2;\n"
            "    print x;\n"
            "  }\n"
            "  print x;\n"
            "}\n");
        CHECK(out.compiledOk);
        // Inner print should GetLocal(slot 1), outer GetLocal(slot 0).
        // We just check both appear and the disassembly is well-formed.
        CHECK_CONTAINS(out.disasm, "GET_LOCAL");
    }

    void test_local_same_scope_redeclaration_errors()
    {
        auto out = compileSource(
            "{\n"
            "  let a = 1;\n"
            "  let a = 2;\n"
            "}\n");
        // Resolver catches this first (cp-04 behaviour).
        CHECK(!out.resolvedOk || !out.compiledOk);
        bool sawDup = false;
        for (auto &d : out.diagnostics)
        {
            if (d.find("already declared") != std::string::npos)
                sawDup = true;
        }
        CHECK(sawDup);
    }

    void test_short_circuit_and()
    {
        auto out = compileSource("print false && true;");
        CHECK(out.compiledOk);
        // False JumpIfFalse Pop True Print Return
        CHECK(opsMatch(out.chunk, {Op::False, Op::JumpIfFalse, Op::Pop, Op::True,
                                   Op::Print, Op::Return}));
    }

    void test_modulo_and_string_concat()
    {
        auto out = compileSource(
            "print 10 % 3;\n"
            "print \"a\" + \"b\";\n");
        CHECK(out.compiledOk);
        CHECK_CONTAINS(out.disasm, "MOD");
        CHECK_CONTAINS(out.disasm, "ADD");
    }

} // namespace

int main()
{
    test_arithmetic_literal();
    test_unary_and_logic();
    test_global_let_var();
    test_locals_block();
    test_let_immutable();
    test_if_else_disasm();
    test_while_loop_emits_loop_op();
    test_string_constant_dedup();
    test_line_table();
    test_function_declaration_deferred_to_cp07();
    test_call_expression_deferred_to_cp07();
    test_local_shadowing_across_scopes_ok();
    test_local_same_scope_redeclaration_errors();
    test_short_circuit_and();
    test_modulo_and_string_concat();

    std::cout << (g_total - g_failures) << "/" << g_total << " passed\n";
    return g_failures == 0 ? 0 : 1;
}
