#include "OrdinaryBasisFunction.h"

#include "../Utils/DoubleEquality.h"
#include <cmath>
#include <vector>
#include <sstream>
#include <iomanip>

namespace cagd
{
    OrdinaryBasisFunction::OrdinaryBasisFunction():
        xPow(0),
        expCoef(0),
        trigCoef(0),
        isSine(false)
    {}

    OrdinaryBasisFunction::OrdinaryBasisFunction(unsigned xPow, double expCoef, double trigCoef, bool isSine):
        xPow(xPow),
        expCoef(expCoef),
        trigCoef(trigCoef),
        isSine(isSine)
    {}

    double OrdinaryBasisFunction::getDerivative(unsigned order, double input) const
    {
        if (xPow == 0) {
            if (doubleEquals(expCoef, 0)) {
                if (doubleEquals(trigCoef, 0)) return (order == 0 ? 1 : 0);
                else return trigDerivative(order, input);
            }
            else {
                if (doubleEquals(trigCoef, 0)) return expDerivative(order, input);
                else return generalLeibniz(
                        order, input,
                        &OrdinaryBasisFunction::expDerivative, &OrdinaryBasisFunction::trigDerivative);
            }
        }
        else {
            if (doubleEquals(expCoef, 0)) {
                if (doubleEquals(trigCoef, 0)) return polyDerivative(order, input);
                else return generalLeibniz(
                    order, input,
                    &OrdinaryBasisFunction::polyDerivative, &OrdinaryBasisFunction::trigDerivative);
            }
            else {
                if (doubleEquals(trigCoef, 0)) return generalLeibniz(
                    order, input,
                    &OrdinaryBasisFunction::polyDerivative, &OrdinaryBasisFunction::expDerivative);
                else return generalLeibniz(
                    order, input,
                    &OrdinaryBasisFunction::polyDerivative,
                    &OrdinaryBasisFunction::expDerivative,
                    &OrdinaryBasisFunction::trigDerivative);
            }
        }
    }

    double OrdinaryBasisFunction::polyDerivative(unsigned order, double input) const
    {
        if (order > xPow)
            return 0;
        else {
            double coef = 1;
            for (unsigned i = 0; i < order; ++i)
                coef *= xPow - i;
            return coef * std::pow(input, xPow - order);
        }
    }

    double OrdinaryBasisFunction::expDerivative(unsigned order,  double input) const
    {
        return std::pow(expCoef, order) * std::exp(expCoef * input);
    }

    double OrdinaryBasisFunction::trigDerivative(unsigned order, double input) const
    {
        int sgn =
            (isSine && (order % 4 >= 2)) ||
            (!isSine && ((order+1) % 4 >= 2))
            ? -1 : 1;

        double func = (isSine && (order % 2 == 0)) || (!isSine && (order%2 == 1))
            ? std::sin(trigCoef * input)
            : std::cos(trigCoef * input);

        return sgn * std::pow(trigCoef, order) * func;
    }

    double OrdinaryBasisFunction::generalLeibniz(
        unsigned order, double input,
        FunctionTerm f1, FunctionTerm f2) const
    {
        std::vector<double> fact(order + 1);
        fact[0] = 1;
        for (unsigned i = 1; i <= order; ++i)
            fact[i] = fact[i-1] * i;

        double sum = 0;
        for (unsigned ordFirst = 0; ordFirst <= order; ++ordFirst) {
            unsigned ordSecond = order - ordFirst;
            sum += fact[order] / (fact[ordFirst] * fact[ordSecond])
                * (this->*f1)(ordFirst, input)
                * (this->*f2)(ordSecond, input);
        }
        return sum;
    }

    double OrdinaryBasisFunction::generalLeibniz(
        unsigned order, double input,
        FunctionTerm f1, FunctionTerm f2, FunctionTerm f3) const
    {
        std::vector<double> fact(order + 1);
        fact[0] = 1;
        for (unsigned i = 1; i <= order; ++i)
            fact[i] = fact[i-1] * i;

        double sum = 0;
        for (unsigned ordFirst = 0; ordFirst <= order; ++ordFirst) {
            unsigned remains = order - ordFirst;
            for (unsigned ordSecond = 0; ordSecond <= remains; ++ordSecond) {
                unsigned ordThird = remains - ordSecond;
                sum += fact[order] / (fact[ordFirst] * fact[ordSecond] * fact[ordThird])
                    * (this->*f1)(ordFirst, input)
                    * (this->*f2)(ordSecond, input)
                    * (this->*f3)(ordThird, input);
            }
        }
        return sum;
    }

    std::string OrdinaryBasisFunction::getHTMLRepresentation(std::string variable) const
    {
        std::stringstream ss;
        bool wasAnythingWritten = false;

        if (xPow) {
            ss << variable;
            if (xPow != 1)
                ss << BEGIN_ALIGN_SUPER << xPow << END_ALIGN_SUPER;
            wasAnythingWritten = true;
        }

        if (!doubleEquals(expCoef, 0)) {
            if (wasAnythingWritten)
                ss << "·";

            ss << "e" << BEGIN_ALIGN_SUPER;

            if (doubleEquals(expCoef, -1)) ss << "-" << variable;
            else if (doubleEquals(expCoef, 1)) ss << variable;
            else ss << formatTrimmed(expCoef) << "·" << variable;

            ss << END_ALIGN_SUPER;
            wasAnythingWritten = true;
        }

        if (!doubleEquals(trigCoef, 0)) {
            if (wasAnythingWritten)
                ss << "·";

            ss << (isSine ? "sin(" : "cos(");

            if(!doubleEquals(trigCoef, 1))
                ss << formatTrimmed(trigCoef) << "·";
            ss << variable << ")";

            wasAnythingWritten = true;
        }

        if (!wasAnythingWritten)
            ss << "1";

        return ss.str();
    }

    std::string OrdinaryBasisFunction::formatTrimmed(double value) const
    {
        std::stringstream ss;
        ss << std::fixed << std::setprecision(6) << value;
        std::string result = ss.str();

        int nrTrailingChars = 0;
        for (int i = (int) result.size() - 1; i > 0; --i)
            if (result[i] == '0')
                ++nrTrailingChars;
            else if (result[i] == '.') {
                ++nrTrailingChars;
                break;
            }
            else break;

        return result.substr(0, result.size() - nrTrailingChars);
    }
}

