import numpy as np
from OpenGL.GL import *
from OpenGL.GL import shaders

# pip install pyopengl

# Exception for OpenGL errors
class GlError(Exception):
    pass

def check_gl_error():
    err = glGetError()
    if err != GL_NO_ERROR:
        raise GlError(f"OpenGL error: {err}")


def compileShader(type: int, source: str) -> int:
    shader = glCreateShader(type)
    glShaderSource(shader, source)
    glCompileShader(shader)

    # Check status
    if not glGetShaderiv(shader, GL_COMPILE_STATUS):
        log = glGetShaderInfoLog(shader)
        glDeleteShader(shader)
        raise GlError(f"Shader compile failed:\n{log.decode() if isinstance(log, (bytes, bytearray)) else log}")

    return shader


def linkProgram(vertexSource: str, fragmentSource: str) -> int:
    vertexShader = compileShader(GL_VERTEX_SHADER, vertexSource)
    fragmentShader = compileShader(GL_FRAGMENT_SHADER, fragmentSource)

    # Shader program
    program = glCreateProgram()
    glAttachShader(program, vertexShader)
    glAttachShader(program, fragmentShader)
    glLinkProgram(program)

    # Clean no longer needed shaders
    glDeleteShader(vertexShader)
    glDeleteShader(fragmentShader)

    # Check status
    if not glGetProgramiv(program, GL_LINK_STATUS):
        log = glGetShaderInfoLog(program)
        glDeleteProgram(program)
        raise GlError(f"Program link failed:\n{log.decode() if isinstance(log, (bytes, bytearray)) else log}")

    return program



# ------------------------------------------------------------
# YUYV renderer (each texel = 4 bytes: Y0 U Y1 V)
# ------------------------------------------------------------
_VERTEX_SHADER = """
attribute vec4 a_position;
attribute vec2 a_texCoord;
varying vec2 v_texCoord;

void main() {
    gl_Position = a_position;
    v_texCoord = a_texCoord;
}
"""

_YUYV_FRAGMENT_SHADER = """
#version 400

varying vec2 v_texCoord;
uniform sampler2D y_texture;

// YCbCr to RGB (Rec.601)
// mat3 csc = mat3(
// 	1.,   0.,       1.13983,
// 	1.,  -0.39465, -0.58060,
// 	1.,   2.03211,  0.);

// YCbCr to RGB full (Rec.2020)
mat3 csc = mat3(
    1.0,   0.0000,  1.4746,
    1.0,  -0.1646, -0.5714,
    1.0,   1.8814,  0.0000);

void main (void){
    // Fetch Luma Y - Use GPU native bilinear interpolation
    float y = texture(y_texture, v_texCoord).r;

    // Process YU & YV texel size
    vec2 textureSize = textureSize(y_texture, 0);
    vec2 texelSize0 = 0./textureSize;
    vec2 texelSize1 = 1./textureSize;
    vec2 texelSize2 = 2./textureSize;
    vec2 texelSize3 = 3./textureSize;

    // YU/YV texels
    // YU00 YV10 YU20 YV30 ...
    // YU01 YV11 YU21 YV31 ...

    // YU00 top left coord
    vec2 coord = floor(v_texCoord*textureSize / vec2(2.,1.)) * vec2(2.,1.);
    // YU00 center coord
    coord += .5;
    // YU00 normalized coord [0..1]
    coord /= textureSize;

    // Fetch Chroma Uxy
    float u00 = texture(y_texture, coord + vec2(texelSize0.x, texelSize0.y)).g;
    float u20 = texture(y_texture, coord + vec2(texelSize2.x, texelSize0.y)).g;
    float u01 = texture(y_texture, coord + vec2(texelSize0.x, texelSize1.y)).g;
    float u21 = texture(y_texture, coord + vec2(texelSize2.x, texelSize1.y)).g;

    // Fetch Chroma Vxy
    float v10 = texture(y_texture, coord + vec2(texelSize1.x, texelSize0.y)).a;
    float v30 = texture(y_texture, coord + vec2(texelSize3.x, texelSize0.y)).a;
    float v11 = texture(y_texture, coord + vec2(texelSize1.x, texelSize1.y)).a;
    float v31 = texture(y_texture, coord + vec2(texelSize3.x, texelSize1.y)).a;

    // Interpolation weight
    vec2 w = fract((v_texCoord*textureSize) / vec2(2.,1.));

    // YUV
    vec3 yuv;
    yuv.r = y;
    yuv.g = mix(
        mix(u00, u20, w.x),
        mix(u01, u21, w.x), w.y
    ) -0.5;
    yuv.b = mix(
        mix(v10, v30, w.x),
        mix(v11, v31, w.x), w.y
    ) -0.5;

    // RGBA
    gl_FragColor = vec4(yuv*csc, 1.);
}
"""

class YUYVRenderer:
    def __init__(self):
        # compile/link program
        self.shaderProgram = linkProgram(_VERTEX_SHADER, _YUYV_FRAGMENT_SHADER)

        # Specify the layout of the vertex data
        self.posAttrib = glGetAttribLocation(self.shaderProgram, "a_position")
        self.texAttrib = glGetAttribLocation(self.shaderProgram, "a_texCoord")

        # uniform locations (separate from texture ids)
        self.yLocation = glGetUniformLocation(self.shaderProgram, "y_texture")

        # Vertex and element buffers
        self.vbo = glGenBuffers(1)
        self.ebo = glGenBuffers(1)

        # single texture to store YUYV interleaved data (RGBA texel = Y0 U Y1 V)
        self.yuyvTexture = glGenTextures(1)

        # Black background
        glClearColor(0.0, 0.0, 0.0, 0.0)
        check_gl_error()


    def Draw(self, width: int, height: int, data: bytes) -> None:
        if width % 2 != 0:
            raise GlError("Frame width must be even for YUYV")

        # zero-copy view
        arr = np.frombuffer(data, dtype=np.uint8)
        expected_len = width * height * 2
        if arr.size != expected_len:
            raise GlError(f"YUYV data length mismatch: got {arr.size}, expected {expected_len}")

        # use program
        glUseProgram(self.shaderProgram)

        # Vertex data (pos + texcoord interleaved)
        vertices = np.array([
            -1.0,  1.0, 0.0, 0.0,
            -1.0, -1.0, 0.0, 1.0,
             1.0, -1.0, 1.0, 1.0,
             1.0,  1.0, 1.0, 0.0,
        ], dtype=np.float32)

        indices = np.array([
            0,1,2,
            0,2,3
        ], dtype=np.uint32)

        # upload vertex & index buffers (correct byte sizes)
        glBindBuffer(GL_ARRAY_BUFFER, self.vbo)
        glBufferData(GL_ARRAY_BUFFER, vertices.nbytes, vertices, GL_STATIC_DRAW)
        glVertexAttribPointer(self.posAttrib, 2, GL_FLOAT, GL_FALSE, 4 * 4, ctypes.c_void_p(0*4))
        glEnableVertexAttribArray(self.posAttrib)
        glVertexAttribPointer(self.texAttrib, 2, GL_FLOAT, GL_FALSE, 4 * 4, ctypes.c_void_p(2*4))
        glEnableVertexAttribArray(self.texAttrib)

        glBindBuffer(GL_ELEMENT_ARRAY_BUFFER, self.ebo)
        glBufferData(GL_ELEMENT_ARRAY_BUFFER, indices.nbytes, indices, GL_STATIC_DRAW)
        check_gl_error()

        glPixelStorei(GL_UNPACK_ALIGNMENT, 1)

        # Upload texture: width_tex = width/2, height_tex = height, format = GL_RGBA
        glActiveTexture(GL_TEXTURE0)
        glBindTexture(GL_TEXTURE_2D, self.yuyvTexture)
        glTexImage2D(GL_TEXTURE_2D, 0, GL_RGBA, width // 2, height, 0, GL_RGBA, GL_UNSIGNED_BYTE, arr)
        glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_LINEAR)
        glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_LINEAR)
        glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE)
        glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE)
        glUniform1i(self.yLocation, 0)

        # draw
        glDrawElements(GL_TRIANGLES, 6, GL_UNSIGNED_INT, None)
        check_gl_error()

    def delete(self):
        # cleanup resources if needed
        try:
            if self.vbo:
                glDeleteBuffers(1, [self.vbo])
            if self.ebo:
                glDeleteBuffers(1, [self.ebo])
            if self.yuyvTexture:
                glDeleteTextures(1, [self.yuyvTexture])
            if self.shaderProgram:
                glDeleteProgram(self.shaderProgram)
        except Exception:
            pass


# New creates a yuv422 OpenGL surface
def NewYUYV() -> YUYVRenderer:
    return YUYVRenderer()



# ------------------------------------------------------------
# NV12 renderer (Y plane + interleaved UV plane)
# ------------------------------------------------------------
_NV12_FRAGMENT_SHADER = """
#version 400

varying vec2 v_texCoord;
uniform sampler2D y_texture;
uniform sampler2D uv_texture;

// YCbCr to RGB full (Rec.2020)
mat3 csc = mat3(
    1.0,   0.0000,  1.4746,
    1.0,  -0.1646, -0.5714,
    1.0,   1.8814,  0.0000);

void main (void) {
    vec3 yuv;
    yuv.r = texture2D(y_texture, v_texCoord).r;
    yuv.g = texture2D(uv_texture, v_texCoord).r - 0.5;
    yuv.b = texture2D(uv_texture, v_texCoord).g - 0.5;

    gl_FragColor = vec4(yuv*csc, 1.0);
}
"""

class NV12Renderer:
    def __init__(self):
        # compile/link program
        self.shaderProgram = linkProgram(_VERTEX_SHADER, _NV12_FRAGMENT_SHADER)

	    # attribute locations
        self.posAttrib = glGetAttribLocation(self.shaderProgram, "a_position")
        self.texAttrib = glGetAttribLocation(self.shaderProgram, "a_texCoord")

        # uniform locations (separate from texture ids)
        self.yLocation = glGetUniformLocation(self.shaderProgram, "y_texture")
        self.uvLocation = glGetUniformLocation(self.shaderProgram, "uv_texture")

        # Vertex and element buffers
        self.vbo = glGenBuffers(1)
        self.ebo = glGenBuffers(1)

        # Texture object ids
        self.yTexture = glGenTextures(1)
        self.uvTexture = glGenTextures(1)

        # Black background
        glClearColor(0.0, 0.0, 0.0, 0.0)
        check_gl_error()


    def Draw(self, width: int, height: int, data: bytes) -> None:
        if width % 2 != 0:
            raise GlError("Frame width must be even for NV12")

        # zero-copy view
        arr = np.frombuffer(data, dtype=np.uint8)
        expected_len = (width * height * 3) // 2
        if arr.size != expected_len:
            raise GlError(f"NV12 data length mismatch: got {arr.size}, expected {expected_len}")

        # use program
        glUseProgram(self.shaderProgram)

        # Vertex data (pos + texcoord interleaved)
        vertices = np.array([
            -1.0,  1.0, 0.0, 0.0,
            -1.0, -1.0, 0.0, 1.0,
             1.0, -1.0, 1.0, 1.0,
             1.0,  1.0, 1.0, 0.0,
        ], dtype=np.float32)

        indices = np.array([
            0,1,2,
            0,2,3
        ], dtype=np.uint32)

        # upload vertex & index buffers (correct byte sizes)
        glBindBuffer(GL_ARRAY_BUFFER, self.vbo)
        glBufferData(GL_ARRAY_BUFFER, vertices.nbytes, vertices, GL_STATIC_DRAW)
        glVertexAttribPointer(self.posAttrib, 2, GL_FLOAT, GL_FALSE, 4 * 4, ctypes.c_void_p(0*4))
        glEnableVertexAttribArray(self.posAttrib)
        glVertexAttribPointer(self.texAttrib, 2, GL_FLOAT, GL_FALSE, 4 * 4, ctypes.c_void_p(2*4))
        glEnableVertexAttribArray(self.texAttrib)

        glBindBuffer(GL_ELEMENT_ARRAY_BUFFER, self.ebo)
        glBufferData(GL_ELEMENT_ARRAY_BUFFER, indices.nbytes, indices, GL_STATIC_DRAW)
        check_gl_error()

        # split NV12 buffer into Y and interleaved UV
        y_size = width * height
        y_plane = arr[:y_size].reshape((height, width))
        uv_plane = arr[y_size:].reshape((height // 2, width // 2, 2))  # U and V interleaved => RG

        glPixelStorei(GL_UNPACK_ALIGNMENT, 1)

        # Upload Y plane to texture unit 0 (single channel)
        glActiveTexture(GL_TEXTURE0)
        glBindTexture(GL_TEXTURE_2D, self.yTexture)
        glTexImage2D(GL_TEXTURE_2D, 0, GL_R8, width, height, 0, GL_RED, GL_UNSIGNED_BYTE, y_plane)
        glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_LINEAR)
        glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_LINEAR)
        glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE)
        glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE)
        glUniform1i(self.yLocation, 0)

        # UV plane
        glActiveTexture(GL_TEXTURE1)
        glBindTexture(GL_TEXTURE_2D, self.uvTexture)
        glTexImage2D(GL_TEXTURE_2D, 0, GL_RG8, width // 2, height // 2, 0, GL_RG, GL_UNSIGNED_BYTE, uv_plane)
        glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MIN_FILTER, GL_LINEAR)
        glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_MAG_FILTER, GL_LINEAR)
        glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_S, GL_CLAMP_TO_EDGE)
        glTexParameteri(GL_TEXTURE_2D, GL_TEXTURE_WRAP_T, GL_CLAMP_TO_EDGE)
        glUniform1i(self.uvLocation, 1)

        # draw
        glDrawElements(GL_TRIANGLES, 6, GL_UNSIGNED_INT, None)
        check_gl_error()

    def delete(self):
        # cleanup resources if needed
        try:
            if self.vbo:
                glDeleteBuffers(1, [self.vbo])
            if self.ebo:
                glDeleteBuffers(1, [self.ebo])
            if self.yTexture:
                glDeleteTextures(1, [self.yTexture])
            if self.uvTexture:
                glDeleteTextures(1, [self.uvTexture])
            if self.shaderProgram:
                glDeleteProgram(self.shaderProgram)
        except Exception:
            pass


# New creates a nv12 OpenGL surface
def NewNV12() -> NV12Renderer:
    return NV12Renderer()



# ------------------------------------------------------------
# NewRenderer creates an OpenGL surface for the specified
# pixel format
# ------------------------------------------------------------

def NewRenderer(format: str) -> object:
    fmt = format.lower()
    if fmt == 'video/yuyv':
        return NewYUYV()
    elif fmt == 'video/nv12':
        return NewNV12()
    else:
        raise GlError(f"Unsupported pixel format: {format}")
