#include "shader.hpp"

#include <vector>
#include <string>
#include <fstream>
#include <stdexcept>
#include <iostream>

#include <GL/glew.h>

#include "str.hpp"


Shader::StringCache<Shader::UniformBuffer>
Shader::uniform_buffer_cache_{};

GLuint
Shader::uniform_buffer_binding_next_{0};


static void checked_action(
    void (glAction)(
        GLuint object
    ),
    GLuint object,
    GLenum status_enum,
    void (glGetObjectiv)(
        GLuint object, GLenum pname, GLint *params
    ),
    void (glGetObjectInfoLog)(
        GLuint object, GLsizei maxLength, GLsizei *length, GLchar *infoLog
    ),
    std::string const & what
)
{
    // Perform action.
    glAction(object);

    // Check status.
    GLint status;
    glGetObjectiv(object, status_enum, &status);
    if (status)
        return;

    // Get info log length.
    GLint info_log_length;
    glGetObjectiv(object, GL_INFO_LOG_LENGTH, &info_log_length);

    // Get info log content.
    std::string info_log;
    if (info_log_length) {
        std::vector<char> info_log_buf(info_log_length);
        glGetObjectInfoLog(object, info_log_length, nullptr, &info_log_buf[0]);
        info_log = STR("\n" << &info_log[0]);
    }

    // Throw.
    throw std::runtime_error{STR(what << info_log)};
}


static GLenum shader_type(std::string type_str)
{
    return
        type_str == "vert" ? GL_VERTEX_SHADER :
        type_str == "tesc" ? GL_TESS_CONTROL_SHADER :
        type_str == "tese" ? GL_TESS_EVALUATION_SHADER :
        type_str == "geom" ? GL_GEOMETRY_SHADER :
        type_str == "frag" ? GL_FRAGMENT_SHADER :
        type_str == "comp" ? GL_COMPUTE_SHADER :
        0;
}


static std::string uniform_buffer_usage_str(GLenum usage)
{
    switch(usage)
    {
        STR_CASE(GL_STREAM_DRAW)
        STR_CASE(GL_STREAM_READ)
        STR_CASE(GL_STREAM_COPY)
        STR_CASE(GL_STATIC_DRAW)
        STR_CASE(GL_STATIC_READ)
        STR_CASE(GL_STATIC_COPY)
        STR_CASE(GL_DYNAMIC_DRAW)
        STR_CASE(GL_DYNAMIC_READ)
        STR_CASE(GL_DYNAMIC_COPY)
        default:
            return STR(std::hex << std::showbase << usage);
    }
}


Shader::Shader(std::vector<std::string> paths, std::string name)
:
    program_{0},
    paths_{std::move(paths)},
    name_{!name.empty()
        ? std::move(name)
        : STR_JOIN(", ", "'" << it << "'", paths_)
    },
    uniform_location_cache_{},
    uniform_block_index_cache_{}
{
    new_();
}


Shader::Shader(Shader const & other)
:
    program_                  {0},
    paths_                    {other.paths_},
    name_                     {other.name_},
    uniform_location_cache_   {other.uniform_location_cache_},
    uniform_block_index_cache_{other.uniform_block_index_cache_}
{
    new_();
}


Shader & Shader::operator=(Shader const & other)
{
    delete_();
    program_                   = 0;
    paths_                     = other.paths_;
    name_                      = other.name_;
    uniform_location_cache_    = other.uniform_location_cache_;
    uniform_block_index_cache_ = other.uniform_block_index_cache_;
    new_();
    return *this;
}


Shader::Shader(Shader && other)
:
    program_                  {std::move(other.program_)},
    paths_                    {std::move(other.paths_)},
    name_                     {std::move(other.name_)},
    uniform_location_cache_   {std::move(other.uniform_location_cache_)},
    uniform_block_index_cache_{std::move(other.uniform_block_index_cache_)}
{
    other.program_ = 0;
}


Shader & Shader::operator=(Shader && other)
{
    delete_();
    program_                   = std::move(other.program_);
    paths_                     = std::move(other.paths_);
    name_                      = std::move(other.name_);
    uniform_location_cache_    = std::move(other.uniform_location_cache_);
    uniform_block_index_cache_ = std::move(other.uniform_block_index_cache_);
    other.program_ = 0;
    return *this;
}


Shader::~Shader()
{
    delete_();
}


Shader & Shader::use()
{
    glUseProgram(program_);
    return *this;
}


Shader & Shader::validate()
{
    // Validate shader and check for errors.
    checked_action(
        glValidateProgram, program_,
        GL_VALIDATE_STATUS, glGetProgramiv, glGetProgramInfoLog,
        STR("Failed to validate shader program " << name_ << ".")
    );
    return *this;
}


void Shader::new_()
{
    try
    {
        // Create program.
        program_ = glCreateProgram();
        if (!program_)
            throw std::runtime_error{STR(
                "Failed to create shader program for " << name_ << "."
            )};

        // Label program.
        glObjectLabel(GL_PROGRAM, program_, -1, name_.c_str());

        // Process shader paths.
        std::vector<GLuint> shaders;
        for (auto const & path : paths_)
        {
            // Infer shader type from path extension.
            // https://www.khronos.org/opengles/sdk/tools/Reference-Compiler/
            auto pos = path.rfind(".");
            if (pos == path.npos)
                throw std::runtime_error{STR(
                    "Failed to infer shader type of '" << path <<
                    "' of shader program " << name_ << "; no file extension."
                )};
            auto type_str = path.substr(pos + 1);
            auto type = shader_type(type_str);
            if (!type)
                throw std::runtime_error{STR(
                    "Failed to infer shader type of '" << path <<
                    "' of shader program " << name_ <<
                    "; unknown file extension '" << type_str << "'."
                )};

            // Create, attach, and flag shader for deletion when detached.
            auto shader = glCreateShader(type);
            if (!shader)
                throw std::runtime_error{STR(
                    "Failed to create " << type_str << " shader for " << path
                    << " of shader program " << name_ << "."
                )};
            glAttachShader(program_, shader);
            glDeleteShader(shader);
            shaders.push_back(shader);

            // Label shader.
            glObjectLabel(GL_SHADER, shader, -1, path.c_str());

            // Set shader source.
            std::ifstream source_file(path);
            if (!source_file)
                throw std::runtime_error{STR(
                    "Failed to open " << type_str << " shader '" << path <<
                    "' of shader program " << name_ << "."
                )};
            auto source_str = STR(source_file.rdbuf());
            std::vector<char const *> sources = { source_str.c_str() };
            glShaderSource(shader, sources.size(), &sources[0], nullptr);

            // Compile shader and check for errors.
            checked_action(
                glCompileShader, shader,
                GL_COMPILE_STATUS, glGetShaderiv, glGetShaderInfoLog,
                STR(
                    "Failed to compile " << type_str << " shader '" << path <<
                    "' of shader program " << name_ << "."
                )
            );
        }

        // Link program and check for errors.
        checked_action(
            glLinkProgram, program_,
            GL_LINK_STATUS, glGetProgramiv, glGetProgramInfoLog,
            STR("Failed to link shader program " << name_ << ".")
        );

        // Detach shaders.
        for (auto shader : shaders)
            glDetachShader(program_, shader);
    }
    catch (...)
    {
        delete_();
        throw;
    }
}


void Shader::delete_() {
    // Delete program (and detach and delete shaders).
    glDeleteProgram(program_);
}


Shader & Shader::uniform(
    std::string const & name, std::string const & buffer_name
) {
    // Find uniform buffer in cache.
    auto cache_entry = uniform_buffer_cache_.find(buffer_name);
    if (cache_entry == uniform_buffer_cache_.end())
        throw std::runtime_error{STR(
            "Failed to set uniform block '" << name << "' of shader program "
            << name << "; no uniform buffer '" << buffer_name << "'."
        )};
    auto const & uniform_buffer = cache_entry->second;

    // Bind.
    glUniformBlockBinding(
        program_, uniform_block_index_(name), uniform_buffer.binding
    );

    // Return.
    return *this;
}


void Shader::uniform_buffer_delete(std::string const & name) {
    // Find uniform buffer in cache.
    auto cache_entry = uniform_buffer_cache_.find(name);
    if (cache_entry == uniform_buffer_cache_.end())
        throw std::runtime_error{STR(
            "Failed to delete uniform buffer '" << name << "'; does not exist."
        )};
    auto const & uniform_buffer = cache_entry->second;

    // Delete buffer.
    glDeleteBuffers(1, &uniform_buffer.buffer);

    // Erase cache entries.
    uniform_buffer_cache_.erase(cache_entry);
}


void Shader::ensure_current_(
    std::string const & operation,
    std::string const & name
)
{
    GLuint current_program;
    glGetIntegerv(GL_CURRENT_PROGRAM, (GLint *)&current_program);
    if (current_program != program_) {
        auto action = name.empty()
            ? operation
            : STR(operation << " '" << name << "'");
        throw std::runtime_error{STR(
            "Failed to " << action << " of shader program " << name_ <<
            "; program is not current."
        )};
    }
}


GLint Shader::uniform_location_(std::string const & name)
{
    // Try cache.
    auto cache_entry = uniform_location_cache_.find(name);
    if (cache_entry != uniform_location_cache_.end())
        return cache_entry->second;

    // Query OpenGL.
    auto location = glGetUniformLocation(program_, name.c_str());
    if (location == -1)
        throw std::runtime_error{STR(
            "Failed to get location of uniform '" << name <<
            "' of shader program " << name_ << "."
        )};

    // Save in cache and return.
    uniform_location_cache_.emplace(name, location);
    return location;
}


GLuint Shader::uniform_block_index_(std::string const & name) {
    // Try cache.
    auto & cache = uniform_block_index_cache_;
    auto cache_entry = cache.find(name);
    if (cache_entry != cache.end())
        return cache_entry->second;

    // Query OpenGL.
    auto index = glGetUniformBlockIndex(program_, name.c_str());
    if (index == GL_INVALID_INDEX)
        throw std::runtime_error{STR(
            "Failed to get index of uniform block '" << name <<
            "' in shader program " << name_ << "."
        )};

    // Save in cache and return.
    cache.emplace(name, index);
    return index;
}


GLuint Shader::uniform_buffer_(
    std::string const & name, GLsizeiptr size, GLenum usage
) {
    // Get usage string.
    auto const & usage_str = uniform_buffer_usage_str(usage);

    // Try cache.
    auto cache_entry = uniform_buffer_cache_.find(name);
    if (cache_entry != uniform_buffer_cache_.end()) {
        // Get uniform buffer.
        auto const & uniform_buffer = cache_entry->second;

        // Check size and usage.
        if (size != uniform_buffer.size)
            throw std::runtime_error{STR(
                "Failed to set data of uniform buffer '" << name <<
                "'; data has size " << size << " but buffer has size " <<
                uniform_buffer.size << "."
            )};
        if (usage != uniform_buffer.usage)
            throw std::runtime_error{STR(
                "Failed to set data of uniform buffer '" << name <<
                "'; data has usage " << usage_str << " but buffer has usage "
                << uniform_buffer_usage_str(uniform_buffer.usage) << "."
            )};

        // Return.
        return uniform_buffer.buffer;
    }

    // Validate size and usage.
    if (size <= 0)
        throw std::runtime_error{STR(
            "Failed to create uniform buffer '" << name << "', invalid size "
            << size << "."
        )};
    if (usage_str.rfind("GL_", 0) != 0)
        throw std::runtime_error{STR(
            "Failed to create uniform buffer '" << name << "', invalid usage "
            << usage_str << "."
        )};

    // Check max uniform buffer bindings.
    GLuint max_uniform_buffer_bindings;
    glGetIntegerv(
        GL_MAX_UNIFORM_BUFFER_BINDINGS,
        (GLint *)&max_uniform_buffer_bindings
    );
    if (uniform_buffer_binding_next_ >= max_uniform_buffer_bindings)
        throw std::runtime_error{STR(
            "Failed to bind uniform buffer '" << name << "'; max bindings of "
            << max_uniform_buffer_bindings << " exceeded."
        )};

    // Create uniform buffer.
    UniformBuffer uniform_buffer;

    // Generate and bind.
    glGenBuffers(1, &uniform_buffer.buffer);
    glBindBuffer(GL_UNIFORM_BUFFER, uniform_buffer.buffer),

    // Set size and usage.
    uniform_buffer.size = size;
    uniform_buffer.usage = usage;
    glBufferData(
        GL_UNIFORM_BUFFER, uniform_buffer.size, nullptr, uniform_buffer.usage
    );

    // Allocate binding and bind.
    uniform_buffer.binding = uniform_buffer_binding_next_++;
    glBindBufferBase(
        GL_UNIFORM_BUFFER, uniform_buffer.binding, uniform_buffer.buffer
    );

    // Save in cache and return.
    uniform_buffer_cache_.emplace(name, uniform_buffer);
    return uniform_buffer.buffer;
}