import base64
import ctypes
import json
import mmap
import threading
import os


#
# Errors
#

class error(str):
    None


ErrEOF = "EOF"
ErrRedirect = "redirect"
ErrClosed = "lt100agent not found"


def ErrorIs(err: error, target: error) -> bool:
    return err.find(target) != -1


def RedirectLocation(err: error) -> str:
    return err.split(ErrRedirect + ": ")[1]


#
# JSON
#


class JSON(dict):
    def String(self) -> str:
        return json.dumps(self, indent=4)


#
# Shared buffers
#


class buffer:
    fd: int
    handle: mmap
    ref: int

    def __init__(self):
        self.ref = 0

    def __del__(self):
        self.handle.close()


def mapBuffer(name: str, length: int) -> tuple[buffer, error]:
    b = buffer()
    if os.name == 'nt':
        b.handle = mmap.mmap(-1, length, tagname=name)
    else:
        b.fd = os.open('/dev/shm/' + name, os.O_RDWR)
        b.handle = mmap.mmap(b.fd, size)
    return b, None


class buffers:
    def __init__(self):
        self.m = dict()             # Buffers references
        self.clients = dict()       # Clients references
        self.mu = threading.Lock()

    # Load loads shared buffer data and increments the client reference
    def Load(self, handle: str, size: int) -> tuple[buffer, bool]:
        with self.mu:
            # Map shared buffer data
            if not handle in self.m:
                b, err = mapBuffer(handle, size)
                if err != None:
                    return None, False
                self.m[handle] = b
            # Increment buffer ref
            self.m[handle].ref += 1
            # Done
            return self.m[handle], True

    # Delete deletes the client shared buffer references
    def Delete(self, handle: str):
        with self.mu:
            if handle in self.m:
                self.m[handle].ref -= 1


sharedBuffers = buffers()


#
# pipeConn
#

if os.name == 'nt':
    import win32file

    class pipeConn:
        def __init__(self, addr: str):
            self.handle = win32file.CreateFile(
                "\\\\.\\pipe\\" + addr,
                win32file.GENERIC_READ | win32file.GENERIC_WRITE,
                0,                          # mode
                None,                       # default security attributes
                win32file.OPEN_EXISTING,    # opens existing pipe
                0,                          # default attributes
                None,                       # no template file
            )
            self.buf = win32file.AllocateReadBuffer(1024 * 1024 * 8)

        def Read(self) -> tuple[bytes, error]:
            err, data = win32file.ReadFile(self.handle, self.buf)
            if err != 0:
                return None, "ReadFile error " + str(err)
            # Done
            return data, None

        def Write(self, data: bytes) -> error:
            err, _ = win32file.WriteFile(self.handle, data)
            if err != 0:
                return "WriteFile error " + str(err)
            # Done
            return None


elif os.name == 'posix':
    import socket
    import tempfile

    class pipeConn:
        def __init__(self, addr: str):
            # Create the Unix socket client
            self.client = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
            # Connect to the server
            self.client.connect(tempfile.gettempdir()+'/'+addr+'.sock')

        def Read(self) -> tuple[bytes, None]:
            data = self.client.recv(1024 * 1024 * 8)
            return data, None

        def Write(self, data: bytes) -> None:
            self.client.sendall(data)
            return None

else:
    raise ImportError(
        "Sorry: no implementation for your platform ('{}') available".format(
            os.name)
    )


def dialPipe(addr: str) -> tuple[pipeConn, error]:
    return pipeConn(addr), None


#
# RoundTripper
#

class RoundTripper:
    def __init__(self):
        self.scheme = ""
        self.conn = None
        self.mu = threading.Lock()

    def open(self, addr: str) -> error:
        self.conn, err = dialPipe(addr)
        return err

    def Call(self, method: str, url: str, body: JSON) -> tuple[JSON, error]:
        with self.mu:
            # Validate url
            urlScheme = url.split(':')[0]
            if (urlScheme == ""):
                return None, "url scheme not found"

            # Validate connection
            if self.scheme != "" and self.scheme != urlScheme:
                return None, "bad url scheme: " + self.scheme + "!=" + urlScheme
            if self.scheme == "":
                err = self.open(urlScheme)
                if err != None:
                    return None, err
                self.scheme = urlScheme

            # JSON request
            err = self.encode({
                'method': method,
                'url': url,
                'body': body,
            })
            if err != None:
                self.scheme = ""
                self.conn = None
                return None, err

            # JSON response
            response, err = self.decode()
            if err != None:
                self.scheme = ""
                self.conn = None
                return None, err

            # Check for error
            if 'error' in response:
                if response['error'] == ErrRedirect and 'location' in response:
                    return None, response['error'] + ": " + response['location']
                # Done
                return None, response['error']

            # Done
            return response, None

    def decode(self) -> tuple[JSON, error]:
        # Read
        data, err = self.conn.Read()
        if err != None:
            return None, err
        # Decode
        j = json.loads(data.decode())
        # Done
        if j == None:
            return JSON(), None
        return JSON(j), None

    def encode(self, j: JSON) -> error:
        if j == None:
            j = JSON()
        # Encode
        data = json.dumps(j)
        # Write
        return self.conn.Write(data.encode())

#
# Client
#


class Client:
    roundTripper = RoundTripper()

    def Get(self, url: str) -> tuple[JSON, error]:
        j, err = self.call('GET', url, None)
        if err != None:
            return None, err

        # Worker is being created
        if url.find("/client/jobs") != -1:
            # Validate worker
            err = ValidateWorker(j)
            if err != None:
                return None, err

            # Referenced packet
            if 'packets' in j:
                for packet in j['packets']:
                    if 'ref' in packet and 'handle' in packet:
                        packet = SharedPacket(packet, self.roundTripper)

        # Done
        return j, None

    def Post(self, url: str, body: JSON) -> tuple[JSON, error]:
        return self.call('POST', url, body)

    def Delete(self, url: str) -> error:
        _, err = self.call('DELETE', url, None)
        return err

    def call(self, method: str, url: str, body: JSON) -> tuple[JSON, error]:
        return self.roundTripper.Call(method, url, body)

#
# High level functions
#


def Get(url: str) -> tuple[JSON, error]:
    return Client().Get(url)


def Post(url: str, body: JSON) -> tuple[JSON, error]:
    return Client().Post(url, body)


def Delete(url: str) -> error:
    return Client().Delete(url)


#
# JSON models
#

def ValidateWorker(j: JSON) -> error:
    if not "name" in j:
        return "missing key: 'name'"
    if not "location" in j:
        return "missing key: 'location'"

    if not "start" in j:
        return "missing key: 'start'"
    if not "duration" in j:
        return "missing key: 'duration'"
    if not "length" in j:
        return "missing key: 'length'"
    if not "status" in j:
        return "missing key: 'status'"

     # Validate packets
    if 'packets' in j:
        for packet in j['packets']:
            err = ValidatePacket(packet)
            if err != None:
                return err

    # Done
    return None


def ValidatePacket(j: JSON) -> error:
    if not "media" in j:
        return "missing key: 'media'"

    if not "track" in j:
        return "missing key: 'track'"

    if not "signal" in j:
        return "missing key: 'signal'"

    if not "timestamp" in j:
        return "missing key: 'timestamp'"

    # Decode base64 data field
    if "data" in j:
        j["data"] = base64.decodebytes(str(j['data']).encode())
        return None

    # Validate shared memory fields
    if not "ref" in j:
        return None
    if not "handle" in j:
        return "missing key: 'handle'"
    if not "ptr" in j:
        return "missing key: 'ptr'"
    if not "len" in j:
        return "missing key: 'len'"
    if not "cap" in j:
        return "missing key: 'cap'"

    # Load shared memory data
    j["buf"], ok = sharedBuffers.Load(j['handle'], j["cap"])
    if not ok:
        return None, "error loading shared memory 'data'"
    # TODO: Server has to return small individual shared buffer instead of a segmented big one
    j["data"] = bytearray(j["buf"].handle[j['ptr']:j['ptr'] + j['len']])
    # j["data"] = j["buf"].handle

    # Done
    return None


class SharedPacket(JSON):
    def __init__(self, j: JSON, roundTripper: RoundTripper):
        self.update(j)
        self["roundTripper"] = roundTripper

    def __del__(self):
        if 'roundTripper' in self and 'ref' in self and 'handle' in self:
            sharedBuffers.Delete(self['handle'])
            self["roundTripper"].Call('DELETE', self['ref'], None)
