#include <Constructs/Views/EnumerateView.hpp>
#include <File/File.hpp>

#include <filesystem>
#include <format>
#include <fstream>
#include <iostream>
#include <set>
#include <sstream>
#include <string>
#include <vector>

#define WIN32_LEAN_AND_MEAN
#define NOMINMAX
#include <Windows.h>
#include <ImageHlp.h>
#include <tchar.h>
#include <Helpers/String.hpp>
#include <Helpers/SysError.hpp>

using namespace RC;
namespace fs = std::filesystem;

using std::cerr;
using std::cout;
using std::endl;
using std::ifstream;
using std::ofstream;

using std::string;

struct ExportFunction
{
    uint16_t ordinal;
    bool is_named;
    string name;

    ExportFunction(uint16_t ordinal, bool is_named, string name) : ordinal(ordinal), is_named(is_named), name(name)
    {
    }
};

std::vector<ExportFunction> DumpExports(const fs::path& dll_path)
{
    auto dll_file = File::open(dll_path);
    const auto dll_file_map = dll_file.memory_map();

    ULONG export_directory_size = 0;
    IMAGE_EXPORT_DIRECTORY* export_directory =
            (IMAGE_EXPORT_DIRECTORY*)ImageDirectoryEntryToData(dll_file_map.data(), FALSE, IMAGE_DIRECTORY_ENTRY_EXPORT, &export_directory_size);

    if (export_directory == nullptr)
    {
        auto err_msg = to_string(SysError(GetLastError()));
        cerr << std::format("Failed to get export directory, reason: {}", err_msg) << '\n';
        return {};
    }

    IMAGE_DOS_HEADER* dos_header = (IMAGE_DOS_HEADER*)dll_file_map.data();
    IMAGE_NT_HEADERS* nt_header = (IMAGE_NT_HEADERS*)(dll_file_map.data() + dos_header->e_lfanew);

    DWORD* name_rvas = (DWORD*)ImageRvaToVa(nt_header, dll_file_map.data(), export_directory->AddressOfNames, NULL);
    DWORD* function_rvas = (DWORD*)ImageRvaToVa(nt_header, dll_file_map.data(), export_directory->AddressOfFunctions, NULL);
    uint16_t* ordinals = (uint16_t*)ImageRvaToVa(nt_header, dll_file_map.data(), export_directory->AddressOfNameOrdinals, NULL);

    std::vector<ExportFunction> exports;
    std::set<uint16_t> exported_ordinals;

    for (size_t i = 0; i < export_directory->NumberOfNames; i++)
    {
        std::string export_name = (char*)ImageRvaToVa(nt_header, dll_file_map.data(), name_rvas[i], NULL);
        uint16_t ordinal = ordinals[i] + 1;

        ExportFunction named_export(ordinal, true, export_name);
        exports.push_back(named_export);

        exported_ordinals.insert(ordinal);
    }

    for (size_t i = 0; i < export_directory->NumberOfFunctions; i++)
    {
        uint16_t ordinal = (uint16_t)(export_directory->Base + i);
        uint32_t function_rva = function_rvas[i];

        if (function_rva == 0) continue;
        if (exported_ordinals.contains(ordinal)) continue; // a named function for this ordinal was already exported

        ExportFunction ordinal_export(ordinal, false, std::format("ordinal{}", ordinal));
        exports.push_back(ordinal_export);
    }

    return exports;
}

std::vector<ExportFunction> ReadExportsFile(const fs::path& exp_path, fs::path& dll_path_out)
{
    dll_path_out.clear();
    ifstream exp_file(exp_path);
    string line;

    std::getline(exp_file, line);
    if (line.find("Path: ") != string::npos)
    {
        dll_path_out = line.substr(6);
    }

    if (dll_path_out.empty())
    {
        cerr << std::format("Failed to read export file (missing file path info)") << endl;
        return {};
    }

    std::vector<ExportFunction> exports;
    while (std::getline(exp_file, line))
    {
        std::istringstream s(line);

        uint16_t ordinal{};
        string export_name;
        s >> ordinal >> export_name;

        // invalid line - let's just hope functions with ordinal 0 don't exist
        if (ordinal == 0) continue;
        auto is_named = !export_name.empty();

        ExportFunction exp(ordinal, is_named, is_named ? export_name : std::format("ordinal{}", ordinal));
        exports.push_back(exp);
    }

    return exports;
}

int _tmain(int argc, TCHAR* argv[])
{
    if (argc != 3)
    {
        cerr << "Invalid arguments! Expected: proxy_generator.exe <input_dll_name_or_exports_file> <output_path>" << endl;
        return -1;
    }

    const fs::path input_file = argv[1];
    const fs::path output_path = argv[2];

    if (!fs::exists(input_file))
    {
        cerr << "Input file doesn't exist!\n" << endl;
        return -1;
    }

    fs::path input_dll = input_file;
    fs::path input_dll_name = input_file.filename();
    std::vector<ExportFunction> exports;

    if (input_file.extension() == ".exports")
    {
        cout << std::format("Generating a proxy using {}, output path: {}", input_dll_name.string(), output_path.string()) << endl;
        exports = ReadExportsFile(input_file, input_dll);
        input_dll_name = input_dll.filename();
    }
    else
    {
        cout << std::format("Generating a proxy for {}, output path: {}", input_dll.string(), output_path.string()) << endl;
        exports = DumpExports(input_dll);

        const auto exports_path = (output_path / input_dll_name).replace_extension("exports");
        ofstream exports_file(exports_path);
        exports_file << "Path: " << input_dll.string() << endl << endl;
        for (const auto [e, index] : exports | views::enumerate)
        {
            exports_file << std::format("{} {}", e.ordinal, e.is_named ? e.name : "") << endl;
        }
        exports_file.close();

        cout << std::format("Exports file generated at {}", exports_path.string()) << endl;
    }

    cout << std::format("Export count: {}", exports.size()) << endl;

    ofstream def_file((output_path / input_dll_name).replace_extension("def"));
    def_file << std::format("LIBRARY {}", fs::path(input_dll_name).replace_extension().string()) << endl;
    def_file << "EXPORTS" << endl;

    for (const auto [e, index] : exports | views::enumerate)
    {
        def_file << std::format("  {}=f{} @{}", e.name, index, e.ordinal) << endl;
    }
    def_file.close();

    ofstream asm_file((output_path / input_dll_name).replace_extension("asm"));
    asm_file << ".code" << endl;
    asm_file << "extern mProcs:QWORD" << endl;

    for (const auto [e, index] : exports | views::enumerate)
    {
        asm_file << std::format("f{} proc", index) << endl;
        asm_file << std::format("  jmp mProcs[8*{}]", index) << endl;
        asm_file << std::format("f{} endp", index) << endl;
    }

    asm_file << "end" << endl;
    asm_file.close();

    ofstream cpp_file(output_path / "dllmain.cpp");
    cpp_file << "#include <File/Macros.hpp>" << endl;
    cpp_file << endl;
    cpp_file << "#include <cstdint>" << endl;
    cpp_file << "#include <fstream>" << endl;
    cpp_file << "#include <string>" << endl;
    cpp_file << endl;
    cpp_file << "#define WIN32_LEAN_AND_MEAN" << endl;
    cpp_file << "#include <Windows.h>" << endl;
    cpp_file << "#include <shellapi.h>" << endl;
    cpp_file << "#include <filesystem>" << endl;
    cpp_file << endl;
    cpp_file << "#pragma comment(lib, \"user32.lib\")" << endl;
    cpp_file << "#pragma comment(lib, \"shell32.lib\")" << endl;
    cpp_file << endl;

    cpp_file << "using namespace RC;" << endl;
    cpp_file << "namespace fs = std::filesystem;" << endl;
    cpp_file << endl;

    cpp_file << "HMODULE SOriginalDll = nullptr;" << endl;
    cpp_file << std::format("extern \"C\" uintptr_t mProcs[{}] = {{0}};", exports.size()) << endl;
    cpp_file << endl;

    cpp_file << "void setup_functions()" << endl;
    cpp_file << "{" << endl;

    for (const auto [e, index] : exports | views::enumerate)
    {
        string getter = e.is_named ? std::format("\"{}\"", e.name) : std::format("MAKEINTRESOURCEA({})", e.ordinal);
        cpp_file << std::format("    mProcs[{}] = (uintptr_t)GetProcAddress(SOriginalDll, {});", index, getter) << endl;
    }

    cpp_file << "}" << endl;
    cpp_file << endl;

    cpp_file << "void load_original_dll()" << endl;
    cpp_file << "{" << endl;
    cpp_file << "    wchar_t path[MAX_PATH];" << endl;
    cpp_file << "    GetSystemDirectory(path, MAX_PATH);" << endl;
    cpp_file << endl;
    cpp_file << std::format("    std::wstring dll_path = std::wstring(path) + L\"\\\\{}\";", input_dll_name.string()) << endl;
    cpp_file << endl;
    cpp_file << "    SOriginalDll = LoadLibrary(dll_path.c_str());" << endl;
    cpp_file << "    if (!SOriginalDll)" << endl;
    cpp_file << "    {" << endl;
    cpp_file << "        MessageBox(nullptr, L\"Failed to load proxy DLL\", L\"UE4SS Error\", MB_OK | MB_ICONERROR);" << endl;
    cpp_file << "        ExitProcess(0);" << endl;
    cpp_file << "    }" << endl;
    cpp_file << "}" << endl;
    cpp_file << endl;

    cpp_file << "bool is_absolute_path(const std::string& path)" << endl;
    cpp_file << "{" << endl;
    cpp_file << "    return fs::path(path).is_absolute();" << endl;
    cpp_file << "}" << endl;
    cpp_file << endl;

    cpp_file << "bool should_disable_ue4ss()" << endl;
    cpp_file << "{" << endl;
    cpp_file << "    int argc = 0;" << endl;
    cpp_file << "    LPWSTR* argv = CommandLineToArgvW(GetCommandLineW(), &argc);" << endl;
    cpp_file << "    if (!argv)" << endl;
    cpp_file << "    {" << endl;
    cpp_file << "        return false;" << endl;
    cpp_file << "    }" << endl;
    cpp_file << endl;
    cpp_file << "    bool disable = false;" << endl;
    cpp_file << "    for (int i = 0; i < argc; i++)" << endl;
    cpp_file << "    {" << endl;
    cpp_file << "        if (wcscmp(argv[i], L\"--disable-ue4ss\") == 0)" << endl;
    cpp_file << "        {" << endl;
    cpp_file << "            disable = true;" << endl;
    cpp_file << "            break;" << endl;
    cpp_file << "        }" << endl;
    cpp_file << "    }" << endl;
    cpp_file << endl;
    cpp_file << "    LocalFree(argv);" << endl;
    cpp_file << "    return disable;" << endl;
    cpp_file << "}" << endl;
    cpp_file << endl;

    cpp_file << "std::wstring get_ue4ss_path_from_args()" << endl;
    cpp_file << "{" << endl;
    cpp_file << "    int argc = 0;" << endl;
    cpp_file << "    LPWSTR* argv = CommandLineToArgvW(GetCommandLineW(), &argc);" << endl;
    cpp_file << "    if (!argv)" << endl;
    cpp_file << "    {" << endl;
    cpp_file << "        return L\"\";" << endl;
    cpp_file << "    }" << endl;
    cpp_file << endl;
    cpp_file << "    std::wstring ue4ss_path;" << endl;
    cpp_file << "    for (int i = 0; i < argc - 1; i++)" << endl;
    cpp_file << "    {" << endl;
    cpp_file << "        if (wcscmp(argv[i], L\"--ue4ss-path\") == 0)" << endl;
    cpp_file << "        {" << endl;
    cpp_file << "            ue4ss_path = argv[i + 1];" << endl;
    cpp_file << "            break;" << endl;
    cpp_file << "        }" << endl;
    cpp_file << "    }" << endl;
    cpp_file << endl;
    cpp_file << "    LocalFree(argv);" << endl;
    cpp_file << "    return ue4ss_path;" << endl;
    cpp_file << "}" << endl;
    cpp_file << endl;

    cpp_file << "HMODULE load_ue4ss_dll(HMODULE moduleHandle)" << endl;
    cpp_file << "{" << endl;
    cpp_file << "    HMODULE hModule = nullptr;" << endl;
    cpp_file << "    wchar_t moduleFilenameBuffer[1024]{'\\0'};" << endl;
    cpp_file << "    GetModuleFileNameW(moduleHandle, moduleFilenameBuffer, sizeof(moduleFilenameBuffer) / sizeof(wchar_t));" << endl;
    cpp_file << "    const auto currentPath = std::filesystem::path(moduleFilenameBuffer).parent_path();" << endl;
    cpp_file << "    const fs::path ue4ssPath = currentPath / \"ue4ss\" / \"UE4SS.dll\";" << endl;
    cpp_file << endl;

    cpp_file << "    // Check for --ue4ss-path command line argument" << endl;
    cpp_file << "    std::wstring cmdLineUe4ssPath = get_ue4ss_path_from_args();" << endl;
    cpp_file << "    if (!cmdLineUe4ssPath.empty())" << endl;
    cpp_file << "    {" << endl;
    cpp_file << "        fs::path ue4ssArgPath = cmdLineUe4ssPath;" << endl;
    cpp_file << "        if (!ue4ssArgPath.is_absolute())" << endl;
    cpp_file << "        {" << endl;
    cpp_file << "            ue4ssArgPath = currentPath / ue4ssArgPath;" << endl;
    cpp_file << "        }" << endl;
    cpp_file << endl;
    cpp_file << "        // Attempt to load UE4SS.dll from the command line path" << endl;
    cpp_file << "        hModule = LoadLibrary(ue4ssArgPath.c_str());" << endl;
    cpp_file << "        if (hModule)" << endl;
    cpp_file << "        {" << endl;
    cpp_file << "            return hModule;" << endl;
    cpp_file << "        }" << endl;
    cpp_file << "    }" << endl;
    cpp_file << endl;

    cpp_file << "    // Check for override.txt" << endl;
    cpp_file << "    const fs::path overrideFilePath = currentPath / \"override.txt\";" << endl;
    cpp_file << "    if (fs::exists(overrideFilePath))" << endl;
    cpp_file << "    {" << endl;
    cpp_file << "        std::ifstream overrideFile(overrideFilePath);" << endl;
    cpp_file << "        std::string overridePath;" << endl;
    cpp_file << "        if (std::getline(overrideFile, overridePath))" << endl;
    cpp_file << "        {" << endl;
    cpp_file << "            fs::path ue4ssOverridePath = overridePath;" << endl;
    cpp_file << "            if (!is_absolute_path(overridePath))" << endl;
    cpp_file << "            {" << endl;
    cpp_file << "                ue4ssOverridePath = currentPath / overridePath;" << endl;
    cpp_file << "            }" << endl;
    cpp_file << endl;
    cpp_file << "            ue4ssOverridePath = ue4ssOverridePath / \"UE4SS.dll\";" << endl;
    cpp_file << endl;
    cpp_file << "            // Attempt to load UE4SS.dll from the override path" << endl;
    cpp_file << "            hModule = LoadLibrary(ue4ssOverridePath.c_str());" << endl;
    cpp_file << "            if (hModule)" << endl;
    cpp_file << "            {" << endl;
    cpp_file << "                return hModule;" << endl;
    cpp_file << "            }" << endl;
    cpp_file << "        }" << endl;
    cpp_file << "    }" << endl;
    cpp_file << endl;

    cpp_file << "    // Attempt to load UE4SS.dll from ue4ss directory" << endl;
    cpp_file << "    hModule = LoadLibrary(ue4ssPath.c_str());" << endl;
    cpp_file << "    if (!hModule)" << endl;
    cpp_file << "    {" << endl;
    cpp_file << "        // If loading from ue4ss directory fails, load from the current directory" << endl;
    cpp_file << "        hModule = LoadLibrary(L\"UE4SS.dll\");" << endl;
    cpp_file << "    }" << endl;
    cpp_file << endl;
    cpp_file << "    return hModule;" << endl;
    cpp_file << "}" << endl;
    cpp_file << endl;

    cpp_file << "BOOL WINAPI DllMain(HMODULE hInstDll, DWORD fdwReason, LPVOID lpvReserved)" << endl;
    cpp_file << "{" << endl;
    cpp_file << "    if (fdwReason == DLL_PROCESS_ATTACH)" << endl;
    cpp_file << "    {" << endl;
    cpp_file << "        load_original_dll();" << endl;
    cpp_file << "        setup_functions();" << endl;
    cpp_file << endl;
    cpp_file << "        // Check if UE4SS should be disabled via command line argument" << endl;
    cpp_file << "        if (should_disable_ue4ss())" << endl;
    cpp_file << "        {" << endl;
    cpp_file << "            // UE4SS is disabled, proxy will still forward calls to original DLL" << endl;
    cpp_file << "            return TRUE;" << endl;
    cpp_file << "        }" << endl;
    cpp_file << endl;
    cpp_file << "        HMODULE hUE4SSDll = load_ue4ss_dll(hInstDll);" << endl;
    cpp_file << "        if (!hUE4SSDll)" << endl;
    cpp_file << "        {" << endl;
    cpp_file << "            MessageBox(nullptr, L\"Failed to load UE4SS.dll. Please see the docs on correct installation: "
                "https://docs.ue4ss.com/installation-guide\", L\"UE4SS Error\", MB_OK | MB_ICONERROR);"
             << endl;
    cpp_file << "            ExitProcess(0);" << endl;
    cpp_file << "        }" << endl;
    cpp_file << "    }" << endl;
    cpp_file << "    else if (fdwReason == DLL_PROCESS_DETACH)" << endl;
    cpp_file << "    {" << endl;
    cpp_file << "        FreeLibrary(SOriginalDll);" << endl;
    cpp_file << "    }" << endl;
    cpp_file << "    return TRUE;" << endl;
    cpp_file << "}" << endl;

    cpp_file.close();

    cout << "Finished generating!" << endl;

    return 0;
}
