#include "lt.h"
#include <atomic>
#include <chrono>
#include <condition_variable>
#include <iostream>
#include <vector>

#ifdef _WIN32
#include <windows.h>
#else
#include <arpa/inet.h>
#include <cstring>
#include <fcntl.h>
#include <netdb.h>
#include <sys/mman.h>
#include <sys/socket.h>
#include <sys/types.h>
#include <unistd.h>
#define SOCKET int
#endif

namespace lt {

//---------------------------------------------------------------------------
// JSON helpers
//---------------------------------------------------------------------------

// Load boolean
bool Load(json j, string const& key, bool& value)
{
    if (!j.contains(key) || !j.at(key).is_boolean()) {
        value = false;
        return false;
    }
    j.at(key).get_to(value);
    return true;
}

// Load string
bool Load(json j, string const& key, string& value)
{
    if (!j.contains(key) || !j.at(key).is_string()) {
        value = "";
        return false;
    }
    j.at(key).get_to(value);
    return true;
}

// Load number
template <typename T>
bool Load(json j, string const& key, T& value)
{
    if (!j.contains(key) || !j.at(key).is_number()) {
        value = 0;
        return false;
    }
    j.at(key).get_to(value);
    return true;
}

// Load base64 string
bool Load(json j, string const& key, bytes& value)
{
    if (!j.contains(key) || !j.at(key).is_string()) {
        value.clear();
        return false;
    }

    std::vector<int> T(256, -1);
    for (int i = 0; i < 64; i++) {
        T["ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"[i]] = i;
    }

    int val = 0, valb = -8;
    for (char c : j.at(key).get<string>()) {
        if (T[c] == -1) {
            break;
        }
        val = (val << 6) + T[c];
        valb += 6;
        if (valb >= 0) {
            value.push_back((val >> valb) & 0xFF);
            valb -= 8;
        }
    }
    return true;
}

// Load json
bool Load(json j, string const& key, json& value)
{
    if (!j.contains(key)) {
        return false;
    }
    j.at(key).get_to(value);
    return true;
}

#ifdef _WIN32
static string getLastErrorString()
{
    // Format
    LPVOID lpMsgBuf;
    DWORD bufLen = FormatMessage(
        FORMAT_MESSAGE_ALLOCATE_BUFFER | FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
        NULL,
        GetLastError(),
        MAKELANGID(LANG_NEUTRAL, SUBLANG_DEFAULT),
        (LPTSTR)&lpMsgBuf,
        0, NULL);

    // Result
    LPCSTR lpMsgStr = (LPCSTR)lpMsgBuf;
    string result(lpMsgStr, lpMsgStr + bufLen);
    LocalFree(lpMsgBuf);

    return result;
}

//---------------------------------------------------------------------------
// SharedBuffer
//---------------------------------------------------------------------------

class bufferObject {
public:
    bufferObject(string const& name, int size, error& err)
        : Data(nullptr)
        , Size(0)
        , Ref(0)
    {
        err.clear();

        handle = OpenFileMappingW(
            FILE_MAP_READ,
            FALSE,
            std::wstring(name.begin(), name.end()).c_str());
        if (handle == NULL) {
            err = "Could not open shared memory: " + getLastErrorString();
            return;
        }

        Data = (char*)MapViewOfFile(
            handle,
            FILE_MAP_READ,
            0,
            0,
            size);
        if (Data == nullptr) {
            CloseHandle(handle);
            err = "Could not map shared memory: " + getLastErrorString();
            return;
        }

        Size = size;
    }

    ~bufferObject()
    {
        UnmapViewOfFile(Data);
        CloseHandle(handle);
    }

    char* Data;
    int Size;
    int Ref;

private:
    HANDLE handle;
};

#else // defined(_WIN32)

class bufferObject {
public:
    bufferObject(string const& name, int size, error& err)
        : Data(nullptr)
        , Size(0)
        , Ref(0)
        , fd(-1)
    {
        err.clear();

        int perm = S_IRUSR | S_IWUSR | S_IRGRP | S_IWGRP;
        fd = open((string("/dev/shm/") + name).c_str(), O_RDWR, perm);
        if ((fd < 0) || (ftruncate(fd, size) < 0) || ((Data = (char*)mmap(0, size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0)) == MAP_FAILED)) {
            err = "Could not open shared memory: " + (std::string)(strerror(errno));
            return;
        }

        Size = size;
    }

    ~bufferObject()
    {
        if (Data) {
            munmap(Data, Size);
            Data = nullptr;
        }
        if (fd >= 0) {
            close(fd);
            fd = -1;
        }
    }

    char* Data;
    int Size;
    int Ref;

private:
    int fd;
};

#endif

error mapBuffer(string const& name, int size, buffer& b)
{
    error err;
    b.reset(new bufferObject(name, size, err));
    return err;
}

// Shared bufferObject collection
class buffers {
public:
    // Load loads shared bufferObject data and increments the client reference
    buffer Load(string const& handle, int size)
    {
        std::lock_guard<std::mutex> lock(mu);

        // Load shared bufferObject data
        if (m.count(handle) == 0) {
            buffer b;
            error err = mapBuffer(handle, size, b);
            if (!err.empty()) {
                return nullptr;
            }
            m[handle] = b;
        }

        // Increment bufferObject ref
        m[handle]->Ref += 1;

        // Done
        return m[handle];
    };

    // Delete deletes the client shared bufferObject references
    void Delete(string const& handle)
    {
        std::lock_guard<std::mutex> lock(mu);

        // Free shared buffer
        if (m.count(handle) != 0) {
            m[handle]->Ref -= 1;
            if (m[handle]->Ref <= 0) {
                // m.erase(handle);
            }
        }
    };

private:
    std::map<string, buffer> m;
    std::mutex mu;
};

//---------------------------------------------------------------------------
// pipeConn class
//---------------------------------------------------------------------------

#ifdef _WIN32

class pipeConn {
public:
    error dial(string const& addr)
    {
        string path = "\\\\.\\pipe\\" + addr;
        handle = CreateFileW(
            std::wstring(path.begin(), path.end()).c_str(),
            GENERIC_READ | GENERIC_WRITE, // read and write access
            0,
            NULL, // default security attributes
            OPEN_EXISTING, // opens existing pipe
            0, // default attributes
            NULL // no template file
        );

        if (handle == INVALID_HANDLE_VALUE) {
            return "Could not open named pipe: " + getLastErrorString();
        }
        return "";
    }

    error Read(const char* p, int l, int& n)
    {
        DWORD bytesRead;
        if (!ReadFile(
                handle, // pipe handle
                LPVOID(p), // bufferObject to receive reply
                l, // size of bufferObject
                &bytesRead, // number of bytes read
                NULL)) // not overlapped
        {
            return "ReadFile from pipe failed. Error=" + getLastErrorString();
        }
        n = int(bytesRead);
        return "";
    }

    error Write(const char* p, int l, int& n)
    {
        DWORD bytesWritten;
        if (!WriteFile(
                handle, // pipe handle
                LPVOID(p), // message
                l, // message length
                &bytesWritten, // bytes written
                NULL)) // not overlapped
        {
            return "WriteFile to pipe failed:" + std::to_string(GetLastError());
        }

        n = int(bytesWritten);
        return "";
    }

    error Close()
    {
        if (!CloseHandle(handle))
            return getLastErrorString();

        return "";
    }

private:
    HANDLE handle;
};

#else // defined(_WIN32)

#include <sys/un.h>

class pipeConn {
public:
    error dial(string const& family)
    {
        int data_len = 0;
        struct sockaddr_un remote;
        string p = std::filesystem::temp_directory_path();
        string socket_path = p + "/" + family + ".sock";

        // Create unix socket
        if ((sock = socket(AF_UNIX, SOCK_STREAM, 0)) == -1) {
            return "Unable to create socket";
        }

        remote.sun_family = AF_UNIX;
        strcpy(remote.sun_path, socket_path.c_str());
        data_len = strlen(remote.sun_path) + sizeof(remote.sun_family);

        // Connect to server address
        if (connect(sock, (struct sockaddr*)&remote, data_len) == -1) {
            Close();
            return "Connection failed";
        }

        return "";
    }

    error Read(char* p, int l, int& n)
    {
        n = recv(sock, p, l, 0);
        if (n == -1) {
            n = 0;
            return "recv() failed";
        }

        return "";
    }

    error Write(const char* p, int l, int& n)
    {
        n = 0;
        if (send(sock, p, l, 0) == -1) {
            return "send() failed";
        }
        n = l;

        return "";
    }

    error Close()
    {
        if (!shutdown(sock, SHUT_RDWR)) {
            close(sock);
        }
        sock = 0;

        return "";
    }

private:
    SOCKET sock;
};

#endif

//----------------------------------------------------------------------------
// roundTripperObject class
//----------------------------------------------------------------------------

class roundTripperObject {
public:
    roundTripperObject()
        : scheme("")
        , connBuffer(std::vector<char>(1024 * 1024 * 16, 0))
    {
    }

    ~roundTripperObject()
    {
        std::lock_guard<std::mutex> lock(mu);
        scheme = "";
        conn.Close();
    }

    error call(string const& method, string const& url, json const& body, json& response)
    {
        std::lock_guard<std::mutex> lock(mu);

        // Validate url
        string urlScheme = url.substr(0, url.find(":"));
        if (urlScheme == "") {
            return "url scheme not found";
        }

        // Validate connection
        if (scheme != "" && scheme != urlScheme) {
            return "bad url scheme: " + scheme + "!=" + urlScheme;
        }
        if (scheme == "") {
            error err = open(urlScheme);
            if (!err.empty()) {
                return err;
            }
            scheme = urlScheme;
        }

        // JSON Request
        {
            json request = {
                { "method", method },
                { "url", url },
                { "body", body },
            };
            error err = encode(request);
            if (!err.empty()) {
                return err;
            }
        }

        // JSON Response
        {
            error err = decode(response);
            if (!err.empty()) {
                return err;
            }
        }

        // Check for error
        {
            error err;
            string location;
            if (response.contains("error")) {
                err = response.at("error").get<string>();
            }
            if (response.contains("location")) {
                location = response.at("location").get<string>();
            }
            if (err == ErrRedirect) {
                return "redirect: " + location;
            } else if (err != "") {
                return err;
            }
        }

        // Done
        return "";
    }

private:
    error open(string const& addr) { return conn.dial(addr); }

    error decode(json& j)
    {
        int n = 0;
        error err = conn.Read(connBuffer.data(), (int)connBuffer.size(), n);
        if (!err.empty()) {
            conn.Close();
            return err;
        }
        j = json::parse(string(connBuffer.data(), n));
        return "";
    }

    error encode(json const& j)
    {
        int n = 0;
        string data = j.dump();
        error err = conn.Write(data.c_str(), (int)data.size(), n);
        if (!err.empty()) {
            conn.Close();
            return err;
        }
        return "";
    }

    string scheme;
    pipeConn conn;
    std::vector<char> connBuffer;
    std::mutex mu;
};

//---------------------------------------------------------------------------
// PacketObject class
//---------------------------------------------------------------------------
static buffers sharedBuffers;

PacketObject::PacketObject(const json& j, error& err)
{
    err.clear();

    // Parse fields
    if (!Load(j, "track", track)) {
        err = "error loading 'track' field";
        return;
    }
    if (!Load(j, "media", media)) {
        err = "error loading 'media' field";
        return;
    }
    if (!Load(j, "signal", signal)) {
        err = "error loading 'signal' field";
        return;
    }
    if (!Load(j, "timestamp", timestamp)) {
        err = "error loading 'timestamp' field";
        return;
    }
    Load(j, "meta", meta);

    // Decode base64 data field
    if (Load(j, "data", Data)) {
        data = Data.data();
        length = (int)Data.size();
        // Done
        return;
    }

    // Capture shared memory reference
    if (!Load(j, "ref", ref)) {
        // Done
        return;
    }

    // Load shared memory fields
    if (!Load(j, "handle", handle)) {
        err = "error loading 'handle' field";
        return;
    }
    int ptr, len, cap;
    if (!Load(j, "ptr", ptr)) {
        err = "error loading 'ptr' field";
        return;
    }
    if (!Load(j, "len", len)) {
        err = "error loading 'len' field";
        return;
    }
    if (!Load(j, "cap", cap)) {
        err = "error loading 'cap' field";
        return;
    }

    // Load shared memory data
    buf = sharedBuffers.Load(handle, cap);
    if (buf == nullptr) {
        err = "error loading shared memory 'handle': " + handle;
        return;
    };
    data = buf->Data + ptr;
    length = len;

    // Done
    return;
}

PacketObject::~PacketObject()
{
    if (ref != "") {
        sharedBuffers.Delete(handle);
        if (roundTripper != nullptr) {
            json r;
            roundTripper->call("DELETE", ref, json(), r);
        }
    }
}

// ---------------------------------------------------------------------------
// ClientObject class
// ---------------------------------------------------------------------------

Client::Client()
    : roundTripper(std::make_shared<roundTripperObject>())
{
}

error Client::call(string const& method, string const& url, json const& body, json& response)
{
    return roundTripper->call(method, url, body, response);
}

} // namespace lt
