#!/usr/bin/env python3


## Help
"""
glregistry 1.0

Cache and query the OpenGL registry locally.

Usage:
  glregistry xml
  glregistry xml-path
  glregistry ext            <extension>
  glregistry ext-path       <extension>
  glregistry exts
  glregistry exts-download
  glregistry exts-all       <name>
  glregistry vendors
  glregistry type           <type>
  glregistry aliases        <enum>
  glregistry value          <enum>
  glregistry enum           <value>
  glregistry supports       <name>
  glregistry names          [<support>]
  glregistry groups         [<enum>]
  glregistry enums          [<group>]
  glregistry enums-tree     [<group>]
  glregistry params         [<group>]
  glregistry params-tree    [<group>]
  glregistry audit          [<path>]
  glregistry audit-tree     [<path>]
  glregistry refs           <name>
  glregistry refs-all       <name>
  glregistry -h|--help

Commands:
  xml
    Download the registry XML and open it with an editor.
  xml-path
    Download the registry XML and print its local path.
  ext <extension>
    Download the <extension> spec and open it with an editor.
  ext-path <extension>
    Download the <extension> spec and print its local path.
  exts
    Print the names of all extension specs.
  exts-download
    Download all extension specs.
  exts-all <name>
    Print all downloaded extensions that mention <name>.
  vendors
    Print all vendor abbreviations.
  type <type>
    Print the definition of <type>.
  aliases <enum>
    Print the KHR, ARB, and EXT aliases of <enum>.
  value <enum>
    Print the value of <enum>.
  enum <value>
    Print the enum(s) that has the given <value>.
  supports <name>
    Print the OpenGL version or extension required to use <name>.
  names [<support>]
    Print the names introduced by the OpenGL version or extension <support> if
    given, or all names if omitted. The special values VERSION and EXTENSION
    print the names introduced by all versions or all extensions respectively.
  groups [<enum>]
    Print the groups of <enum> if given, or all groups if omitted.
  enums [<group>]
    Print the enums in <group> if given, or all enums if omitted.
  enums-tree [<group>]
    Print the enums in <group> if given, or all enums if omitted, sorted on
    support, in a tree.
  params [<group>]
    Print the parameter names of <group> if given, or all parameter names if
    omitted.
  params-tree [<group>]
    Print the parameter names of <group> if given, or all parameter names if
    omitted, sorted on count, together with the commands sorted on support, in
    a tree.
  audit [<path>]
    Search files in <path> if given, or the current directory if omitted,
    recursively for OpenGL API names and print them sorted on location,
    support, and name, in a list.
  audit-tree [<path>]
    Search files in <path> if given, or the current directory if omitted,
    recursively for OpenGL API names and print them sorted on support, name,
    and location, in a tree.
  refs <name>
    Print all URLs of all reference pages with name <name>.
  refs-all <name>
    Print all URLs of all reference pages that mention <name>, sorted on
    support, in a tree.

Environment variables:
  GLREGISTRY_CACHE
    The directory to cache files in. Defaults to `$XDG_CACHE_HOME/glregistry`
    or, if `$XDG_CACHE_HOME` is not defined, `$HOME/.cache/glregistry`.
  GLREGISTRY_EDITOR
    The editor to use when opening files. Defaults to `$EDITOR` or, if
    `$EDITOR` is not defined, `editor` if it exists in `$PATH`, else `vi`. The
    value is interpreted by the shell.
  GLREGISTRY_PAGER
    The pager to use when viewing output. Defaults to `$PAGER` or, if `$PAGER`
    is not defined, `pager` if it exists in `$PATH`, else `less` . The value is
    interpreted by the shell. If the `$LESS` environment variable is unset, it
    is set to `FR`.
  GLREGISTRY_COLORS
    If standard out is a terminal, the colors used in output of the `enums`,
    `enums-tree`, `params`, `params-tree`, `audit,` and `audit-tree` commands.
    It uses the same format (and defaults) as GREP_COLORS, i.e. a
    colon-separated list of capabilties: `ms` (matching selected), `fn` (file
    name), `ln` (line and column number), `se` (separators). Added custom
    capabilities are: `ve` (version), `ex` (extension), `un` (unsupported).
    Defaults to `ms=01;31:fn=35:ln=32:se=36:ve=01;34:ex=34:un=01;33`.
"""


## Imports
import os
import sys
import collections
import re
import functools
import subprocess
import shutil
import shlex
import fnmatch
import urllib.request
import docopt
from lxml import etree


## Constants
REFPAGES_URL  = 'https://registry.khronos.org/OpenGL-Refpages/'
REGISTRY_URL  = 'https://github.com/KhronosGroup/OpenGL-Registry/raw/main/'
REFPAGES_GIT  = 'https://github.com/KhronosGroup/OpenGL-Refpages'
XML_PATH      = 'xml/gl.xml'
USER_AGENT    = 'Mozilla/5.0'
REGEX         = r'\b(gl|GL_)[0-9A-Z][0-9A-Za-z_]+\b'
EXCLUDE_DIRS  = ['.?*', '_*']
EXCLUDE_FILES = ['README*', 'TODO*']
BINARY_PEEK   = 1024
INDENT        = 2
ENV_XDG = lambda var, default: (
    os.environ.get(f'GLREGISTRY_{var}') or
    os.path.join(
        (
            os.environ.get(f'XDG_{var}_HOME') or
            os.path.expanduser(default)
        ),
        'glregistry',
    )
)
ENV_PRG = lambda var, default: (
    os.environ.get(f'GLREGISTRY_{var}') or
    os.environ.get(var) or
    shutil.which(var.lower()) or
    default
)
CACHE  = ENV_XDG('CACHE',  os.path.join('~', '.cache'))
EDITOR = ENV_PRG('EDITOR', 'vi')
PAGER  = ENV_PRG('PAGER',  'less')
LESS   = os.environ.get('LESS') or 'FR'
COLORS = collections.defaultdict(str, [
    (color.split('=') + [''])[:2]
    for color in
    filter(None, (
        os.environ.get('GLREGISTRY_COLORS') or
        (lambda x, y: ':'.join([x, x and y]))(
            (
                os.environ.get('GREP_COLORS') or
                'ms=01;31:fn=35:ln=32:se=36'
            ),
            've=01;34:ex=34:un=01;33',
        )
    ).split(':'))
])
IN    = lambda a, v, s: f"contains(concat('{s}',@{a},'{s}'),'{s}{v}{s}')"
MAYBE = lambda a, v:    f"(@{a}='{v}' or not(@{a}))"
TYPES    = "/registry/types"
ENUMS    = "/registry/enums"
COMMANDS = "/registry/commands"
CATEGORY_ATTRIBS = {
    'VERSION': [
        "/registry/feature[@api='gl']",
        'number',
    ],
    'EXTENSION': [
        f"/registry/extensions/extension[{IN('supported','gl','|')}]",
        'name',
    ],
}
REQUIRE = f"require[{MAYBE('api','gl')} and {MAYBE('profile','core')}]"
REMOVE  = f"remove[{ MAYBE('api','gl')} and {MAYBE('profile','core')}]"
CHANGE_PREFIXES = [
    [REQUIRE, '' ],
    [REMOVE,  '<'],
]
VENDORS = ['KHR', 'ARB', 'EXT']
KEY_SUBS = [
    *[[f'^{   prefix}',  f''       ] for _, prefix in CHANGE_PREFIXES],
    *[[f'{    vendor}$', f'{   i}' ] for i, vendor in enumerate(VENDORS)],
    *[[f'^GL_{vendor}_', f'GL_{i}_'] for i, vendor in enumerate(VENDORS)],
]


## Helpers


### `log`
def log(*args, **kwargs):
    print(*args, file=sys.stderr, flush=True, **kwargs)


### `edit`
def edit(paths):
    if EDITOR and sys.stdout.isatty():
        args = ' '.join([EDITOR, *map(shlex.quote, paths)])
        subprocess.run(args, shell=True)
    else:
        for path in paths:
            with open(path) as f:
                shutil.copyfileobj(f, sys.stdout)


### `page`
def page(lines):
    lines = ''.join(f'{line}\n' for line in lines)
    if lines and PAGER and sys.stdout.isatty():
        args = f'LESS={shlex.quote(LESS)} {PAGER}'
        subprocess.run(args, shell=True, text=True, input=lines)
    else:
        sys.stdout.write(lines)


### `color`
def color(capability, string):
    if not sys.stdout.isatty():
        return string
    return f'\x1b[{COLORS[capability]}m{string}\x1b[m'


### `color_supports`
def color_supports(supports):
    for support in supports:
        if support == 'UNSUPPORTED' or support.startswith('<'):
            yield color('un', support)
        elif support.startswith('GL_'):
            yield color('ex', support)
        else:
            yield color('ve', support)


### `indentjoin`
def indentjoin(indent, sep, parts):
    return ' ' * INDENT * indent + color('se', sep).join(map(str, parts))


### `removeprefix`
def removeprefix(prefix, string):
    if string.startswith(prefix):
        return string[len(prefix):]
    return string


### `key`
def key(item):
    for sub in KEY_SUBS:
        item = re.sub(*sub, item)
    return item


### `download`
def download(path, exit_on_failure=True):
    remote = urllib.parse.urljoin(REGISTRY_URL, path)
    local  = os.path.join        (CACHE,        path)
    if not os.path.exists(local):
        try:
            log(f"Downloading '{path}' ... ", end='')
            with urllib.request.urlopen(remote) as response:
                os.makedirs(os.path.dirname(local), exist_ok=True)
                with open(local, 'wb') as f:
                    shutil.copyfileobj(response, f)
        except urllib.error.URLError as error:
            log(error.reason)
            if exit_on_failure:
                exit(1)
        else:
            log(response.reason)
    return local


### `grep`
def grep(
    path=None,
    regex=REGEX,
    exclude_dirs=EXCLUDE_DIRS,
    exclude_files=EXCLUDE_FILES,
    silent=False,
):
    path = path if path else '.'
    def onerror(error, file=None):
        file = removeprefix(f'.{os.path.sep}', file or error.filename)
        if silent:
            pass
        elif isinstance(error, OSError):
            log(f"{file}: {error.strerror}")
        elif isinstance(error, UnicodeDecodeError):
            log(f"{file}: {error.reason}")
        else:
            log(f"{file}: {error}")
    def exclude(excludes, names):
        names = set(names)
        for exclude in excludes:
            names -= set(fnmatch.filter(names, exclude))
        return sorted(names)
    def grep_file(file):
        try:
            with open(file, 'rb') as f:
                if 0 in f.read(BINARY_PEEK):
                    return
            with open(file, errors='ignore') as f:
                file = removeprefix(f'.{os.path.sep}', file)
                for line, string in enumerate(f):
                    for match in re.finditer(regex, string):
                        column, name = match.start(), match.group()
                        yield file, line+1, column+1, name
        except Exception as error:
            onerror(error, file)
    if os.path.isfile(path):
        for match in grep_file(path):
            yield match
    else:
        for root, dirs, files in os.walk(path, onerror=onerror):
            dirs [:] = exclude(exclude_dirs,  dirs)
            files[:] = exclude(exclude_files, files)
            for file in files:
                for match in grep_file(os.path.join(root, file)):
                    yield match


## Commands


### `xml_`
def xml_():
    return etree.parse(download(XML_PATH))


### `xml`
def xml():
    return [download(XML_PATH)]


### `ext_`
def ext_(extension):
    prefix, vendor, name = extension.split('_', 2)
    if prefix != 'GL':
        log("Extension names must start with 'GL_'.")
        exit(1)
    return f'extensions/{vendor}/{vendor}_{name}.txt'


### `ext`
def ext(extension):
    return [download(ext_(extension))]


### `exts`
def exts(xml):
    category, attrib = CATEGORY_ATTRIBS['EXTENSION']
    exts = xml.xpath(f"{category}/@{attrib}")
    return sorted(exts, key=key)


### `exts_download`
def exts_download(xml):
    for ext in exts(xml):
        download(ext_(ext), exit_on_failure=False)
    return []


### `exts_all`
def exts_all(xml, name):
    # exts_download(xml)
    exts_all = set(
        'GL_' + os.path.splitext(os.path.basename(file))[0]
        for name in set([
            name,
            removeprefix('gl',  name),
            removeprefix('GL_', name),
        ])
        for file, *_ in
        grep(os.path.join(CACHE, 'extensions'), rf'\b{name}\b', [], [])
    )
    return sorted(exts_all, key=key)


### `vendors`
def vendors(xml):
    vendors = set(extension.split('_')[1] for extension in exts(xml))
    return sorted(vendors, key=key)


### `type`
def type(xml, type):
    return [xml.xpath(f"string({TYPES}/type/name[text()='{type}']/..)")]


### `aliases`
def aliases(xml, name, supports_=None, vendors_=[]):
    if not vendors_:
        vendors_[:] = vendors(xml)
    for vendor in vendors_:
        if name.endswith(f'_{vendor}'):
            return []
    value_  = value(xml, name)
    if not value_:
        return []
    if not supports_:
        supports_ = supports(xml, name, False)
    if not supports_:
        return []
    aliases = []
    # for vendor in vendors_:
    for vendor in VENDORS:
        alias = f'{name}_{vendor}'
        if supports(xml, alias, False) and value(xml, alias) == value_:
            aliases.append(alias)
    return sorted(aliases, key=key)


### `value`
def value(xml, enum):
    return xml.xpath(f"{ENUMS}/enum[@name='{enum}']/@value")


### `enum`
def enum(xml, value):
    def conv(s):
        return int(s, 16 if s.startswith('0x') else 10)
    value = conv(value)
    enum = (
        enum.get('name')
        for enum in xml.xpath(f"{ENUMS}/enum")
        if conv(enum.get('value')) == value
    )
    return sorted(enum, key=key)


### `supports`
@functools.cache
def supports(xml, name, use_aliases=True):
    category, attrib = CATEGORY_ATTRIBS['EXTENSION']
    if xml.xpath(f"{category}[@{attrib}='{name}']"):
        return ['EXTENSION']
    supports_ = [
        f'{prefix}{support}'
        for category, attrib in CATEGORY_ATTRIBS.values()
        for change, prefix   in CHANGE_PREFIXES
        for support in
        xml.xpath(f"{category}/{change}/*[@name='{name}']/../../@{attrib}")
    ]
    if supports_ and use_aliases:
        supports_.extend(
            support
            for alias   in aliases (xml, name, supports_)
            for support in supports(xml, alias, False)
        )
    return sorted(supports_, key=key)


### `names`
def names(xml, support=None):
    if support in CATEGORY_ATTRIBS.keys():
        category_attribs = [
            [category, ""]
            for category, _ in [CATEGORY_ATTRIBS[support]]
        ]
    elif support:
        category_attribs = [
            [category, f"[@{attrib}='{support}']"]
            for category, attrib in CATEGORY_ATTRIBS.values()
        ]
    else:
        category_attribs = [
            [category, ""]
            for category, _ in CATEGORY_ATTRIBS.values()
        ]
    names = set(
        name
        for category, attrib in category_attribs
        for name in
        xml.xpath(f"{category}{attrib}/{REQUIRE}/*/@name")
    )
    return sorted(names, key=key)


### `groups`
def groups(xml, enum=None):
    name = f"[@name='{enum}']" if enum else ""
    return sorted(set(
        group
        for groups in xml.xpath(f"{ENUMS}/enum{name}/@group")
        for group  in groups.split(',')
    ))


### `enums_`
def enums_(xml, group=None):
    group = f"[{IN('group',group,',')}]" if group else ""
    enums_ = collections.defaultdict(list)
    for enum in xml.xpath(f"{ENUMS}/enum{group}/@name"):
        supports_ = supports(xml, enum)
        if supports_:
            enums_[tuple(supports_)].append(enum)
    return enums_


### `enums`
def enums(xml, group=None):
    enums = [
        enum
        for _, enums in enums_(xml, group).items()
        for enum     in enums
    ]
    return sorted(enums, key=key)


### `enums_tree`
def enums_tree(xml, group=None):
    for supports, enums in sorted(enums_(xml, group).items()):
        yield indentjoin(0, ',', color_supports(supports))
        for enum in sorted(enums):
            yield indentjoin(1, '', [color('ms', enum)])


### `params_`
def params_(xml, group=None):
    group   = f"[@group='{group}']" if group else ""
    counts  = collections.defaultdict(int)
    params_ = collections.defaultdict(lambda: collections.defaultdict(list))
    for xmlcommand in xml.xpath(f"{COMMANDS}/command/param{group}/.."):
        command   = xmlcommand.xpath(f"string(proto/name)")
        supports_ = supports(xml, command)
        if supports_:
            for xmlparam in xmlcommand.xpath(f"param{group}"):
                param = xmlparam.xpath(f"string(name)")
                params_[param][tuple(supports_)].append(command)
                counts[param] -= 1
    return {
        (count, param): params_[param]
        for param, count in counts.items()
    }


### `params`
def params(xml, group=None):
    params = [param for (_, param), _ in params_(xml, group).items()]
    return sorted(params)


### `params_tree`
def params_tree(xml, group=None):
    for (count, param), occurences in sorted(params_(xml, group).items()):
        yield indentjoin(0, ':', [
            color('ms', param),
            color('ln', -count),
        ])
        for supports_, commands in sorted(occurences.items()):
            yield indentjoin(1, ',', color_supports(supports_))
            for command in sorted(commands):
                yield indentjoin(2, '', [color('fn', command)])


### `audit_`
def audit_(xml, path=None):
    audit_ = collections.defaultdict(lambda: collections.defaultdict(list))
    for file, line, column, name in grep(path):
        supports_ = supports(xml, name)
        if not supports_:
            supports_ = ['UNSUPPORTED']
        audit_[tuple(supports_)][name].append([file, line, column])
    return audit_


### `audit`
def audit(xml, path=None):
    for file, line, column, supports, name in sorted(
        [file, line, column, supports, name]
        for supports, names    in audit_(xml, path).items()
        for name, locations    in names.items()
        for file, line, column in locations
    ):
        yield indentjoin(0, ':', [
            color('fn', file),
            color('ln', line),
            color('ln', column),
            indentjoin(0, ',', color_supports(supports)),
            color('ms', name),
        ])


### `audit_tree`
def audit_tree(xml, path=None):
    for supports, names in sorted(audit_(xml, path).items()):
        yield indentjoin(0, ',', color_supports(supports))
        for name, locations in sorted(names.items()):
            yield indentjoin(1, '', [color('ms', name)])
            for file, line, column in sorted(locations):
                yield indentjoin(2, ':', [
                    color('fn', file),
                    color('ln', line),
                    color('ln', column),
                ])


### `refs_`
def refs_(name):
    local = os.path.join(CACHE, os.path.basename(REFPAGES_GIT))
    if not os.path.exists(local):
        os.makedirs(os.path.dirname(local), exist_ok=True)
        subprocess.run(['git', 'clone', REFPAGES_GIT, local])
    refs_ = collections.defaultdict(set)
    for file, *_ in grep(local, rf'\b{name}\b', [], [], True):
        file = removeprefix(f'{local}{os.path.sep}', file)
        try:
            support, *_, dir, base = os.path.normpath(file).split(os.path.sep)
        except:
            continue
        if support.startswith('gl') and dir.endswith('html'):
            support   = removeprefix('gl', support)
            name, ext = os.path.splitext(base)
            url       = urllib.parse.urljoin(REFPAGES_URL, file)
            if ext in ['.xml', '.xhtml']:
                refs_[support].add((name, url))
    return refs_


### `refs`
def refs(name):
    return sorted(
        url
        for support, locations in refs_(name).items()
        for name_, url         in locations
        if name_ == name
    )


### `refs_all`
def refs_all(name):
    for support, locations in sorted(refs_(name).items()):
        yield indentjoin(0, ',', color_supports([support]))
        for name_, url in sorted(locations):
            yield indentjoin(1, ':', [
                color('ms', name_),
                color('fn', url),
            ])


## Main
def main():
    opener = urllib.request.build_opener()
    opener.addheaders = [('User-Agent', USER_AGENT)]
    urllib.request.install_opener(opener)
    args = docopt.docopt(__doc__)
    if args['xml']:           edit(xml          ())
    if args['xml-path']:      page(xml          ())
    if args['ext']:           edit(ext          (args['<extension>']))
    if args['ext-path']:      page(ext          (args['<extension>']))
    if args['exts']:          page(exts         (xml_()))
    if args['exts-download']: page(exts_download(xml_()))
    if args['exts-all']:      page(exts_all     (xml_(), args['<name>']))
    if args['vendors']:       page(vendors      (xml_()))
    if args['type']:          page(type         (xml_(), args['<type>']))
    if args['aliases']:       page(aliases      (xml_(), args['<enum>']))
    if args['value']:         page(value        (xml_(), args['<enum>']))
    if args['enum']:          page(enum         (xml_(), args['<value>']))
    if args['supports']:      page(supports     (xml_(), args['<name>']))
    if args['names']:         page(names        (xml_(), args['<support>']))
    if args['groups']:        page(groups       (xml_(), args['<enum>']))
    if args['enums']:         page(enums        (xml_(), args['<group>']))
    if args['enums-tree']:    page(enums_tree   (xml_(), args['<group>']))
    if args['params']:        page(params       (xml_(), args['<group>']))
    if args['params-tree']:   page(params_tree  (xml_(), args['<group>']))
    if args['audit']:         page(audit        (xml_(), args['<path>']))
    if args['audit-tree']:    page(audit_tree   (xml_(), args['<path>']))
    if args['refs']:          page(refs         (args['<name>']))
    if args['refs-all']:      page(refs_all     (args['<name>']))


if __name__ == '__main__':
    main()