#include <glshader.hpp>

#include <algorithm>
#include <array>
#include <cerrno>
#include <cstring>
#include <fstream>
#include <ios>
#include <list>
#include <memory>
#include <regex>
#include <sstream>
#include <stdexcept>
#include <string>
#include <tuple>
#include <utility>
#include <vector>

#include <GL/glew.h>

#include <str.hpp>


using Here = std::tuple<std::string, int, std::string>;


constexpr auto max_length_workaround = 4096;


// NOLINTNEXTLINE
#define GLSHADER_INIT_(NAME, INIT) decltype(NAME) NAME INIT;
GLSHADER_INIT_(Shader::root_, {})
GLSHADER_INIT_(Shader::defines_, {})
GLSHADER_INIT_(Shader::verts_, {})
GLSHADER_INIT_(Shader::frags_, {})
GLSHADER_INIT_(Shader::uniform_buffers_, {})


template<typename Type>
static Type get_integer_(GLenum name, bool supported = true)
{
    auto data = GLint{};
    if (supported)
        glGetIntegerv(name, &data);
    return (Type)data;
}


void Shader::error_(
    std::string const & error,
    std::string const & error_hint
)
{
    auto gl_error = glGetError();
    if (gl_error != GL_NO_ERROR)
        throw std::runtime_error{STR(
            error << "; " <<
            "got error " << (
                STR_COND(gl_error, GL_INVALID_ENUM)
                STR_COND(gl_error, GL_INVALID_VALUE)
                STR_COND(gl_error, GL_INVALID_OPERATION)
                STR_COND(gl_error, GL_OUT_OF_MEMORY)
                STR(std::hex << std::showbase << gl_error)
            ) <<
            (error_hint.empty() ? "" : STR(" (" << error_hint << ")")) <<
            "."
        )};
}


static void info_log_action_(
    std::string const & error,
    void (action)(GLuint object),
    GLuint object,
    GLenum status_enum,
    void (GLAPIENTRY * getObjectiv)(
        GLuint object, GLenum pname, GLint * params
    ),
    void (GLAPIENTRY * getObjectInfoLog)(
        GLuint object, GLsizei max_length, GLsizei * length, GLchar * info_log
    )
)
{
    // Perform action.
    action(object);

    // Check status.
    auto status = GLint{};
    getObjectiv(object, status_enum, &status);
    if (status)
        return;

    // Get info log length.
    auto info_log_length = GLint{};
    getObjectiv(object, GL_INFO_LOG_LENGTH, &info_log_length);

    // Get info log content.
    // NOLINTNEXTLINE
    auto info_log = std::unique_ptr<GLchar[]>(new GLchar[info_log_length]);
    if (info_log_length)
        getObjectInfoLog(object, info_log_length, nullptr, &info_log[0]);

    // Throw.
    throw std::runtime_error{STR(
        error << (!info_log_length ? "." : STR(":\n" << &info_log[0]))
    )};
}


static std::string source_(
    std::string const & error,
    std::string const & path,
    std::string const & root,
    Shader::Defines const & defines,
    std::string extension_behavior = {},
    std::list<Here> included_by = {}
)
{
    // Set here error.
    auto const here_error = [&](std::list<Here> const & include_here)
    {
        return STR_JOIN(
            "\n",
            it,
            std::get<0>(it) << ":" <<
            std::get<1>(it) << ": " <<
            std::get<2>(it),
            include_here
        );
    };

    // Set include helper.
    auto const include_here = [&](Here const & here)
    {
        auto include_here = included_by;
        include_here.push_front(here);
        return include_here;
    };

    // Set full path.
    auto path_full = std::string{};
    {
        auto istream = std::istringstream(path);
        auto part    = std::string{};
        auto parts   = std::vector<std::string>{};
        parts.reserve((size_t)std::count(path.begin(), path.end(), '/') + 1);
        while (std::getline(istream, part, '/'))
        {
            if (part == ".." && !parts.empty())
                parts.pop_back();
            if (part != ".." && part != ".")
                parts.push_back(std::move(part));
        }
        path_full = STR_JOIN('/', it, it, parts);
    }
    if (!root.empty())
        path_full = STR(root << "/" << path_full);

    // Define and open input stream.
    auto istream = std::ifstream{path_full};
    if (!istream)
        throw std::runtime_error{STR(
            error << "; " <<
            "could not open file '" << path_full << "':\n" <<
            here_error(included_by) << (included_by.empty() ? "" : ":\n") <<
            std::strerror(errno)
        )};

    // Define output stream.
    auto ostream = std::ostringstream{};

    // Define parse regexes.
    static auto const re_ignored   = std::regex{R"(\s*//.*$)"};
    static auto const re_words     = std::regex{R"((\w+(?:\s+\w+)*))"};
    static auto const re_spec      = std::regex{R"((\w+)\s*:\s*(\w+))"};
    static auto const re_quoted    = std::regex{R"((["<])([^">]*)([">]))"};
    static auto const re_version   = std::regex{R"(\s*#\s*version\s*(.*))"};
    static auto const re_extension = std::regex{R"(\s*#\s*extension\s*(.*))"};
    static auto const re_include   = std::regex{R"(\s*#\s*include\s*(.*))"};

    // Parse.
    auto version_number    = 0;
    auto extension_enabled = false;
    auto line_number       = 0;
    auto line              = std::string{};
    auto match             = std::smatch{};
    auto here              = [&]()
    {
        return Here{path_full, line_number, line};
    };
    while (++line_number, std::getline(istream, line))
    {
        // Remove ignored.
        auto const content = std::regex_replace(line, re_ignored, "");

        // Output `#line`.
        auto const line_number_offset = version_number < 330 ? -1 : 0;
        if (GLEW_ARB_shading_language_include)
            if (!extension_behavior.empty() && extension_behavior != "disable")
                ostream
                    << "#line" << " "
                    << line_number + line_number_offset << " "
                    << "\"" << path_full << "\""
                    << "\n";

        // Process version.
        if (std::regex_match(content, match, re_version))
        {
            // Parse.
            auto const words = match.str(1);
            if (!std::regex_match(words, match, re_words))
                throw std::runtime_error{STR(
                    error << "; " <<
                    "malformed #version:\n" <<
                    here_error(include_here(here()))
                )};
            auto const version = match.str(1);

            // Check for errors.
            if (version_number)
                throw std::runtime_error{STR(
                    error << "; " <<
                    "found repeated #version:\n" <<
                    here_error(include_here(here()))
                )};
            if (!included_by.empty())
                throw std::runtime_error{STR(
                    error <<  "; " <<
                    "found #version in #include:\n" <<
                    here_error(include_here(here()))
                )};

            // Process.
            version_number = std::stoi(version);

            // Output.
            ostream << line << "\n";
            if (GLEW_ARB_shading_language_include)
            {
                if (extension_behavior.empty())
                {
                    extension_behavior = "enable";
                    ostream
                        << "#extension GL_ARB_shading_language_include : "
                        << extension_behavior << "\n";
                }
            }
            for (auto const & define : defines)
                ostream
                    << "#define "
                    << define.first << " "
                    << define.second << "\n";
        }

        // Process extension.
        else if (std::regex_match(content, match, re_extension))
        {
            // Parse.
            auto const spec = match.str(1);
            if (!std::regex_match(spec, match, re_spec))
                throw std::runtime_error{STR(
                    error << "; " <<
                    "malformed #extension:\n" <<
                    here_error(include_here(here()))
                )};
            auto const extension = match.str(1);
            auto const behavior  = match.str(2);

            if (extension == "GL_ARB_shading_language_include")
            {
                // Check for errors.
                if (!included_by.empty())
                    throw std::runtime_error{STR(
                        error <<  "; " <<
                        "found #extension GL_ARB_shading_language_include " <<
                        "in #include:\n" <<
                        here_error(include_here(here()))
                    )};

                // Process.
                extension_enabled = behavior != "disable";
                extension_behavior = behavior;
                line = "";
            }

            // Output.
            ostream << line << "\n";
        }

        // Process include.
        else if (std::regex_match(content, match, re_include))
        {
            // Parse.
            auto const quoted = match.str(1);
            if (!std::regex_match(quoted, match, re_quoted))
                throw std::runtime_error{STR(
                    error << "; " <<
                    "malformed #include:\n" <<
                    here_error(include_here(here()))
                )};
            auto const quote_open   = match.str(1);
            auto const include_path = match.str(2);
            auto const quote_close  = match.str(3);

            // Check for errors.
            if (!(
                (quote_open == "\"" && quote_close == "\"") ||
                (quote_open == "<"  && quote_close == ">" )
            ))
                throw std::runtime_error{STR(
                    error << "; " <<
                    "mismatched #include quotes '" << quote_open << "' and '"
                    << quote_close << "':\n" <<
                    here_error(include_here(here()))
                )};
            if (!extension_enabled && included_by.empty())
                throw std::runtime_error{STR(
                    error << "; " <<
                    "#include found but #extension " <<
                    "GL_ARB_shading_language_include not enabled:\n" <<
                    here_error(include_here(here()))
                )};

            // Process.
            auto source = std::string{};
            if (included_by.end() == std::find(
                included_by.begin(), included_by.end(), here()
            ))
            {
                auto include_path_full = include_path;
                if (quote_open == "\"")
                {
                    auto const pos = path.rfind('/');
                    if (pos != path.npos && pos != 0)
                        include_path_full = STR(
                            path.substr(0, pos + 1) << include_path
                        );
                }
                source = source_(
                    error,
                    include_path_full, root,
                    defines,
                    extension_behavior,
                    include_here(here())
                );
            }

            // Output.
            ostream << source << "\n";
        }

        // Non-processed line.
        else
        {
            // Output.
            ostream << line << "\n";
        }
    }

    // Check for version.
    if (!version_number && included_by.empty())
        throw std::runtime_error{STR(
            error << "; " <<
            "found no #version."
        )};

    // Return.
    return ostream.str();
}


template<typename Function>
static void for_variable_(
    GLuint program,
    GLenum count_enum,
    GLenum max_length_enum,
    Function function
)
{
    // Get count.
    auto count = GLuint{};
    glGetProgramiv(program, count_enum, (GLint *)&count);

    // Get max length.
    auto max_length = GLsizei{};
    glGetProgramiv(program, max_length_enum, &max_length);

    // Work around driver bugs.
    if (max_length == 0 && count != 0)
        max_length = max_length_workaround;

    // Allocate and call function.
    // NOLINTNEXTLINE
    auto name = std::unique_ptr<GLchar[]>(new GLchar[max_length]);
    for (auto index = GLuint{0}; index < count; ++index)
        function(index, max_length, &name[0]);
}


Shader::Shader(Paths const & paths)
:
    program_{0},
    program_name_{STR(
        "shader program " << STR_JOIN(", ", it, "'" << it << "'", paths)
    )},
    uniforms_{},
    uniform_blocks_{}
{
    // Get label limits.
    static auto const max_label_length = get_integer_<GLsizei>(
        GL_MAX_LABEL_LENGTH, GLEW_VERSION_4_3 || GLEW_KHR_debug
    );

    try
    {
        // Create program.
        program_ = glCreateProgram();
        if (!program_)
            throw std::runtime_error{STR(
                "Failed to create " << program_name_ << "."
            )};

        // Label program.
        if (GLEW_VERSION_4_3 || GLEW_KHR_debug)
            glObjectLabel(
                GL_PROGRAM,
                program_,
                std::min(max_label_length, (GLsizei)program_name_.length()),
                program_name_.c_str()
            );

        // Process shader paths.
        auto shaders = std::vector<GLuint>{};
        shaders.reserve(paths.size());
        for (auto const & path : paths)
        {
            // Set shader name.
            auto const shader_name = STR(
                "shader '" << path << "' of " << program_name_
            );

            // Infer shader type from path extension.
            auto const type_error = STR(
                "Failed to infer type of " << shader_name
            );
            auto const type_pos = path.rfind('.');
            if (type_pos == path.npos)
                throw std::runtime_error{STR(
                    type_error << "; " <<
                    "no file extension."
                )};
            auto const type_name = path.substr(type_pos + 1);
            auto const type =
                type_name == "vert" ? GL_VERTEX_SHADER :
                type_name == "tesc" ? GL_TESS_CONTROL_SHADER :
                type_name == "tese" ? GL_TESS_EVALUATION_SHADER :
                type_name == "geom" ? GL_GEOMETRY_SHADER :
                type_name == "frag" ? GL_FRAGMENT_SHADER :
                type_name == "comp" ? GL_COMPUTE_SHADER :
                GLenum{0};
            if (!type)
                throw std::runtime_error{STR(
                    type_error << "; " <<
                    "unknown file extension '" << type_name << "'."
                )};

            // Create, attach, and flag shader for deletion when detached.
            auto const shader = glCreateShader(type);
            if (!shader)
                throw std::runtime_error{STR(
                    "Failed to create " << type_name << " shader for " <<
                    shader_name << "."
                )};
            shaders.push_back(shader);
            glAttachShader(program_, shader);
            glDeleteShader(shader);

            // Label shader.
            if (GLEW_VERSION_4_3 || GLEW_KHR_debug)
                glObjectLabel(
                    GL_SHADER,
                    shader,
                    std::min(max_label_length, (GLsizei)shader_name.length()),
                    shader_name.c_str()
                );

            // Set shader source.
            auto const source_error = STR("Failed to source " << shader_name);
            auto const source = source_(source_error, path, root_, defines_);
            auto const sources = std::array<char const *, 1>{{
                source.c_str()
            }};
            glShaderSource(shader, sources.size(), &sources[0], nullptr);

            // Compile shader.
            info_log_action_(
                STR("Failed to compile " << shader_name),
                glCompileShader, shader,
                GL_COMPILE_STATUS, glGetShaderiv, glGetShaderInfoLog
            );
        }

        // Set vertex input locations.
        for (auto const & vert : verts_)
            glBindAttribLocation(
                program_, vert.second, vert.first.c_str()
            );

        // Set fragment output locations.
        if (GLEW_VERSION_3_0)
            for (auto const & frag : frags_)
                glBindFragDataLocation(
                    program_, frag.second, frag.first.c_str()
                );

        // Link program.
        info_log_action_(
            STR("Failed to link " << program_name_),
            glLinkProgram, program_,
            GL_LINK_STATUS, glGetProgramiv, glGetProgramInfoLog
        );

        // Detach shaders.
        for (auto const & shader : shaders)
            glDetachShader(program_, shader);

        // Initialize vertex inputs.
        for_variable_(
            program_,
            GL_ACTIVE_ATTRIBUTES, GL_ACTIVE_ATTRIBUTE_MAX_LENGTH,
            [&](GLuint index, GLsizei max_length, GLchar * name)
            {
                GLint  size{};
                GLenum type{};
                glGetActiveAttrib(
                    program_, index, max_length,
                    nullptr, &size, &type, name
                );
                auto location = glGetAttribLocation(program_, name);
                if (location != -1 && verts_.find(name) == verts_.end())
                    throw std::runtime_error{STR(
                        "Failed to initialize vertex input '" << name <<
                        "' of " << program_name_ << "."
                    )};
            }
        );

        // Initialize uniforms.
        for_variable_(
            program_,
            GL_ACTIVE_UNIFORMS, GL_ACTIVE_UNIFORM_MAX_LENGTH,
            [&](GLuint index, GLsizei max_length, GLchar * name)
            {
                GLint  size{};
                GLenum type{};
                glGetActiveUniform(
                    program_, index, max_length,
                    nullptr, &size, &type, name
                );
                auto location = glGetUniformLocation(program_, name);
                if (location != -1)
                    uniforms_.emplace(name, Uniform{location, false});
            }
        );

        // Initialize uniform blocks.
        for_variable_(
            program_,
            GL_ACTIVE_UNIFORM_BLOCKS, GL_ACTIVE_UNIFORM_BLOCK_MAX_NAME_LENGTH,
            [&](GLuint index, GLsizei max_length, GLchar * name)
            {
                if (!(GLEW_VERSION_3_1 || GLEW_ARB_uniform_buffer_object))
                    throw std::runtime_error{STR(
                        "Failed to initialize uniform block; " <<
                        "ARB_uniform_buffer_object not available."
                    )};
                glGetActiveUniformBlockName(
                    program_, index, max_length,
                    nullptr, name
                );
                auto error = STR(
                    "Failed to initialize uniform block '" << name << "' of "
                    << program_name_
                );
                auto size = GLsizei{};
                glGetActiveUniformBlockiv(
                    program_, index, GL_UNIFORM_BLOCK_DATA_SIZE, (GLint *)&size
                );
                auto & uniform_buffer = *uniform_buffer_(
                    error, name, false, size
                );
                glUniformBlockBinding(program_, index, uniform_buffer.binding);
                uniform_blocks_.emplace(
                    name, UniformBlock{uniform_buffer}
                );
            }
        );
    }
    catch (...)
    {
        // Delete program (and detach and delete shaders).
        if (program_)
            glDeleteProgram(program_);
        throw;
    }
}


Shader::Shader(Shader && other) noexcept
:
    program_       {other.program_},
    program_name_  {std::move(other.program_name_)},
    uniforms_      {std::move(other.uniforms_)},
    uniform_blocks_{std::move(other.uniform_blocks_)}
{
    other.program_ = 0;
}


Shader::~Shader()
{
    if (program_)
        glDeleteProgram(program_);
}


template<typename Uniforms, typename Set>
static void uniforms_validate_(
    std::string const & error,
    std::string const & uniform_type,
    Uniforms uniforms,
    Set set
)
{
    // Find.
    auto it = std::find_if(uniforms.begin(), uniforms.end(),
        [&](typename Uniforms::value_type const & it)
        {
            return !set(it.second);
        }
    );

    // Error if not found.
    if (it != uniforms.end())
        throw std::runtime_error{STR(
            error << "; " <<
            uniform_type << " '" << it->first << "' not set."
        )};
}


void Shader::validate_() const
{
    // Set error.
    auto const validate_error = STR("Failed to validate " << program_name_);

    // Validate program.
    info_log_action_(
        validate_error,
        glValidateProgram, program_,
        GL_VALIDATE_STATUS, glGetProgramiv, glGetProgramInfoLog
    );

    // Assert current.
    current_(validate_error);

    // Validate uniforms.
    uniforms_validate_(validate_error, "uniform", uniforms_,
        [](Uniform const & uniform)
        {
            return uniform.set;
        }
    );

    // Validate uniform blocks.
    uniforms_validate_(validate_error, "uniform block", uniform_blocks_,
        [](UniformBlock const & uniform_block)
        {
            return uniform_block.buffer.set;
        }
    );
}


void Shader::current_(std::string const & error) const
{
    // Error if not current.
    if (get_integer_<GLuint>(GL_CURRENT_PROGRAM) != program_)
        throw std::runtime_error{STR(
            error << "; " <<
            "shader program not current."
        )};
}


Shader::Uniform * Shader::uniform_(
    std::string const & error,
    std::string const & name,
    bool required
)
{
    // Return if found.
    auto it = uniforms_.find(name);
    if (it != uniforms_.end())
    {
        auto & uniform = it->second;
        return &uniform;
    }

    // Error if required.
    if (required)
    {
        auto error_hint = std::string{};
        if (uniform_blocks_.find(name) != uniform_blocks_.end())
            error_hint = " (did you mean the uniform block?)";
        throw std::runtime_error{STR(
            error << "; " <<
            "uniform required but not found" << error_hint << "."
        )};
    }

    // Return.
    return nullptr;
}


Shader::UniformBlock * Shader::uniform_block_(
    std::string const & error,
    std::string const & name,
    bool required,
    GLsizeiptr size
)
{
    // Return if found.
    auto it = uniform_blocks_.find(name);
    if (it != uniform_blocks_.end())
    {
        auto & uniform_block = it->second;
        if (size != uniform_block.buffer.size)
            throw std::runtime_error{STR(
                error << "; " <<
                "expected size " << uniform_block.buffer.size << " but got " <<
                size << "."
            )};
        return &uniform_block;
    }

    // Error if required.
    if (required)
    {
        auto error_hint = std::string{};
        if (uniforms_.find(name) != uniforms_.end())
            error_hint = " (did you mean the uniform?)";
        throw std::runtime_error{STR(
            error << "; " <<
            "uniform block required but not found" << error_hint << "."
        )};
    }

    // Return.
    return nullptr;
}


Shader::UniformBuffer * Shader::uniform_buffer_(
    std::string const & error,
    std::string const & name,
    bool required,
    GLsizeiptr size
)
{
    // Get uniform block / buffer limits.
    static auto const max_uniform_block_size = get_integer_<GLuint>(
        GL_MAX_UNIFORM_BLOCK_SIZE
    );
    static auto const max_uniform_buffer_bindings = get_integer_<GLuint>(
        GL_MAX_UNIFORM_BUFFER_BINDINGS
    );

    // Define next binding.
    static auto next_uniform_buffer_binding_ = GLuint{0};

    // Return if found.
    auto it = uniform_buffers_.find(name);
    if (it != uniform_buffers_.end())
    {
        auto & uniform_buffer = it->second;
        if (size != uniform_buffer.size)
            throw std::runtime_error{STR(
                error << "; " <<
                "expected size " << uniform_buffer.size << " but got " << size
                << "."
            )};
        return &uniform_buffer;
    }

    // Error if required.
    if (required)
        throw std::runtime_error{STR(
            error << "; " <<
            "uniform buffer required but not found."
        )};

    // Set create error.
    auto create_error = STR(
        error << ":\n" <<
        "Failed to create uniform buffer '" << name << "'"
    );

    // Check availability.
    if (!(GLEW_VERSION_3_1 || GLEW_ARB_uniform_buffer_object))
        throw std::runtime_error{STR(
            create_error << "; " <<
            "ARB_uniform_buffer_object not available."
        )};

    // Create storage.
    auto emplace = uniform_buffers_.emplace(name, UniformBuffer{});
    if (!emplace.second)
        throw std::runtime_error{STR(
            create_error << "; " <<
            "already exists."
        )};
    auto & uniform_buffer = emplace.first->second;

    // Check for errors.
    if (size > max_uniform_block_size)
        throw std::runtime_error{STR(
            create_error << "; " <<
            "buffer has size " << size << " but max size is " <<
            max_uniform_block_size << "."
        )};
    if (next_uniform_buffer_binding_ >= max_uniform_buffer_bindings)
        throw std::runtime_error{STR(
            create_error << "; " <<
            "buffer would have binding " << next_uniform_buffer_binding_ <<
            " but max bindings is " << max_uniform_buffer_bindings << "."
        )};

    // Generate and bind.
    glGenBuffers(1, &uniform_buffer.buffer);
    glBindBuffer(GL_UNIFORM_BUFFER, uniform_buffer.buffer),

    // Allocate size.
    uniform_buffer.size = size;
    error_(create_error, "unprocessed previous error");
    glBufferData(
        GL_UNIFORM_BUFFER,
        uniform_buffer.size,
        nullptr,
        GL_DYNAMIC_DRAW
    );
    error_(create_error);

    // Allocate binding and bind.
    uniform_buffer.binding = next_uniform_buffer_binding_++;
    glBindBufferBase(
        GL_UNIFORM_BUFFER,
        uniform_buffer.binding,
        uniform_buffer.buffer
    );

    // Mark as unset.
    uniform_buffer.set = false;

    // Return
    return &uniform_buffer;
}
