/*
    Function plotter with continuous lines.

    This file is part of Simple Graphics Framework (SGF).

    Copyright 2023 Arnold Beiland

    Simple Graphics Framework is free software: you can redistribute it and/or
    modify it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or (at your
    option) any later version.

    This program is distributed in the hope that it will be useful, but
    WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
    or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
    more details.

    You should have received a copy of the GNU General Public License along with
    this program (look for a file named COPYING in the top directory). If not, see
    <https://www.gnu.org/licenses/>.
*/

#include "../sgf.h"
#include <cctype>
#include <iostream>
#include <cmath>
#include <string>
#include <fstream>
using namespace std;

const int max_children = 20;
const int max_functions = 100;
const int text_size = 11;

const Color colors_to_use[max_functions] = {
    {255,0,0}, {0,200,0}, {0,0,255}, {154,72,0}, {153,51,255}, {0,153,153}, {102,102,153} };
const int n_colors = sizeof(colors_to_use) / sizeof(Color);

enum NodeType { Plus, Mul, X, Number, Inv, Neg, Error, Sin };

struct Node {
    NodeType type;
    int nr_children;
    Node *children[max_children];

    double numeric_value;
    string error_message;
};

int n_functions;
string function_strings[max_functions];
Node *expression_roots[max_functions];

void parse_function_strings();
double evaluate_expression(Node *node, double x);

/* expression manipulation */
Node *parse_expr(const string &str, int &pos); // <term> + <term> + ... + <term>
Node *parse_term(const string &str, int &pos); // <token> * <token> * ... * <token>
Node *parse_token(const string &str, int &pos); // x, X, 12.3, 0.5, -9, (expression)

double x_range = 10.0;
int mouse_row = -1, mouse_col = -1;
int current_width, current_height;

void draw_coordinate_grid();
void draw_function_strings();
void draw_mouse_position();
void draw_function_line(Node *expression, Color color);
void draw_horizontal_line(int row, Color color);
void draw_vertical_line(int col, Color color);
void draw_number_bottom_left(int val, int row, int col);
Color get_lighter_color(Color base_color, double rate);

void initialize()
{
    srand(time(NULL));

    ifstream fin("plotter-input.txt");
    if (!fin) {
        logfile << "plotter-input.txt could not be opened" << endl;
        return;
    }

    n_functions = 0;
    string s;
    while (getline(fin, s)) {
        if (n_functions >= max_functions) {
            logfile << "plotter-input.txt contains too many functions" << endl;
            return;
        }

        function_strings[n_functions] = s;

        n_functions++;
    }

    parse_function_strings();
}

void render(int width, int height)
{
    current_width = width;
    current_height = height;

    draw_function_strings();
    draw_coordinate_grid();
    draw_mouse_position();

    for (int i = 0; i < n_functions; ++i)
        if (expression_roots[i]->type != NodeType::Error) {
            draw_function_line(expression_roots[i], colors_to_use[i%n_colors]);
        }
}

void draw_function_strings()
{
    for (int i = 0; i < n_functions; ++i) {
        Color color = colors_to_use[i % n_colors];

        string message = "f(x) = " + function_strings[i];

        if (expression_roots[i]->type == NodeType::Error) {
            message += "    Error: " + expression_roots[i]->error_message + ". ";
        }

        draw_text(20 + 20*i, 20, color, text_size, message.c_str());
    }
}

void draw_coordinate_grid()
{
    int width_2 = current_width / 2;
    int height_2 = current_height / 2;
    double width_xrange = current_width / x_range;
    Color gray {200, 200, 200};

    draw_horizontal_line(height_2, {0,0,0}); // Ox
    draw_vertical_line(width_2, {0,0,0}); // Oy
    draw_number_bottom_left(0, height_2, width_2); // O

    // vertical grid lines
    int x_max = floor(x_range / 2);
    for (int x = -x_max; x <= x_max; ++x)
        if (x != 0) {
            int c = width_2 + x * width_xrange;
            draw_vertical_line(c, gray);
            draw_number_bottom_left(x, height_2, c);
        }

    // horizontal grid lines
    double y_range = x_range * current_height / current_width;
    int y_max = floor(y_range / 2);
    for (int y = -y_max; y <= y_max; ++y)
        if (y != 0) {
            int r = height_2 - y * width_xrange;
            draw_horizontal_line(r, gray);
            draw_number_bottom_left(y, r, width_2);
        }
}

void draw_horizontal_line(int row, Color color)
{
    if (row >= 0 && row < current_height)
        for (int c = 0; c < current_width; ++c)
            draw_pixel(row, c, color);
}

void draw_vertical_line(int col, Color color)
{
    if (col >= 0 && col < current_width)
        for (int r = 0; r < current_height; ++r)
            draw_pixel(r, col, color);
}

void draw_number_bottom_left(int val, int row, int col)
{
    string str = to_string(val);
    int r = row + text_size; // height needed
    int c = col - str.size() * 7; // approximation for width needed

    if (r >= 0 && r < current_height && c >= 0 && c < current_width)
        draw_text(r, c, {0,0,0}, text_size, str.c_str());
}

void draw_mouse_position()
{
    if (mouse_row != -1 && mouse_col != -1) {
        double xrange_width = x_range / current_width;
        double x = (mouse_col - current_width/2) * xrange_width;
        double y = (current_height/2 - mouse_row) * xrange_width;
        draw_text(
            current_height - 10, 10, {0,0,0}, text_size,
            ("(" + to_string(x) + ", " + to_string(y) + ")").c_str());
    }
}

void draw_function_line(Node *expression, Color color)
{
    double xrange_width = x_range / current_width;
    int width_2 = current_width / 2;
    int height_2 = current_height / 2;

    bool is_prev_nan = true;
    int prev_row;

    for (int c = 0; c < current_width; ++c) {
        double x = (c - width_2) * xrange_width;
        double y = evaluate_expression(expression, x);
        if (isnan(y)) {
            is_prev_nan = true;
            continue;
        }

        int r = height_2 - y / xrange_width;

        if (is_prev_nan || prev_row == r) {
            if (r >= 0 && r < current_height)
                draw_pixel(r, c, color);
        }
        else if ((prev_row >= 0 && prev_row < current_height)
                 || (r >= 0 && r < current_height)) {

            // fill rows on c-1 from prev_row to r with decreasing color intensity
            // fill rows on c from prev_row to r with increasing color intensity
            int from = min(max(prev_row,0), current_height - 1);
            int to = min(max(r,0), current_height - 1);
            double d_from_to = abs(from - to);
            int step = to > from ? 1 : -1;

            int i = from;
            while (true) {
                draw_pixel(i, c-1, get_lighter_color(color, abs(i-to) / d_from_to));
                draw_pixel(i, c, get_lighter_color(color, abs(i-from) / d_from_to));
                if (i == to)
                    break;
                i += step;
            }
        }

        is_prev_nan = false;
        prev_row = r;
    }
}

Color get_lighter_color(Color base_color, double rate)
{
    rate = 0.3 + 0.7 * rate; // offset to avoid white

    Color result;
    result.r = base_color.r * rate + 255*(1-rate);
    result.g = base_color.g * rate + 255*(1-rate);
    result.b = base_color.b * rate + 255*(1-rate);
    return result;
}

bool on_move(int row, int col)
{
    mouse_row = row;
    mouse_col = col;
    return true;
}

bool on_scroll(int) { return false; }
bool on_mouse_down() { return false; }
bool on_mouse_up() { return false; }
bool on_key_down(int) { return false; }
bool on_key_up(int) { return false; }

void parse_function_strings()
{
    for (int i = 0; i < n_functions; ++i) {
        int pos = 0;
        expression_roots[i] = parse_expr(function_strings[i], pos);
    }
}

double evaluate_expression(Node *node, double x)
{
    if (node->type == NodeType::Number)
        return node->numeric_value;

    if (node->type == NodeType::X)
        return x;

    if (node->type == NodeType::Neg)
        return - evaluate_expression(node->children[0], x);

    if (node->type == NodeType::Inv) {
        double result_from_child = evaluate_expression(node->children[0], x);
        if (result_from_child == 0)
            return NAN;
        else
            return 1.0 / result_from_child;
    }

    if (node->type == NodeType::Plus) {
        double result = 0.0;

        for (int i = 0; i < node->nr_children; ++i)
            result += evaluate_expression(node->children[i], x);

        return result;
    }

    if (node->type == NodeType::Sin)
        return sin(evaluate_expression(node->children[0], x));

    if (node->type == NodeType::Mul) {
        double result = 1.0;

        for (int i = 0; i < node->nr_children; ++i)
            result *= evaluate_expression(node->children[i], x);

        return result;
    }

    return NAN;
}


bool is_end(const string &str, int pos)
{
    return pos >= (int)str.size();
}

void skip_space(const string &str, int &pos)
{
    while (!is_end(str, pos) && isspace(str[pos]))
        ++pos;
}

Node *create_empty_node(NodeType type)
{
    Node *result = new Node;
    result->type = type;
    result->nr_children = 0;
    return result;
}

Node *create_error_node(string message)
{
    Node *result = create_empty_node(NodeType::Error);
    result->error_message = message;
    return result;
}

void delete_expression(Node *root)
{
    if (root == nullptr) return;

    for (int i = 0; i < root->nr_children; ++i)
        delete_expression(root->children[i]);

    delete root;
}

Node *parse_expr(const string &str, int &pos)
{
    bool result_is_single_token = true;
    Node *result = parse_term(str, pos);
    if (result->type == NodeType::Error)
        return result;

    skip_space(str,pos);
    while (!is_end(str, pos) && (str[pos] == '+' || str[pos] == '-')) {
        bool is_minus = str[pos] == '-';
        ++pos;

        if (result->nr_children == max_children) {
            result->type = NodeType::Error;
            result->error_message = "expression too long";
            break;
        }

        Node *term = parse_term(str, pos);
        if (term->type == NodeType::Error)
        {
            result->type = NodeType::Error;
            result->error_message = term->error_message;
            delete_expression(term);
            break;
        }

        if (is_minus) {
            Node *parent = create_empty_node(NodeType::Neg);
            parent->nr_children = 1;
            parent->children[0] = term;
            term = parent;
        }

        if (result_is_single_token) {
            result_is_single_token = false;
            Node *parent = create_empty_node(NodeType::Plus);
            parent->nr_children = 1;
            parent->children[0] = result;
            result = parent;
        }

        result->children[result->nr_children] = term;
        ++result->nr_children;

        skip_space(str,pos);
    }

    return result;
}


Node *parse_term(const string &str, int &pos)
{
    bool result_is_single_token = true;
    Node *result = parse_token(str, pos);
    if (result->type == NodeType::Error)
        return result;

    skip_space(str,pos);
    while (!is_end(str, pos) && (str[pos] == '*' || str[pos] == '/')) {
        char operation = str[pos];
        ++pos;

        if (result->nr_children == max_children) {
            result->type = NodeType::Error;
            result->error_message = "expression too long";
            return result;
        }

        Node *token = parse_token(str, pos);
        if (token->type == NodeType::Error)
        {
            result->type = NodeType::Error;
            result->error_message = token->error_message;
            delete_expression(token);
            break;
        }

        if (operation == '/') {
            Node *parent = create_empty_node(NodeType::Inv);
            parent->nr_children = 1;
            parent->children[0] = token;
            token = parent;
        }

        if (result_is_single_token) {
            result_is_single_token = false;
            Node *parent = create_empty_node(NodeType::Mul);
            parent->nr_children = 1;
            parent->children[0] = result;
            result = parent;
        }

        result->children[result->nr_children] = token;
        ++result->nr_children;

        skip_space(str,pos);
    }

    return result;
}


Node *parse_token(const string &str, int &pos)
{
    skip_space(str, pos);
    if (is_end(str, pos))
        return create_error_node("unexpected end of string");

    if (str[pos] == 'x' || str[pos] == 'X') {
        ++pos;
        return create_empty_node(NodeType::X);
    }

    if (isdigit(str[pos]) || str[pos] == '.' || str[pos] == '-') {
        Node *result = create_empty_node(NodeType::Number);

        try {
            size_t nr_read;
            result->numeric_value = stod(str.substr(pos), &nr_read);
            pos += nr_read;
        }
        catch (...) {
            result->type = NodeType::Error;
            result->error_message = "invalid number at position " + to_string(pos);
        }

        return result;
    }

    if (str[pos] == '(') {
        ++pos;
        Node *result = parse_expr(str, pos);
        if (result->type == NodeType::Error)
            return result;

        skip_space(str, pos);
        if (is_end(str, pos) || str[pos] != ')') {
            delete_expression(result);
            return create_error_node("missing ')'");
        }

        ++pos;
        return result;
    }

    // sin(...)
    if (pos <= (int)str.size()-5 && str[pos] == 's' && str[pos+1] == 'i' && str[pos+2] == 'n' && str[pos+3] == '(') {
        pos+=3;

        Node *result = parse_token(str, pos);
        if (result->type == NodeType::Error)
            return result;

        Node *parent = create_empty_node(NodeType::Sin);
        parent->nr_children = 1;
        parent->children[0] = result;
        return parent;
    }

    return create_error_node(string("unexpected character '") + str[pos] + "'");
}
