#!/usr/bin/env python3

import sys
import PIL
import numpy as np
import scipy.signal
import scipy.ndimage
import matplotlib.colors
import matplotlib.pyplot as plt

INPUT     = sys.argv[1]
OUTPUT    = len(sys.argv) > 2 and sys.argv[2] or None
HIGHLIGHT = [255, 0, 0]

subplot_index = 0
def show(title, *args, **kwargs):
    global subplot_index
    subplot_index += 1
    plt.subplot(2, 3, subplot_index)
    plt.imshow(*args, **kwargs)
    plt.title(title)

def blur(im, sigma, cutoff=0.01):
    length = int(2 * sigma * np.sqrt(-2*np.log(cutoff)))
    signal = scipy.signal.windows.gaussian(length, sigma)
    kernel = np.outer(signal, signal) / sum(signal)**2
    return scipy.signal.fftconvolve(im, kernel, mode='same')

def vec(x, y, adjust=False):
    mag  = np.sqrt(x*x + y*y)
    mag /= np.max(mag)
    ang  = np.arctan2(y, x) / (2*np.pi) + 0.5
    if adjust:
        min = np.min(ang)
        max = np.max(ang)
        ang = (ang - min) / (max - min)
    return matplotlib.colors.hsv_to_rgb(np.moveaxis((ang, mag, mag), 0, 2))

# Load image
gray  = np.array(PIL.Image.open(INPUT).convert('L'), dtype='float32')
sigma = np.hypot(*gray.shape) / 50

# Calculate gradients
dy, dx = np.gradient(gray)
show("Gradient", vec(dx, dy))

# Fold gradients
sel = (dx if np.sum(np.abs(dx)) > np.sum(np.abs(dy)) else dy) < 0
dx[sel] = -dx[sel]
dy[sel] = -dy[sel]
show("Gradient, folded", vec(dx, dy))

# Blur gradients
dx = blur(dx, sigma)
dy = blur(dy, sigma)
show("Gradient, blurred", vec(dx, dy, True))

# Calculate dot product with maximal blurred gradient, normalized
mag = np.sqrt(dx*dx + dy*dy)
dx /= mag
dy /= mag
ind = np.unravel_index(np.argmax(mag), mag.shape)
mdx = dx[ind]
mdy = dy[ind]
mag /= mag[ind]
weight = mag * np.power(np.maximum(0, dx*mdx + dy*mdy), 100)
weight = blur(weight, sigma)
show("Weight", weight)

# Mask and display
mask = weight > 1/2
im = np.array(PIL.Image.open(INPUT))
im[mask] = (im[mask] + HIGHLIGHT) / 2
show("Mask", im)

# Rotate and crop
angle  = np.degrees(np.arctan2(mdy, mdx))
gray   = scipy.ndimage.rotate(gray, angle)
mask   = scipy.ndimage.rotate(mask, angle)
my, mx = np.nonzero(mask)
mx = np.min(mx), np.max(mx)
my = np.min(my), np.max(my)
cx = int((mx[1] - mx[0]) * 0.1)
cy = int((my[1] - my[0]) * 0.4)
crop = gray[my[0]+cy:my[1]-cy, mx[0]-cx:mx[1]+cx]
show("Rotated, cropped", crop, 'gray')

# Save/show
plt.tight_layout()
if OUTPUT:
    plt.savefig(OUTPUT, dpi=200)
else:
    plt.show()