#include "shader.hpp"

#include <vector>
#include <string>
#include <fstream>
#include <stdexcept>
#include <iostream>

#include <GL/glew.h>

#include "str.hpp"


static void checked_action(
    void (glAction)(
        GLuint object
    ),
    GLuint object,
    GLenum status_enum,
    void (glGetObjectiv)(
        GLuint object, GLenum pname, GLint *params
    ),
    void (glGetObjectInfoLog)(
        GLuint object, GLsizei maxLength, GLsizei *length, GLchar *infoLog
    ),
    std::string const & what
)
{
    // Perform action.
    glAction(object);

    // Check status.
    GLint status;
    glGetObjectiv(object, status_enum, &status);
    if (status)
        return;

    // Get info log length.
    GLint info_log_length;
    glGetObjectiv(object, GL_INFO_LOG_LENGTH, &info_log_length);

    // Get info log content.
    std::string info_log;
    if (info_log_length) {
        std::vector<char> info_log_buf(info_log_length);
        glGetObjectInfoLog(object, info_log_length, nullptr, &info_log_buf[0]);
        info_log = STR("\n" << &info_log[0]);
    }

    // Throw.
    throw std::runtime_error{STR(what << info_log)};
}


static GLenum shader_type(std::string type_str)
{
    return
        type_str == "vert" ? GL_VERTEX_SHADER :
        type_str == "tesc" ? GL_TESS_CONTROL_SHADER :
        type_str == "tese" ? GL_TESS_EVALUATION_SHADER :
        type_str == "geom" ? GL_GEOMETRY_SHADER :
        type_str == "frag" ? GL_FRAGMENT_SHADER :
        type_str == "comp" ? GL_COMPUTE_SHADER :
        0;
}


Shader::Shader(std::vector<std::string> paths, std::string name)
:
    program_{0},
    paths_{std::move(paths)},
    name_{!name.empty()
        ? std::move(name)
        : STR_JOIN(", ", "'" << it << "'", paths_)
    },
    uniform_location_cache_{}
{
    new_();
}


Shader::Shader(Shader const & other)
:
    program_               {0},
    paths_                 {other.paths_},
    name_                  {other.name_},
    uniform_location_cache_{other.uniform_location_cache_}
{
    new_();
}


Shader & Shader::operator=(Shader const & other)
{
    delete_();
    program_                = 0;
    paths_                  = other.paths_;
    name_                   = other.name_;
    uniform_location_cache_ = other.uniform_location_cache_;
    new_();
    return *this;
}


Shader::Shader(Shader && other)
:
    program_               {std::move(other.program_)},
    paths_                 {std::move(other.paths_)},
    name_                  {std::move(other.name_)},
    uniform_location_cache_{std::move(other.uniform_location_cache_)}
{
    other.program_ = 0;
}


Shader & Shader::operator=(Shader && other)
{
    delete_();
    program_                = std::move(other.program_);
    paths_                  = std::move(other.paths_);
    name_                   = std::move(other.name_);
    uniform_location_cache_ = std::move(other.uniform_location_cache_);
    other.program_ = 0;
    return *this;
}


Shader::~Shader()
{
    delete_();
}


Shader & Shader::use()
{
    glUseProgram(program_);
    return *this;
}


Shader & Shader::validate()
{
    // Validate shader and check for errors.
    checked_action(
        glValidateProgram, program_,
        GL_VALIDATE_STATUS, glGetProgramiv, glGetProgramInfoLog,
        STR("Failed to validate shader program " << name_ << ".")
    );
    return *this;
}


void Shader::new_()
{
    try
    {
        // Create program.
        program_ = glCreateProgram();
        if (!program_)
            throw std::runtime_error{STR(
                "Failed to create shader program for " << name_ << "."
            )};

        // Label program.
        glObjectLabel(GL_PROGRAM, program_, -1, name_.c_str());

        // Process shader paths.
        std::vector<GLuint> shaders;
        for (auto const & path : paths_)
        {
            // Infer shader type from path extension.
            // https://www.khronos.org/opengles/sdk/tools/Reference-Compiler/
            auto pos = path.rfind(".");
            if (pos == path.npos)
                throw std::runtime_error{STR(
                    "Failed to infer shader type of '" << path <<
                    "' of shader program " << name_ << "; no file extension."
                )};
            auto type_str = path.substr(pos + 1);
            auto type = shader_type(type_str);
            if (!type)
                throw std::runtime_error{STR(
                    "Failed to infer shader type of '" << path <<
                    "' of shader program " << name_ <<
                    "; unknown file extension '" << type_str << "'."
                )};

            // Create, attach, and flag shader for deletion when detached.
            auto shader = glCreateShader(type);
            if (!shader)
                throw std::runtime_error{STR(
                    "Failed to create " << type_str << " shader for " << path
                    << " of shader program " << name_ << "."
                )};
            glAttachShader(program_, shader);
            glDeleteShader(shader);
            shaders.push_back(shader);

            // Label shader.
            glObjectLabel(GL_SHADER, shader, -1, path.c_str());

            // Set shader source.
            std::ifstream source_file(path);
            if (!source_file)
                throw std::runtime_error{STR(
                    "Failed to open " << type_str << " shader '" << path <<
                    "' of shader program " << name_ << "."
                )};
            auto source_str = STR(source_file.rdbuf());
            std::vector<char const *> sources = { source_str.c_str() };
            glShaderSource(shader, sources.size(), &sources[0], nullptr);

            // Compile shader and check for errors.
            checked_action(
                glCompileShader, shader,
                GL_COMPILE_STATUS, glGetShaderiv, glGetShaderInfoLog,
                STR(
                    "Failed to compile " << type_str << " shader '" << path <<
                    "' of shader program " << name_ << "."
                )
            );
        }

        // Link program and check for errors.
        checked_action(
            glLinkProgram, program_,
            GL_LINK_STATUS, glGetProgramiv, glGetProgramInfoLog,
            STR("Failed to link shader program " << name_ << ".")
        );

        // Detach shaders.
        for (auto shader : shaders)
            glDetachShader(program_, shader);
    }
    catch (...)
    {
        delete_();
        throw;
    }
}


void Shader::delete_()
{
    // Delete program (and detach and delete shaders).
    glDeleteProgram(program_);
}


void Shader::ensure_current_(
    std::string const & operation,
    std::string const & name
)
{
    GLuint current_program;
    glGetIntegerv(GL_CURRENT_PROGRAM, (GLint *)&current_program);
    if (current_program != program_) {
        auto action = name.empty()
            ? operation
            : STR(operation << " '" << name << "'");
        throw std::runtime_error{STR(
            "Failed to " << action << " of shader program " << name_ <<
            "; program is not current."
        )};
    }
}


GLint Shader::uniform_location_(std::string const & name)
{
    // Try cache.
    auto cache_entry = uniform_location_cache_.find(name);
    if (cache_entry != uniform_location_cache_.end())
        return cache_entry->second;

    // Query OpenGL.
    auto location = glGetUniformLocation(program_, name.c_str());
    if (location == -1)
        throw std::runtime_error{STR(
            "Failed to get location of uniform '" << name <<
            "' of shader program " << name_ << "."
        )};

    // Save in cache and return.
    uniform_location_cache_.emplace(name, location);
    return location;
}