#include <ctype.h>
#include <string.h>
#include <strings.h>
#include <stdio.h>
#include <float.h>
#include <math.h>
#include <stdlib.h>

#include "match.h"
#include "bonus.h"

#include "../config.h"

const char *ignore_escapes(const char *haystack) {
	enum {
		state_default,
		state_escaped,
		state_in_csi,
		state_start_osc,
		state_in_osc,
		state_ignore_next,
	} state = state_default;
	do switch (state) {
		case state_default: switch (*haystack) {
			case '\x1b':
				state = state_escaped; break;
			case '\x7':
				return ++haystack;
			default:
				return haystack;
		}; break;
		case state_escaped: switch (*haystack) {
			case '[':
				state = state_in_csi; break;
			case ']':
				state = state_start_osc; break;
			case '%': case '(': case ')': case '#':
			case '0': case '3': case '5': case '6':
				state = state_ignore_next; break;
			case '<': case '=': case '>': case '\x7':
			case '1': case '2': case '7': case '8':
			case 'c': case 's': case 'u':
			case 'A': case 'B': case 'C': case 'D': case 'E':
			case 'H': case 'I': case 'J': case 'K': case 'M':
			case 'N': case 'O': case 'S': case 'T': case 'Z':
				return ++haystack;
			default:
				return haystack;
		}; break;
		case state_in_csi: switch (*haystack) {
			case ';': case '?':
			case '0': case '1': case '2': case '3': case '4':
			case '5': case '6': case '7': case '8': case '9':
				break;
			default:
				return ++haystack;
		}; break;
		case state_start_osc: switch (*haystack) {
			case '0': case '1': case '2': case '3': case '4':
			case '5': case '6': case '7': case '8': case '9':
				state = state_in_osc; break;
			default:
				return ++haystack;
		}; break;
		case state_in_osc: switch (*haystack) {
			case '\x7':
				return ++haystack;
			case '\x1b':
				state = state_ignore_next; break;
		}; break;
		case state_ignore_next:
			return ++haystack;
	} while (*++haystack);
	return haystack;
}

char *strcasechr(const char *s, char c) {
	const char accept[3] = {c, toupper(c), 0};
	return strpbrk(s, accept);
}

int has_match(const char *needle, const char *haystack) {
	haystack = ignore_escapes(haystack);
	while (*needle) {
		char nch = *needle++;

		if (!(haystack = strcasechr(haystack, nch))) {
			return 0;
		}
		haystack = ignore_escapes(++haystack);
	}
	return 1;
}

#define max(a, b) (((a) > (b)) ? (a) : (b))

struct match_struct {
	int needle_len;
	int haystack_len;

	char lower_needle[MATCH_MAX_LEN];
	char lower_haystack[MATCH_MAX_LEN];

	score_t match_bonus[MATCH_MAX_LEN];
};

static void precompute_bonus(const char *haystack, score_t *match_bonus) {
	/* Which positions are beginning of words */
	char last_ch = '/';
	for (int i = 0; haystack[i]; i++) {
		char ch = haystack[i];
		match_bonus[i] = COMPUTE_BONUS(last_ch, ch);
		last_ch = ch;
	}
}

static void setup_match_struct(struct match_struct *match, const char *needle, const char *haystack) {
	match->needle_len = strlen(needle);
	match->haystack_len = strlen(haystack);

	if (match->haystack_len > MATCH_MAX_LEN || match->needle_len > match->haystack_len) {
		return;
	}

	for (int i = 0; i < match->needle_len; i++)
		match->lower_needle[i] = tolower(needle[i]);

	for (int i = 0; i < match->haystack_len; i++)
		match->lower_haystack[i] = tolower(haystack[i]);

	precompute_bonus(haystack, match->match_bonus);
}

static inline void match_row(const struct match_struct *match, int row, score_t *curr_D, score_t *curr_M, const score_t *last_D, const score_t *last_M) {
	int n = match->needle_len;
	int m = match->haystack_len;
	int i = row;

	const char *lower_needle = match->lower_needle;
	const char *lower_haystack = match->lower_haystack;
	const score_t *match_bonus = match->match_bonus;

	score_t prev_score = SCORE_MIN;
	score_t gap_score = i == n - 1 ? SCORE_GAP_TRAILING : SCORE_GAP_INNER;

	/* These will not be used with this value, but not all compilers see it */
	score_t prev_M = SCORE_MIN, prev_D = SCORE_MIN;

	for (int j = 0; j < m; j++) {
		if (lower_needle[i] == lower_haystack[j]) {
			score_t score = SCORE_MIN;
			if (!i) {
				score = (j * SCORE_GAP_LEADING) + match_bonus[j];
			} else if (j) { /* i > 0 && j > 0*/
				score = max(
						prev_M + match_bonus[j],

						/* consecutive match, doesn't stack with match_bonus */
						prev_D + SCORE_MATCH_CONSECUTIVE);
			}
			prev_D = last_D[j];
			prev_M = last_M[j];
			curr_D[j] = score;
			curr_M[j] = prev_score = max(score, prev_score + gap_score);
		} else {
			prev_D = last_D[j];
			prev_M = last_M[j];
			curr_D[j] = SCORE_MIN;
			curr_M[j] = prev_score = prev_score + gap_score;
		}
	}
}

score_t match(const char *needle, const char *haystack) {
	if (!*needle)
		return SCORE_MIN;

	struct match_struct match;
	setup_match_struct(&match, needle, haystack);

	int n = match.needle_len;
	int m = match.haystack_len;

	if (m > MATCH_MAX_LEN || n > m) {
		/*
		 * Unreasonably large candidate: return no score
		 * If it is a valid match it will still be returned, it will
		 * just be ranked below any reasonably sized candidates
		 */
		return SCORE_MIN;
	} else if (n == m) {
		/* Since this method can only be called with a haystack which
		 * matches needle. If the lengths of the strings are equal the
		 * strings themselves must also be equal (ignoring case).
		 */
		return SCORE_MAX;
	}

	/*
	 * D[][] Stores the best score for this position ending with a match.
	 * M[][] Stores the best possible score at this position.
	 */
	score_t D[MATCH_MAX_LEN], M[MATCH_MAX_LEN];

	for (int i = 0; i < n; i++) {
		match_row(&match, i, D, M, D, M);
	}

	return M[m - 1];
}

score_t match_positions(const char *needle, const char *haystack, size_t *positions) {
	if (!*needle)
		return SCORE_MIN;

	struct match_struct match;
	setup_match_struct(&match, needle, haystack);

	int n = match.needle_len;
	int m = match.haystack_len;

	if (m > MATCH_MAX_LEN || n > m) {
		/*
		 * Unreasonably large candidate: return no score
		 * If it is a valid match it will still be returned, it will
		 * just be ranked below any reasonably sized candidates
		 */
		return SCORE_MIN;
	} else if (n == m) {
		/* Since this method can only be called with a haystack which
		 * matches needle. If the lengths of the strings are equal the
		 * strings themselves must also be equal (ignoring case).
		 */
		if (positions)
			for (int i = 0; i < n; i++)
				positions[i] = i;
		return SCORE_MAX;
	}

	/*
	 * D[][] Stores the best score for this position ending with a match.
	 * M[][] Stores the best possible score at this position.
	 */
	score_t (*D)[MATCH_MAX_LEN], (*M)[MATCH_MAX_LEN];
	M = malloc(sizeof(score_t) * MATCH_MAX_LEN * n);
	D = malloc(sizeof(score_t) * MATCH_MAX_LEN * n);

	match_row(&match, 0, D[0], M[0], D[0], M[0]);
	for (int i = 1; i < n; i++) {
		match_row(&match, i, D[i], M[i], D[i - 1], M[i - 1]);
	}

	/* backtrace to find the positions of optimal matching */
	if (positions) {
		int match_required = 0;
		for (int i = n - 1, j = m - 1; i >= 0; i--) {
			for (; j >= 0; j--) {
				/*
				 * There may be multiple paths which result in
				 * the optimal weight.
				 *
				 * For simplicity, we will pick the first one
				 * we encounter, the latest in the candidate
				 * string.
				 */
				if (D[i][j] != SCORE_MIN &&
				    (match_required || D[i][j] == M[i][j])) {
					/* If this score was determined using
					 * SCORE_MATCH_CONSECUTIVE, the
					 * previous character MUST be a match
					 */
					match_required =
					    i && j &&
					    M[i][j] == D[i - 1][j - 1] + SCORE_MATCH_CONSECUTIVE;
					positions[i] = j--;
					break;
				}
			}
		}
	}

	score_t result = M[n - 1][m - 1];

	free(M);
	free(D);

	return result;
}