#!/usr/bin/env python3

from sympy import *


def mat(*args):
    return Matrix(list(map(list, args))).T

def vec(*args):
    return Matrix(args)

def normalize(v):
    return v / sqrt(v.dot(v))

def elementwise(a, b, f):
    return Matrix([
        [
            f(a[r, c], b[r, c])
            for c in range(a.cols)
        ]
        for r in range(a.rows)
    ])

def mul(a, b):
    return elementwise(a, b, lambda a, b: a * b)

def div(a, b):
    return elementwise(a, b, lambda a, b: a / b)


def window(viewport, depth_range):
    size = vec(*viewport[2:4], depth_range[1]-depth_range[0])
    offs = vec(*viewport[0:2],                depth_range[0])
    s = 0.5 * size
    t = 0.5 * size + offs
    return Matrix([
        [s[0], 0,    0,    t[0]],
        [0,    s[1], 0,    t[1]],
        [0,    0,    s[2], t[2]],
        [0,    0,    0,    1],
    ])

def ortho(left, right, bottom, top, near, far):
    size = vec(right-left, top-bottom, far-near)
    offs = vec(     -left,    -bottom,    +near)
    s = div(vec(2, 2, -2), size)
    t = mul(s, offs) - vec(1, 1, 1)
    return Matrix([
        [s[0], 0,    0,    t[0]],
        [0,    s[1], 0,    t[1]],
        [0,    0,    s[2], t[2]],
        [0,    0,    0,    1],
    ])

def frustum(left, right, bottom, top, near, far):
    z = vec(0, 0, near+far, -1)
    w = vec(0, 0, near*far,  0)
    return ortho(left, right, bottom, top, near, far) * Matrix([
        [near, 0,    z[0], z[0]],
        [0,    near, z[1], w[1]],
        [0,    0,    z[2], w[2]],
        [0,    0,    z[3], w[3]],
    ])

def perspective(fovy, aspect, near, far):
    y = near * tan(0.5 * fovy)
    x = y * aspect
    return frustum(-x, +x, -y, +y, near, far)

def lookat(position, target, up):
    z = normalize(position - target)
    x = normalize(up.cross(z))
    y = z.cross(x)
    R_inv = mat(x, y, z).T
    t_inv = -position
    return mat(R_inv.row(0), R_inv.row(0), R_inv.row(0), R_inv * t_inv)


# https://registry.khronos.org/OpenGL-Refpages/gl4/html/glViewport.xhtml
# https://registry.khronos.org/OpenGL-Refpages/gl4/html/glDepthRange.xhtml
# https://www.khronos.org/opengl/wiki/Vertex_Post-Processing#Viewport_transform
# https://www.songho.ca/opengl/gl_viewport.html
viewport    = symbols('x, y, width, height')
depth_range = symbols('near, far')
pprint(simplify(window(viewport, depth_range)))
# ⎡0.5⋅width      0               0             0.5⋅width + x   ⎤
# ⎢                                                             ⎥
# ⎢    0      0.5⋅height          0             0.5⋅height + y  ⎥
# ⎢                                                             ⎥
# ⎢    0          0       0.5⋅far - 0.5⋅near  0.5⋅far + 0.5⋅near⎥
# ⎢                                                             ⎥
# ⎣    0          0               0                   1         ⎦

# https://registry.khronos.org/OpenGL-Refpages/gl2.1/xhtml/glOrtho.xml
# https://www.songho.ca/opengl/gl_projectionmatrix.html#ortho
left, right, bottom, top, near, far = symbols('left, right, bottom, top, near, far')
pprint(simplify(ortho(left, right, bottom, top, near, far)))
# ⎡    -2                                  left + right⎤
# ⎢────────────       0            0       ────────────⎥
# ⎢left - right                            left - right⎥
# ⎢                                                    ⎥
# ⎢                  -2                    bottom + top⎥
# ⎢     0        ────────────      0       ────────────⎥
# ⎢              bottom - top              bottom - top⎥
# ⎢                                                    ⎥
# ⎢                               -2       -far - near ⎥
# ⎢     0             0        ──────────  ─────────── ⎥
# ⎢                            far - near   far - near ⎥
# ⎢                                                    ⎥
# ⎣     0             0            0            1      ⎦

# https://registry.khronos.org/OpenGL-Refpages/gl2.1/xhtml/glFrustum.xml
# https://www.songho.ca/opengl/gl_projectionmatrix.html#perspective
left, right, bottom, top, near, far = symbols('left, right, bottom, top, near, far')
pprint(simplify(frustum(left, right, bottom, top, near, far)))
# ⎡  -2⋅near                   -left - right              ⎤
# ⎢────────────       0        ─────────────       0      ⎥
# ⎢left - right                 left - right              ⎥
# ⎢                                                       ⎥
# ⎢                -2⋅near     -bottom - top              ⎥
# ⎢     0        ────────────  ─────────────       0      ⎥
# ⎢              bottom - top   bottom - top              ⎥
# ⎢                                                       ⎥
# ⎢                             -far - near   -2⋅far⋅near ⎥
# ⎢     0             0         ───────────   ────────────⎥
# ⎢                              far - near    far - near ⎥
# ⎢                                                       ⎥
# ⎣     0             0             -1             0      ⎦

# https://registry.khronos.org/OpenGL-Refpages/gl2.1/xhtml/gluPerspective.xml
# https://www.songho.ca/opengl/gl_projectionmatrix.html#fov
fovy, aspect, near, far = symbols('fovy, aspect, near, far')
pprint(simplify(perspective(fovy, aspect, near, far)))
# ⎡         1                                                    ⎤
# ⎢────────────────────        0             0            0      ⎥
# ⎢aspect⋅tan(0.5⋅fovy)                                          ⎥
# ⎢                                                              ⎥
# ⎢                            1                                 ⎥
# ⎢         0            ─────────────       0            0      ⎥
# ⎢                      tan(0.5⋅fovy)                           ⎥
# ⎢                                                              ⎥
# ⎢                                     -far - near  -2⋅far⋅near ⎥
# ⎢         0                  0        ───────────  ────────────⎥
# ⎢                                      far - near   far - near ⎥
# ⎢                                                              ⎥
# ⎣         0                  0            -1            0      ⎦

# https://registry.khronos.org/OpenGL-Refpages/gl2.1/xhtml/gluLookAt.xml
# https://www.songho.ca/opengl/gl_camera.html#lookat
position = Matrix(MatrixSymbol('position', 3, 1))
target   = Matrix(MatrixSymbol('target',   3, 1))
up       = Matrix(MatrixSymbol('up',       3, 1))
# pprint(simplify(lookat(position, target, up)))