/*
    Preparations for a function plotter program: expression evaluation.

    This file is part of Simple Graphics Framework (SGF).

    Copyright 2023 Arnold Beiland

    Simple Graphics Framework is free software: you can redistribute it and/or
    modify it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or (at your
    option) any later version.

    This program is distributed in the hope that it will be useful, but
    WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY
    or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for
    more details.

    You should have received a copy of the GNU General Public License along with
    this program (look for a file named COPYING in the top directory). If not, see
    <https://www.gnu.org/licenses/>.
*/

#include "../sgf.h"
#include <cctype>
#include <iostream>
#include <cmath>
#include <string>
#include <fstream>
using namespace std;

const int max_children = 20;
const int max_functions = 100;
const int text_size = 11;

const Color colors_to_use[max_functions] = {
    {255,0,0}, {0,200,0}, {0,0,255}, {154,72,0}, {153,51,255}, {0,153,153}, {102,102,153} };
const int n_colors = sizeof(colors_to_use) / sizeof(Color);

enum NodeType { Plus, Mul, X, Number, Inv, Neg, Error, Sin };

struct Node {
    NodeType type;
    int nr_children;
    Node *children[max_children];

    double numeric_value;
    string error_message;
};

int n_functions;
string function_strings[max_functions];
Node *expression_roots[max_functions];

void parse_function_strings();
double evaluate_expression(Node *node, double x);

/* expression manipulation */
Node *parse_expr(const string &str, int &pos); // <term> + <term> + ... + <term>
Node *parse_term(const string &str, int &pos); // <token> * <token> * ... * <token>
Node *parse_token(const string &str, int &pos); // x, X, 12.3, 0.5, -9, (expression)

void initialize()
{
    srand(time(NULL));

    ifstream fin("plotter-input.txt");
    if (!fin) {
        logfile << "plotter-input.txt could not be opened" << endl;
        return;
    }

    n_functions = 0;
    string s;
    while (getline(fin, s)) {
        if (n_functions >= max_functions) {
            logfile << "plotter-input.txt contains too many functions" << endl;
            return;
        }

        function_strings[n_functions] = s;

        n_functions++;
    }

    parse_function_strings();
}

void render(int, int)
{
    for (int i = 0; i < n_functions; ++i) {
        Color color = colors_to_use[i % n_colors];

        string message = "f(x) = " + function_strings[i];

        if (expression_roots[i]->type == NodeType::Error) {
            message += "    Error: " + expression_roots[i]->error_message + ". ";
        }
        else {
            double f_minus2 = evaluate_expression(expression_roots[i], -2.0);
            double f_minus1 = evaluate_expression(expression_roots[i], -1.0);
            double f_0 = evaluate_expression(expression_roots[i], 0.0);
            double f_1 = evaluate_expression(expression_roots[i], 1.0);
            double f_2 = evaluate_expression(expression_roots[i], 2.0);

            char values[200];
            sprintf(values, "    %f %f %f %f %f", f_minus2, f_minus1, f_0, f_1, f_2);

            message += values;
        }

        draw_text(20 + 20*i, 20, color, text_size, message.c_str());
    }
}


bool on_scroll(int) { return false; }
bool on_move(int, int) { return false; }
bool on_mouse_down() { return false; }
bool on_mouse_up() { return false; }
bool on_key_down(int) { return false; }
bool on_key_up(int) { return false; }

void parse_function_strings()
{
    for (int i = 0; i < n_functions; ++i) {
        int pos = 0;
        expression_roots[i] = parse_expr(function_strings[i], pos);
    }
}

double evaluate_expression(Node *node, double x)
{
    if (node->type == NodeType::Number)
        return node->numeric_value;

    if (node->type == NodeType::X)
        return x;

    if (node->type == NodeType::Neg)
        return - evaluate_expression(node->children[0], x);

    if (node->type == NodeType::Inv) {
        double result_from_child = evaluate_expression(node->children[0], x);
        if (result_from_child == 0)
            return NAN;
        else
            return 1.0 / result_from_child;
    }

    if (node->type == NodeType::Plus) {
        double result = 0.0;

        for (int i = 0; i < node->nr_children; ++i)
            result += evaluate_expression(node->children[i], x);

        return result;
    }

    if (node->type == NodeType::Sin)
        return sin(evaluate_expression(node->children[0], x));

    if (node->type == NodeType::Mul) {
        double result = 1.0;

        for (int i = 0; i < node->nr_children; ++i)
            result *= evaluate_expression(node->children[i], x);

        return result;
    }

    return NAN;
}


bool is_end(const string &str, int pos)
{
    return pos >= (int)str.size();
}

void skip_space(const string &str, int &pos)
{
    while (!is_end(str, pos) && isspace(str[pos]))
        ++pos;
}

Node *create_empty_node(NodeType type)
{
    Node *result = new Node;
    result->type = type;
    result->nr_children = 0;
    return result;
}

Node *create_error_node(string message)
{
    Node *result = create_empty_node(NodeType::Error);
    result->error_message = message;
    return result;
}

void delete_expression(Node *root)
{
    if (root == nullptr) return;

    for (int i = 0; i < root->nr_children; ++i)
        delete_expression(root->children[i]);

    delete root;
}

Node *parse_expr(const string &str, int &pos)
{
    bool result_is_single_token = true;
    Node *result = parse_term(str, pos);
    if (result->type == NodeType::Error)
        return result;

    skip_space(str,pos);
    while (!is_end(str, pos) && (str[pos] == '+' || str[pos] == '-')) {
        bool is_minus = str[pos] == '-';
        ++pos;

        if (result->nr_children == max_children) {
            result->type = NodeType::Error;
            result->error_message = "expression too long";
            break;
        }

        Node *term = parse_term(str, pos);
        if (term->type == NodeType::Error)
        {
            result->type = NodeType::Error;
            result->error_message = term->error_message;
            delete_expression(term);
            break;
        }

        if (is_minus) {
            Node *parent = create_empty_node(NodeType::Neg);
            parent->nr_children = 1;
            parent->children[0] = term;
            term = parent;
        }

        if (result_is_single_token) {
            result_is_single_token = false;
            Node *parent = create_empty_node(NodeType::Plus);
            parent->nr_children = 1;
            parent->children[0] = result;
            result = parent;
        }

        result->children[result->nr_children] = term;
        ++result->nr_children;

        skip_space(str,pos);
    }

    return result;
}


Node *parse_term(const string &str, int &pos)
{
    bool result_is_single_token = true;
    Node *result = parse_token(str, pos);
    if (result->type == NodeType::Error)
        return result;

    skip_space(str,pos);
    while (!is_end(str, pos) && (str[pos] == '*' || str[pos] == '/')) {
        char operation = str[pos];
        ++pos;

        if (result->nr_children == max_children) {
            result->type = NodeType::Error;
            result->error_message = "expression too long";
            return result;
        }

        Node *token = parse_token(str, pos);
        if (token->type == NodeType::Error)
        {
            result->type = NodeType::Error;
            result->error_message = token->error_message;
            delete_expression(token);
            break;
        }

        if (operation == '/') {
            Node *parent = create_empty_node(NodeType::Inv);
            parent->nr_children = 1;
            parent->children[0] = token;
            token = parent;
        }

        if (result_is_single_token) {
            result_is_single_token = false;
            Node *parent = create_empty_node(NodeType::Mul);
            parent->nr_children = 1;
            parent->children[0] = result;
            result = parent;
        }

        result->children[result->nr_children] = token;
        ++result->nr_children;

        skip_space(str,pos);
    }

    return result;
}


Node *parse_token(const string &str, int &pos)
{
    skip_space(str, pos);
    if (is_end(str, pos))
        return create_error_node("unexpected end of string");

    if (str[pos] == 'x' || str[pos] == 'X') {
        ++pos;
        return create_empty_node(NodeType::X);
    }

    if (isdigit(str[pos]) || str[pos] == '.' || str[pos] == '-') {
        Node *result = create_empty_node(NodeType::Number);

        try {
            size_t nr_read;
            result->numeric_value = stod(str.substr(pos), &nr_read);
            pos += nr_read;
        }
        catch (...) {
            result->type = NodeType::Error;
            result->error_message = "invalid number at position " + to_string(pos);
        }

        return result;
    }

    if (str[pos] == '(') {
        ++pos;
        Node *result = parse_expr(str, pos);
        if (result->type == NodeType::Error)
            return result;

        skip_space(str, pos);
        if (is_end(str, pos) || str[pos] != ')') {
            delete_expression(result);
            return create_error_node("missing ')'");
        }

        ++pos;
        return result;
    }

    // sin(...)
    if (pos <= (int)str.size()-5 && str[pos] == 's' && str[pos+1] == 'i' && str[pos+2] == 'n' && str[pos+3] == '(') {
        pos+=3;

        Node *result = parse_token(str, pos);
        if (result->type == NodeType::Error)
            return result;

        Node *parent = create_empty_node(NodeType::Sin);
        parent->nr_children = 1;
        parent->children[0] = result;
        return parent;
    }

    return create_error_node(string("unexpected character '") + str[pos] + "'");
}
