"""
Python library to ease processing of video frames.
"""

import sys
import os
import signal
import argparse
import json
import datetime
import shlex
import ffmpeg  # ffmpeg-python
import numpy as np
import PIL.Image
import cv2


def error(message):
    print(f"{sys.argv[0]}: Error: {message}", file=sys.stderr)
    exit(1)


def resize(image, width, height, resample):
    return np.array(
        PIL.Image
        .fromarray(image)
        .resize((width, height), resample)
    )


def text(image, text, x, y, thickness, font=cv2.FONT_HERSHEY_SIMPLEX):
    margin = 10
    height, width = image.shape[:2]
    size, _ = cv2.getTextSize(text, font, 1, 1)
    scale = min(1, (width - 2 * margin) / size[0])
    x = np.clip(x, margin, width - int(scale * size[0]) - margin)
    y = np.clip(y, margin + int(scale * size[1]), height - margin)
    black = (0, 0, 0)
    white = (255, 255, 255)
    image = cv2.putText(image, text, (x, y), font, scale, black, 3 * thickness)
    image = cv2.putText(image, text, (x, y), font, scale, white, thickness)
    return image


def fix_rect(args, x, y, w, h):
    # Grow into correct aspect ratio.
    nw = max(w, int(h * args.output_aspect))
    nh = max(h, int(w / args.output_aspect))
    # Optionally shrink into working image size.
    scale = min(1, args.working_width / nw, args.working_height / nh)
    nw = int(nw * scale)
    nh = int(nh * scale)
    # Center new region on old region.
    nx = x - (nw - w) // 2
    ny = y - (nh - h) // 2
    # Push into working image.
    nx = max(0, nx) - max(0, (nx+nw) - args.working_width)
    ny = max(0, ny) - max(0, (ny+nh) - args.working_height)
    # Return.
    return nx, ny, nw, nh


def run(description, process, init=None, args_pre=None, args_post=None):
    # Parse arguments
    parser = argparse.ArgumentParser(description=description)
    if args_pre is not None:
        args_pre(parser)
    parser.add_argument(
        '--start-frame', metavar='FRAME', type=int,
        help="starting input frame to process (inclusive)")
    parser.add_argument(
        '--end-frame', metavar='FRAME', type=int,
        help="ending input frame to process (exclusive)")
    parser.add_argument(
        '--decimate', metavar="COUNT", type=int, default=1,
        help="only use every %(metavar)s input frame")
    parser.add_argument(
        '--width', type=int,
        help="width of output video")
    parser.add_argument(
        '--height', type=int,
        help="height of output video")
    parser.add_argument(
        '--codec',
        help="codec of output video"
    )
    parser.add_argument(
        '--pix-fmt',
        help="pixel format of output video")
    parser.add_argument(
        '--extra-args', metavar='ARGS',
        help="""
        extra arguments to pass to FFmpeg, as a JSON object (do not specify the
        leading dash for keys, use a null value for arguments that do not take
        a parameter)
        """)
    parser.add_argument(
        '--no-preview', action='store_true',
        help="do not show any previews")
    parser.add_argument(
        '--no-write', action='store_true',
        help="do not write any files")
    parser.add_argument(
        '--no-audio', action='store_true',
        help="do not include audio in output files")
    parser.add_argument(
        '--overwrite', action='store_true',
        help="overwrite output files")
    parser.add_argument(
        '--debug', metavar='DEBUG_FILE',
        help="produce debug output (leave empty to base filename on input)")
    parser.add_argument(
        '--output', metavar='OUTPUT_FILE',
        help="produce final output (leave empty to base filename on input)")
    parser.add_argument(
        'input', metavar='INPUT_FILE',
        help="input file")
    if args_post is not None:
        args_post(parser)
    args = parser.parse_args()

    # Check arguments.
    root, ext = os.path.splitext(os.path.basename(args.input))
    for file in ["debug", "output"]:
        if getattr(args, file) == "":
            setattr(args, file, f"{root}_{file}{ext}")
    if os.path.isfile(args.output) and not args.overwrite:
        error(f"File exists: '{args.output}', use --overwrite to overwrite.")
    if os.path.isfile(args.debug) and not args.overwrite:
        error(f"File exists: '{args.debug}', use --overwrite to overwrite.")
    if not os.path.isfile(args.input):
        error(f"File does not exist: '{args.input}'.")

    # Probe input.
    probe = ffmpeg.probe(args.input, select_streams='v')

    # Common parameters.
    args.duration = float(probe['format']['duration'])
    args.frame_count = int(probe['streams'][0]['nb_frames'])

    # Input parameters.
    args.input_codec = probe['streams'][0]['codec_name']
    args.input_pix_fmt = probe['streams'][0]['pix_fmt']
    args.input_width = probe['streams'][0]['width']
    args.input_height = probe['streams'][0]['height']
    args.input_aspect = args.input_width / args.input_height
    args.input_fps = (lambda x, y: x / y)(
        *map(int, probe['streams'][0]['r_frame_rate'].split("/"))
    )

    # Output parameters.
    args.output_codec = args.codec or args.input_codec
    args.output_pix_fmt = args.pix_fmt or args.input_pix_fmt
    args.output_width = args.width or args.input_width
    args.output_height = args.height or int(
        args.output_width / args.input_aspect
    )
    args.output_aspect = args.output_width / args.output_height
    args.output_fps = args.input_fps / (args.decimate or 1)

    # Working parameters.
    args.working_width = min(
        args.input_width,
        2 * max(
            args.output_width,
            int(args.output_height * args.input_aspect)
        )
    )
    args.working_height = min(
        args.input_height,
        2 * max(
            args.output_height,
            int(args.output_width / args.input_aspect)
        )
    )
    args.working_width += args.working_width % 2
    args.working_height += args.working_height % 2
    args.thickness = max(1, int(args.working_width / 1000))

    # Fill in default arguments.
    if args.start_frame is None:
        args.start_frame = 0
    if args.end_frame is None:
        args.end_frame = args.frame_count
    if args.extra_args is None:
        args.extra_args = {}
    else:
        try:
            args.extra_args = json.loads(args.extra_args)
        except json.decoder.JSONDecodeError:
            error(f"Extra arguments is not valid JSON.")
        if type(args.extra_args) is not dict:
            error(f"Extra arguments is not a JSON object.")

    # Open files.
    debug_size = f'{args.working_width}x{args.working_height}'
    output_size = f'{args.output_width}x{args.output_height}'
    pipe_args = {
        'format': 'rawvideo',
        'pix_fmt': 'rgb24',
    }
    output_args = {
        'vcodec': args.output_codec,
        'pix_fmt': args.output_pix_fmt,
        'shortest': None,
        **args.extra_args,
    }
    audio_args = (
        [
            ffmpeg
            .input(args.input)
            .audio
            # This works badly for some reason.
            .filter(
                'atrim',
                start=f"{args.start_frame/args.input_fps}s",
                end=f"{args.end_frame/args.input_fps}s",
            )
        ]
        if not args.no_audio else []
    )
    input_stream = (
        ffmpeg
        .input(args.input)
        .output('pipe:', **pipe_args)
        .global_args('-loglevel', 'error')
        # .global_args('-stats')
        .run_async(pipe_stdout=True)
    )
    if not args.no_write:
        if args.debug:
            debug_stream = (
                ffmpeg
                .input('pipe:', **pipe_args, s=debug_size, r=args.output_fps)
                .output(
                    *audio_args, args.debug, **output_args, r=args.output_fps)
                .global_args('-loglevel', 'error')
                .run_async(pipe_stdin=True, overwrite_output=args.overwrite)
            )
        if args.output:
            output_stream = (
                ffmpeg
                .input('pipe:', **pipe_args, s=output_size, r=args.output_fps)
                .output(
                    *audio_args, args.output, **output_args, r=args.output_fps)
                .global_args('-loglevel', 'error')
                .run_async(pipe_stdin=True, overwrite_output=args.overwrite)
            )

    # Set up signal handler.
    sigint = False

    def sigint_handler(signum, frame):
        nonlocal sigint
        sigint = True

    signal.signal(signal.SIGINT, sigint_handler)

    # Call init.
    if init is not None:
        state = init(args)
    else:
        state = None

    # Process.
    try:
        for frame_num in range(args.frame_count):
            # Check for end frame.
            if frame_num >= args.end_frame:
                break

            # Gather and print info.
            time_now = datetime.timedelta(
                seconds=int(frame_num / args.input_fps))
            time_duration = datetime.timedelta(
                seconds=int(args.duration))
            elapsed_frames = f"{frame_num} / {args.frame_count-1}"
            elapsed_time = f"{time_now} / {time_duration}"
            argv = sys.argv
            if len(argv) >= 1:
                argv[0] = os.path.basename(argv[0])
            argv = " ".join(map(shlex.quote, argv))
            sys.stdout.write(f"{elapsed_frames} ({elapsed_time})\r")

            # Read input.
            if sigint:
                break
            input_bytes = input_stream.stdout.read(
                args.input_width * args.input_height * 3
            )
            if not input_bytes:
                break
            input_frame = (
                np
                .frombuffer(input_bytes, np.uint8)
                .reshape([args.input_height, args.input_width, 3])
            )

            # Check for start frame.
            if frame_num < args.start_frame:
                continue

            # Check for decimate frame.
            if (frame_num - args.start_frame) % args.decimate != 0:
                continue

            # Resize to working size.
            frame = resize(
                input_frame,
                args.working_width,
                args.working_height,
                PIL.Image.NEAREST,
            )

            # Call process.
            output_frame, debug_frame = process(args, state, frame, frame_num)

            # Show info.
            if args.debug:
                debug_frame = text(
                    debug_frame,
                    elapsed_frames,
                    0,
                    0,
                    args.thickness,
                )
                debug_frame = text(
                    debug_frame,
                    elapsed_time,
                    args.working_width,
                    0,
                    args.thickness,
                )
                debug_frame = text(
                    debug_frame,
                    argv,
                    0,
                    args.working_height,
                    args.thickness,
                )

            # Show preview windows.
            if not args.no_preview:
                if args.debug:
                    cv2.imshow(
                        f"{args.debug}",
                        cv2.cvtColor(debug_frame, cv2.COLOR_RGB2BGR),
                    )
                if args.output:
                    cv2.imshow(
                        f"{args.output}",
                        cv2.cvtColor(output_frame, cv2.COLOR_RGB2BGR),
                    )
                if cv2.waitKey(1) in (ord('q'), 27):
                    break

            # Write files.
            if not args.no_write:
                if sigint:
                    break
                try:
                    if args.debug:
                        debug_stream.stdin.write(
                            debug_frame
                            .astype(np.uint8)
                            .tobytes()
                        )
                    if args.output:
                        output_stream.stdin.write(
                            output_frame
                            .astype(np.uint8)
                            .tobytes()
                        )
                except BrokenPipeError:
                    # FFmpeg has probably written some error message to stderr,
                    # so just break.
                    break
    except KeyboardInterrupt:
        pass
    finally:
        print("")
        # Close and wait.
        if not args.no_write:
            if args.debug:
                debug_stream.stdin.close()
                debug_stream.wait()
            if args.output:
                output_stream.stdin.close()
                output_stream.wait()
        input_stream.send_signal(signal.SIGINT)
        input_stream.communicate()