/***********************************************************************************
                 OpenCS Project: www.daetools.com
                 Copyright (C) Dragan Nikolic
************************************************************************************
OpenCS is free software; you can redistribute it and/or modify it under the terms
of the GNU Lesser General Public License version 3 as published by the Free Software
Foundation. OpenCS is distributed in the hope that it will be useful, but WITHOUT
ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public License along with
the OpenCS software; if not, see <http://www.gnu.org/licenses/>.
***********************************************************************************/
#if !defined(BURGERS_EQUATIONS_CV_2D_MODEL_H)
#define BURGERS_EQUATIONS_CV_2D_MODEL_H

#include <vector>
#include <map>
#include <OpenCS/models/cs_number.h>
#include <OpenCS/models/cs_kernel.h>
using namespace cs;

const char* simulation_options =
#include "dae_example_4_cv-simulation_options.json"
;

class burgers_groups_2D
{
public:
    typedef csNumber_t (burgers_groups_2D:: *cs_function_ptr)(int, int);

    real_t u0;
    real_t v0;
    real_t w0;
    real_t ni;
    real_t eps;
    real_t dx;
    real_t dy;

    int Nx;
    int Ny;
    int Nequations;

    // Current simulation time
    csNumber_t t;

    // Variable values
    csNumber_t* u_data;
    csNumber_t* v_data;
    csNumber_t* uman_data;
    csNumber_t* vman_data;

    // Variable derivatives
    csNumber_t* du_dt_data;
    csNumber_t* dv_dt_data;

    std::vector<real_t> x_domain;
    std::vector<real_t> y_domain;

    int u_start_index;
    int v_start_index;
    int uman_start_index;
    int vman_start_index;

    burgers_groups_2D(int nx, int ny)
    {
        u0  = (1.0);
        v0  = (1.0);
        w0  = (0.1);
        ni  = (0.7);
        eps = (0.001);

        Nx  = nx+1;
        Ny  = ny+1;
        Nequations = 4 * Nx * Ny;

        dx  = (0.7 + 0.1) / (Nx-1);
        dy  = (0.8 - 0.2) / (Ny-1);

        u_start_index    = 0*Nx*Ny;
        v_start_index    = 1*Nx*Ny;
        uman_start_index = 2*Nx*Ny;
        vman_start_index = 3*Nx*Ny;

        x_domain.resize(Nx);
        y_domain.resize(Ny);
        for(int x = 0; x < Nx; x++)
            x_domain[x] = -0.1 + x * dx;
        for(int y = 0; y < Ny; y++)
            y_domain[y] = 0.2 + y * dy;

        u_data         = NULL;
        v_data         = NULL;
        du_dt_data     = NULL;
        dv_dt_data     = NULL;
        uman_data      = NULL;
        vman_data      = NULL;
    }

    // First order partial derivative per x (centered finite difference).
    csNumber_t df_dx(cs_function_ptr f, int x, int y)
    {
        return ( ((this->*f)(x+1,  y) - (this->*f)(x-1,  y)) / (2*dx) );
    }

    // First order partial derivative per y (centered finite difference).
    csNumber_t df_dy(cs_function_ptr f, int x, int y)
    {
        return ( ((this->*f)(x,  y+1) - (this->*f)(x,  y-1)) / (2*dy) );
    }

    // Second order partial derivative per x (centered finite difference).
    csNumber_t d2f_dx2(cs_function_ptr f, int x, int y)
    {
        return ( ((this->*f)(x+1,  y) - 2*(this->*f)(x,y) + (this->*f)(x-1,  y)) / (dx*dx) );
    }

    // Second order partial derivative per y (centered finite difference).
    csNumber_t d2f_dy2(cs_function_ptr f, int x, int y)
    {
        return ( ((this->*f)(x,  y+1) - 2*(this->*f)(x,y) + (this->*f)(x,  y-1)) / (dy*dy) );
    }

    // Partial derivative per x of a product of two variables: d(f1*f2)/dx = (df1/dx)*f2 + f1*(df2/dx).
    csNumber_t product_rule_dx(cs_function_ptr f1, cs_function_ptr f2, int x, int y)
    {
        return ( df_dx(f1,x,y)*(this->*f2)(x,y) + (this->*f1)(x,y)*df_dx(f2,x,y) );
    }

    // Partial derivative per y of a product of two variables: d(f1*f2)/dy = (df1/dy)*f2 + f1*(df2/dy).
    csNumber_t product_rule_dy(cs_function_ptr f1, cs_function_ptr f2, int x, int y)
    {
        return ( df_dy(f1,x,y)*(this->*f2)(x,y) + (this->*f1)(x,y)*df_dy(f2,x,y) );
    }

    // The arrays used by a DAE solver are flat 1D.
    // However, the variables are distributed on a rectangular 2D domain.
    // This function returns a position in the 1D array for given x,y indexes.
    int getIndex(int x, int y)
    {
        if(x < 0 || x >= Nx)
            throw std::runtime_error("Invalid x index");
        if(y < 0 || y >= Ny)
            std::runtime_error("Invalid y index");
        return Ny*x + y;
    }

    // Variable values/derivativs access functions.
    csNumber_t u(int x, int y)
    {
        int index = getIndex(x,y);
        return u_data[index];
    }
    csNumber_t v(int x, int y)
    {
        int index = getIndex(x,y);
        return v_data[index];
    }
    csNumber_t du_dt(int x, int y)
    {
        int index = getIndex(x,y);
        return du_dt_data[index];
    }
    csNumber_t dv_dt(int x, int y)
    {
        int index = getIndex(x,y);
        return dv_dt_data[index];
    }

    csNumber_t u_man(int x, int y)
    {
        int index = getIndex(x,y);
        return uman_data[index];
    }
    csNumber_t v_man(int x, int y)
    {
        int index = getIndex(x,y);
        return vman_data[index];
    }

    // Values of x and y domains.
    real_t xd           (int x)  { return x_domain[x]; }
    real_t yd           (int y)  { return y_domain[y]; }

    // Centered finite difference derivative functions.
    csNumber_t duu_dx   (int x, int y) { return product_rule_dx(&burgers_groups_2D::u, &burgers_groups_2D::u, x,y); }
    csNumber_t duv_dy   (int x, int y) { return product_rule_dy(&burgers_groups_2D::u, &burgers_groups_2D::v, x,y); }
    csNumber_t d2u_dx2  (int x, int y) { return d2f_dx2(&burgers_groups_2D::u, x,y); }
    csNumber_t d2u_dy2  (int x, int y) { return d2f_dy2(&burgers_groups_2D::u, x,y); }

    csNumber_t dvu_dx   (int x, int y) { return product_rule_dx(&burgers_groups_2D::v, &burgers_groups_2D::u, x,y); }
    csNumber_t dvv_dy   (int x, int y) { return product_rule_dy(&burgers_groups_2D::v, &burgers_groups_2D::v, x,y); }
    csNumber_t d2v_dx2  (int x, int y) { return d2f_dx2(&burgers_groups_2D::v, x,y); }
    csNumber_t d2v_dy2  (int x, int y) { return d2f_dy2(&burgers_groups_2D::v, x,y); }

    csNumber_t Su       (int x, int y) { return dum_dt(x,y) + (dumum_dx(x,y) + dumvm_dy(x,y)) - ni * (d2um_dx2(x,y) + d2um_dy2(x,y)); }
    csNumber_t Sv       (int x, int y) { return dvm_dt(x,y) + (dvmum_dx(x,y) + dvmvm_dy(x,y)) - ni * (d2vm_dx2(x,y) + d2vm_dy2(x,y)); }

    // Manufactured solutions and source terms.
    csNumber_t um       (int x, int y) { return u0 * (cs::sin(xd(x)*xd(x) + yd(y)*yd(y) + w0*t) + eps); } // [*] kernel issue, f(x,y)
    csNumber_t vm       (int x, int y) { return v0 * (cs::cos(xd(x)*xd(x) + yd(y)*yd(y) + w0*t) + eps); } // [*] kernel issue, f(x,y)

    // Derivatives are symbolic
    csNumber_t dumum_dx (int x, int y)
    {
        return 2 * um(x,y) * dum_dx(x,y);
    }
    csNumber_t dumvm_dy (int x, int y)
    {
        return um(x,y) * dvm_dy(x,y) + dum_dy(x,y) * vm(x,y);
    }
    csNumber_t d2um_dx2 (int x, int y)
    {
        // The chain rule: df(n)/dx = (df/dn) * (dn/dx)
        csNumber_t fn = xd(x)*xd(x) + yd(y)*yd(y) + w0*t;
        csNumber_t f_prim   = 2 * xd(x);
        csNumber_t sin_prim = cs::cos(fn);
        csNumber_t dum_dx = u0 * f_prim * sin_prim;

        csNumber_t f_prim_prim = 2;
        csNumber_t sin_prim_prim = f_prim * (-cs::sin(fn));
        //return u0 * (f_prim * sin_prim_prim + f_prim_prim * sin_prim);
        return u0 * 2 * (-2*xd(x)*xd(x)*sin(fn) + cos(fn));
        //          2             2    2                2    2
        //  2⋅(- 2⋅x ⋅sin(t⋅w₀ + x  + y ) + cos(t⋅w₀ + x  + y ))
    }
    csNumber_t d2um_dy2 (int x, int y)
    {
        // The chain rule: df(n)/dx = (df/dn) * (dn/dx)
        csNumber_t fn = xd(x)*xd(x) + yd(y)*yd(y) + w0*t;
        csNumber_t f_prim   = 2 * yd(y);
        csNumber_t sin_prim = cs::cos(fn);
        csNumber_t dum_dy = u0 * f_prim * sin_prim;

        csNumber_t f_prim_prim = 2;
        csNumber_t sin_prim_prim = f_prim * (-cs::sin(fn));
        //return u0 * (f_prim * sin_prim_prim + f_prim_prim * sin_prim);
        return u0 * 2 * (-2*yd(y)*yd(y)*sin(fn) + cos(fn));
        //          2             2    2                2    2
        //  2⋅(- 2⋅y ⋅sin(t⋅w₀ + x  + y ) + cos(t⋅w₀ + x  + y ))
    }

    csNumber_t dvmum_dx (int x, int y)
    {
        return vm(x,y) * dum_dx(x,y) + dvm_dx(x,y) * um(x,y);
    }
    csNumber_t dvmvm_dy (int x, int y)
    {
        return 2 * vm(x,y) * dvm_dy(x,y);
    }
    csNumber_t d2vm_dx2 (int x, int y)
    {
        // The chain rule: df(n)/dx = (df/dn) * (dn/dx)
        csNumber_t fn = xd(x)*xd(x) + yd(y)*yd(y) + w0*t;
        csNumber_t f_prim   = 2 * xd(x);
        csNumber_t cos_prim = -cs::sin(fn);
        csNumber_t dvm_dx = v0 * f_prim * cos_prim;

        csNumber_t f_prim_prim = 2;
        csNumber_t cos_prim_prim = f_prim * (-cs::cos(fn));
        //return v0 * (f_prim * cos_prim_prim + f_prim_prim * cos_prim);
        return - v0 * 2 * (2*xd(x)*xd(x)*cos(fn) + sin(fn));
        //        2             2    2                2    2
        // -2⋅(2⋅x ⋅cos(t⋅w₀ + x  + y ) + sin(t⋅w₀ + x  + y ))
    }
    csNumber_t d2vm_dy2 (int x, int y)
    {
        // The chain rule: df(n)/dx = (df/dn) * (dn/dx)
        csNumber_t fn = xd(x)*xd(x) + yd(y)*yd(y) + w0*t;
        csNumber_t f_prim   = 2 * yd(y);
        csNumber_t cos_prim = -cs::sin(fn);
        csNumber_t dvm_dy = v0 * f_prim * cos_prim;

        csNumber_t f_prim_prim = 2;
        csNumber_t cos_prim_prim = f_prim * (-cs::cos(fn));
        //return v0 * (f_prim * cos_prim_prim + f_prim_prim * cos_prim);
        return - v0 * 2 * (2*yd(y)*yd(y)*cos(fn) + sin(fn));
        //        2             2    2                2    2
        // -2⋅(2⋅y ⋅cos(t⋅w₀ + x  + y ) + sin(t⋅w₀ + x  + y ))
    }

    csNumber_t dum_dt(int x, int y)
    {
        // The chain rule: df(n)/dx = (df/dn) * (dn/dx)
        // f(t) = x^2 + y^2 + w0*t
        // sin( f(t) )' = sin'( f(t) ) * f'(t) = cos( f(t) ) * w0
        csNumber_t fn = xd(x)*xd(x) + yd(y)*yd(y) + w0*t;
        csNumber_t f_prim   = w0;
        csNumber_t sin_prim = cs::cos(fn);
        //return u0 * f_prim * sin_prim;
        return u0 * w0 * cos(fn);
        //                 2    2
        //  w₀⋅cos(t⋅w₀ + x  + y )
    }
    csNumber_t dvm_dt(int x, int y)
    {
        // The chain rule: df(n)/dx = (df/dn) * (dn/dx)
        // f(t) = x^2 + y^2 + w0*t
        // cos( f(t) )' = cos'( f(t) ) * f'(t) = -sin( f(t) ) * w0
        csNumber_t fn = xd(x)*xd(x) + yd(y)*yd(y) + w0*t;
        csNumber_t f_prim   = w0;
        csNumber_t cos_prim = -cs::sin(fn);
        //return v0 * f_prim * cos_prim;
        return -v0 * w0 * sin(fn);
        //                 2    2
        // -w₀⋅sin(t⋅w₀ + x  + y )
    }
    csNumber_t dum_dx (int x, int y)
    {
        // The chain rule: df(n)/dx = (df/dn) * (dn/dx)
        csNumber_t fn = xd(x)*xd(x) + yd(y)*yd(y) + w0*t;
        csNumber_t f_prim   = 2 * xd(x);
        csNumber_t sin_prim = cs::cos(fn);
        //return u0 * f_prim * sin_prim;
        return u0 * 2 * xd(x) * cos(fn);
        //                 2    2
        // 2⋅x⋅cos(t⋅w₀ + x  + y )
    }
    csNumber_t dum_dy (int x, int y)
    {
        // The chain rule: df(n)/dx = (df/dn) * (dn/dx)
        csNumber_t fn = xd(x)*xd(x) + yd(y)*yd(y) + w0*t;
        csNumber_t f_prim   = 2 * yd(x);
        csNumber_t sin_prim = cs::cos(fn);
        //return u0 * f_prim * sin_prim;
        return u0 * 2 * yd(y) * cos(fn);
        //                 2    2
        // 2⋅y⋅cos(t⋅w₀ + x  + y )
    }
    csNumber_t dvm_dx(int x, int y)
    {
        // The chain rule: df(n)/dx = (df/dn) * (dn/dx)
        csNumber_t fn = xd(x)*xd(x) + yd(y)*yd(y) + w0*t;
        csNumber_t f_prim   = 2 * xd(x);
        csNumber_t cos_prim = -cs::sin(fn);
        //return v0 * f_prim * cos_prim;
        return -v0 * 2 * xd(x) * sin(fn);
        //                 2    2
        //-2⋅x⋅sin(t⋅w₀ + x  + y )
    }
    csNumber_t dvm_dy(int x, int y)
    {
        // The chain rule: df(n)/dx = (df/dn) * (dn/dx)
        csNumber_t fn = xd(x)*xd(x) + yd(y)*yd(y) + w0*t;
        csNumber_t f_prim   = 2 * yd(y);
        csNumber_t cos_prim = -cs::sin(fn);
        //return v0 * f_prim * cos_prim;
        return -v0 * 2 * yd(y) * sin(fn);
        //                 2    2
        //-2⋅y⋅sin(t⋅w₀ + x  + y )
    }

    void SetInitialConditions(std::vector<real_t>& x0)
    {
        int xi, yi;
        x0.assign(Nequations, 0.0);

        real_t* u_0  = &x0[u_start_index];
        real_t* v_0  = &x0[v_start_index];
        real_t* um_0 = &x0[uman_start_index];
        real_t* vm_0 = &x0[vman_start_index];

        for(xi = 0; xi < Nx; xi++)
        {
            for(yi = 0; yi < Ny; yi++)
            {
                int index = getIndex(xi,yi);
                real_t x = x_domain[xi];
                real_t y = y_domain[yi];

                u_0[index]  = u0 * (sin(x*x + y*y) + eps);
                v_0[index]  = v0 * (cos(x*x + y*y) + eps);
                um_0[index] = u0 * (sin(x*x + y*y) + eps);
                vm_0[index] = v0 * (cos(x*x + y*y) + eps);
            }
        }
    }

    void GetVariableNames(std::vector<std::string>& names)
    {
        const int bsize = 32;
        char buffer[bsize];
        int index = 0;

        names.resize(Nequations);
        for(int x = 0; x < Nx; x++)
        {
            for(int y = 0; y < Ny; y++)
            {
                std::snprintf(buffer, bsize, "%s(%d,%d)", "u", x, y);
                names[index] = buffer;
                index++;
            }
        }
        for(int x = 0; x < Nx; x++)
        {
            for(int y = 0; y < Ny; y++)
            {
                std::snprintf(buffer, bsize, "%s(%d,%d)", "v", x, y);
                names[index] = buffer;
                index++;
            }
        }
        for(int x = 0; x < Nx; x++)
        {
            for(int y = 0; y < Ny; y++)
            {
                std::snprintf(buffer, bsize, "%s(%d,%d)", "um", x, y);
                names[index] = buffer;
                index++;
            }
        }
        for(int x = 0; x < Nx; x++)
        {
            for(int y = 0; y < Ny; y++)
            {
                std::snprintf(buffer, bsize, "%s(%d,%d)", "vm", x, y);
                names[index] = buffer;
                index++;
            }
        }
    }

    void CreateEquations(const csNumber_t& time, const std::vector<csNumber_t>& values, const std::vector<csNumber_t>& derivs, std::vector<csEquation_t>& equations)
    {
        t = time;

        // Variable values:
        u_data        = const_cast<csNumber_t*>(&values[u_start_index]);
        v_data        = const_cast<csNumber_t*>(&values[v_start_index]);
        uman_data     = const_cast<csNumber_t*>(&values[uman_start_index]);
        vman_data     = const_cast<csNumber_t*>(&values[vman_start_index]);
        // Variable derivatives:
        du_dt_data    = const_cast<csNumber_t*>(&derivs[u_start_index]);
        dv_dt_data    = const_cast<csNumber_t*>(&derivs[v_start_index]);

        int x, y;

        //group_BCs = cs::csGroupPtr(new csGroup_t("BoundaryConditions", 1)); // the group for boundary conditions (group 1)
        //group_u   = cs::csGroupPtr(new csGroup_t("Burgers_u",          2)); // the group for u-component (group 2)
        //group_v   = cs::csGroupPtr(new csGroup_t("Burgers_v",          3)); // the group for v-component (group 3)
        //group_um  = cs::csGroupPtr(new csGroup_t("Burgers_um",         4)); // the group for uman-component (group 4)
        //group_vm  = cs::csGroupPtr(new csGroup_t("Burgers_vm",         5)); // the group for vman-component (group 5)

        //csEquation_t equation(group_BCs.get());   // Belongs to the group "BoundaryConditions"
        //csEquation_t equation_u(group_u.get());   // Belongs to the group "Burgers_u"
        //csEquation_t equation_v(group_v.get());   // Belongs to the group "Burgers_v"
        //csEquation_t equation_um(group_um.get()); // Belongs to the group "Burgers_um"
        //csEquation_t equation_vm(group_vm.get()); // Belongs to the group "Burgers_vm"

        csEquation_t equation(nullptr);

        for(x = 0; x < Nx; x++)
        {
            for(y = 0; y < Ny; y++)
            {
                if(x == 0)
                {
                    csNumber_t bc = (u(x,y) - um(x,y));
                    equation[ u(x,y) ] = bc;
                    equations.push_back(equation);
                }
                else if(x == Nx-1)
                {
                    csNumber_t bc = (u(x,y) - um(x,y));
                    equation[ u(x,y) ] = bc;
                    equations.push_back(equation);
                }
                else if(y == 0)
                {
                    csNumber_t bc = (u(x,y) - um(x,y));
                    equation[ u(x,y) ] = bc;
                    equations.push_back(equation);
                }
                else if(y == Ny-1)
                {
                    csNumber_t bc = (u(x,y) - um(x,y));
                    equation[ u(x,y) ] = bc;
                    equations.push_back(equation);
                }
                else
                {
                    csNumber_t pde = (du_dt(x,y) + (duu_dx(x,y) + duv_dy(x,y)) - ni * (d2u_dx2(x,y) + d2u_dy2(x,y)) - Su(x,y));
                    equation[ u(x,y) ] = pde;
                    equation.SetGridPoint(x, y);
                    equations.push_back(equation);
                }
            }
        }

        // v velocity component.
        for(x = 0; x < Nx; x++)
        {
            for(y = 0; y < Ny; y++)
            {
                if(x == 0)
                {
                    csNumber_t bc = v(x,y) - vm(x,y);
                    equation[ v(x,y) ] = bc;
                    equations.push_back(equation);
                }
                else if(x == Nx-1)
                {
                    csNumber_t bc = v(x,y) - vm(x,y);
                    equation[ v(x,y) ] = bc;
                    equations.push_back(equation);
                }
                else if(y == 0)
                {
                    csNumber_t bc = v(x,y) - vm(x,y);
                    equation[ v(x,y) ] = bc;
                    equations.push_back(equation);
                }
                else if(y == Ny-1)
                {
                    csNumber_t bc = v(x,y) - vm(x,y);
                    equation[ v(x,y) ] = bc;
                    equations.push_back(equation);
                }
                else
                {
                    csNumber_t pde = (dv_dt(x,y) + (dvu_dx(x,y) + dvv_dy(x,y)) - ni * (d2v_dx2(x,y) + d2v_dy2(x,y)) - Sv(x,y));
                    equation[ v(x,y) ] = pde;
                    equation.SetGridPoint(x, y);
                    equations.push_back(equation);
                }
            }
        }

        // Manufactured solutions for u and v velocity components (uman, vman).
        for(x = 0; x < Nx; x++)
        {
            for(y = 0; y < Ny; y++)
            {
                csNumber_t eq_um = u_man(x,y) - um(x,y);
                equation[ u_man(x,y) ] = eq_um;
                equations.push_back(equation);
            }
        }
        for(x = 0; x < Nx; x++)
        {
            for(y = 0; y < Ny; y++)
            {
                csNumber_t eq_vm = v_man(x,y) - vm(x,y);
                equation[ v_man(x,y) ] = eq_vm;
                equations.push_back(equation);
            }
        }
    }

    cs::csGroupPtr group_BCs;
    cs::csGroupPtr group_u;
    cs::csGroupPtr group_v;
    cs::csGroupPtr group_um;
    cs::csGroupPtr group_vm;
};

#endif