# pip install pywin32


import base64
import json
import mmap
import os
import threading
from typing import Any


if os.name == 'nt':
    import win32file
    import pywintypes

    PyHANDLE = Any
    PyOVERLAPPEDReadBuffer = Any
elif os.name == 'posix':
    import socket
    import tempfile
else:
    raise ImportError(f"Platform '{os.name}' not supported")


# -----------------------------------------------------------------------------
# Errors
# -----------------------------------------------------------------------------

class SDKError(Exception): pass

class RedirectError(SDKError):
    def __init__(self, location: str):
        super().__init__(f"redirect: {location}")
        self.location = location

class ConnectionError(SDKError): pass
class ProtocolError(SDKError): pass
class BufferError(SDKError): pass
class ValidationError(SDKError): pass

# class error(str):
#     None


ErrEOF = "EOF"
ErrRedirect = "redirect"
ErrClosed = "use of closed connection"


# def ErrorIs(err: error, target: error) -> bool:
#     return err.find(target) != -1


# def RedirectLocation(err: error) -> str:
#     return err.split(ErrRedirect + ": ")[1]


# -----------------------------------------------------------------------------
# JSON wrapper
# -----------------------------------------------------------------------------

class JSON(dict):
    def String(self) -> str:
        return json.dumps(self, indent=4)


# -----------------------------------------------------------------------------
# Shared buffers
# -----------------------------------------------------------------------------

class buffer:
    fd: int | None
    handle: mmap.mmap | None
    ref: int

    def __init__(self):
        self.fd = None
        self.handle = None
        self.ref = 0

    def __del__(self):
        if self.handle:
            self.handle.close()


def mapBuffer(name: str, length: int) -> buffer:
    b = buffer()
    try:
        if os.name == 'nt':
            b.handle = mmap.mmap(-1, length, tagname=name)
        else:
            b.fd = os.open(f'/dev/shm/{name}', os.O_RDWR)
            b.handle = mmap.mmap(b.fd, length)

    except (OSError, FileNotFoundError) as e:
        raise BufferError(f"Failed to map buffer '{name}': {e}") from e

    return b


class buffers:
    def __init__(self):
        self.m: dict[str, buffer] = {} # Buffers references
        self.clients = {}              # Clients references
        self.mu = threading.Lock()

    # Load loads shared buffer data and increments the client reference
    def Load(self, handle: str, size: int) -> buffer:
        with self.mu:
            # Map shared buffer data
            if handle not in self.m:
                self.m[handle] = mapBuffer(handle, size)
            self.m[handle].ref += 1

            return self.m[handle]

    # Delete deletes the client shared buffer references
    def Delete(self, handle: str):
        with self.mu:
            if handle in self.m:
                self.m[handle].ref -= 1


sharedBuffers = buffers()


# -----------------------------------------------------------------------------
# pipeConn
# -----------------------------------------------------------------------------

if os.name == 'nt':
    class pipeConn:
        handle: PyHANDLE
        buf: PyOVERLAPPEDReadBuffer

        def __init__(self, addr: str):
            try:
                self.handle = win32file.CreateFile(
                    "\\\\.\\pipe\\" + addr,
                    win32file.GENERIC_READ | win32file.GENERIC_WRITE,
                    0,                          # mode
                    None,                       # default security attributes
                    win32file.OPEN_EXISTING,    # opens existing pipe
                    0,                          # default attributes
                    None,                       # no template file
                )
                self.buf = win32file.AllocateReadBuffer(1024 * 1024 * 8)
            except Exception as e:
                raise ConnectionError(f"Failed to connect to pipe '{addr}': {e}") from e

        def Read(self) -> str:
            err, data = win32file.ReadFile(self.handle, self.buf)
            if err != 0:
                raise ConnectionError(f"ReadFile error {err}")

            return data

        def Write(self, data: bytes):
            err, _ = win32file.WriteFile(self.handle, data)
            if err != 0:
                raise ConnectionError(f"WriteFile error {err}")


elif os.name == 'posix':
    class pipeConn:
        def __init__(self, addr: str):
            try:
                # Create the Unix socket client and connect to the server
                self.client = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
                self.client.connect(tempfile.gettempdir()+'/'+addr+'.sock')
            except Exception as e:
                raise ConnectionError(f"Failed to connect to socket '{addr}': {e}") from e

        def Read(self) -> bytes:
            return self.client.recv(1024 * 1024 * 8)

        def Write(self, data: bytes):
            self.client.sendall(data)


def dialPipe(addr: str) -> pipeConn:
    return pipeConn(addr)


# -----------------------------------------------------------------------------
# RoundTripper
# -----------------------------------------------------------------------------

class RoundTripper:
    def __init__(self):
        self.scheme: str = ""
        self.conn: pipeConn | None = None
        self.mu = threading.Lock()

    def open(self, addr: str):
        self.conn = dialPipe(addr)

    def Call(self, method: str, url: str, body: JSON | None = None) -> JSON:
        with self.mu:
            # Validate url
            urlScheme = url.split(':')[0]
            if (urlScheme == ""):
                raise ProtocolError("url scheme not found")

            # Validate connection
            if self.scheme != "" and self.scheme != urlScheme:
                raise ProtocolError("bad url scheme: " + self.scheme + "!=" + urlScheme)

            if self.scheme == "":
                self.open(urlScheme)
                self.scheme = urlScheme

            # Encode JSON request
            self.encode(JSON({
                'method': method,
                'url': url,
                'body': body or JSON()
            }))

            # Decode JSON response
            response = self.decode()

            # Check for error
            if 'error' in response:
                if response['error'] == ErrRedirect and 'location' in response:
                    raise RedirectError(response['location'])

                raise SDKError(response['error'])

            # Done
            return response

    def decode(self) -> JSON:
        # Read
        data = self.conn.Read()
        # print(data)
        # Decode
        j = json.loads(data)
        # print(j)
        # Done
        return JSON(j or {})

    def encode(self, j: JSON):
        data = json.dumps(j).encode()
        self.conn.Write(data)


# -----------------------------------------------------------------------------
# Client
# -----------------------------------------------------------------------------

class Client:
    def __init__(self):
        self.roundTripper = RoundTripper()

    def Get(self, url: str) -> JSON:
        j = self.roundTripper.Call('GET', url)
        self._validateWorkerIfNeeded(url, j)
        return j

    def Post(self, url: str, body: JSON) -> JSON:
        return self.roundTripper.Call('POST', url, body)

    def Delete(self, url: str):
        self.roundTripper.Call('DELETE', url)

    def _validateWorkerIfNeeded(self, url: str, j: JSON):
        if "/client/jobs" in url:
            ValidateWorker(j)

            # Referenced packet
            if 'packets' in j:
                for packet in j['packets']:
                    if 'ref' in packet and 'handle' in packet:
                        SharedPacket(packet, self.roundTripper)


# -----------------------------------------------------------------------------
# High level functions
# -----------------------------------------------------------------------------

def Get(url: str) -> JSON:
    return Client().Get(url)


def Post(url: str, body: JSON) -> JSON:
    return Client().Post(url, body)


def Delete(url: str):
    Client().Delete(url)


# -----------------------------------------------------------------------------
# JSON models
# -----------------------------------------------------------------------------

def ValidateWorker(j: JSON):
    required_keys = ["name", "location", "start", "duration", "length", "status"]
    for k in required_keys:
        if k not in j:
            raise ValidationError(f"Missing key '{k}'")

     # Validate packets
    if 'packets' in j:
        for packet in j['packets']:
            ValidatePacket(packet)


def ValidatePacket(j: JSON):
    required_keys = ["media", "track", "signal", "timestamp"]
    for k in required_keys:
        if k not in j:
            raise ValidationError(f"Missing key '{k}' in packet")

    # Decode base64 data field
    if "data" in j:
        j["data"] = base64.decodebytes(str(j['data']).encode())

    # Validate shared memory fields
    if "ref" in j:
        required_keys = ["handle", "len", "cap"]
        for k in required_keys:
            if k not in j:
                raise ValidationError(f"Missing key '{k}' in shared memory")

        # Load shared memory data
        buf = sharedBuffers.Load(j['handle'], j["cap"])
        j["buf"] = buf
        j["data"] = bytearray(buf.handle[:j["len"]])


# -----------------------------------------------------------------------------
# SharedPacket
# -----------------------------------------------------------------------------

class SharedPacket(JSON):
    buf: buffer | None

    def __init__(self, j: JSON, roundTripper: RoundTripper):
        super().__init__(j)
        self["roundTripper"] = roundTripper
        if 'buf' in j:
            self.buf = j['buf']

    def __del__(self):
        if self.buf and 'ref' in self and 'handle' in self:
            sharedBuffers.Delete(self['handle'])
            self["roundTripper"].Call('DELETE', self['ref'])
