/***********************************************************************************
                 OpenCS Project: www.daetools.com
                 Copyright (C) Dragan Nikolic
************************************************************************************
OpenCS is free software; you can redistribute it and/or modify it under the terms
of the GNU Lesser General Public License version 3 as published by the Free Software
Foundation. OpenCS is distributed in the hope that it will be useful, but WITHOUT
ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A
PARTICULAR PURPOSE. See the GNU Lesser General Public License for more details.
You should have received a copy of the GNU Lesser General Public License along with
the OpenCS software; if not, see <http://www.gnu.org/licenses/>.
***********************************************************************************/
#ifndef BRUSSELATOR_GROUPS_MODEL_H
#define BRUSSELATOR_GROUPS_MODEL_H

#include <string>
#include <vector>
#include <map>
#include <iostream>
#include <math.h>
#include <OpenCS/models/cs_number.h>
#include <OpenCS/models/cs_kernel.h>
#include <OpenCS/cs_model.h>
using namespace cs;

const char* simulation_options =
#include "dae_example_3_groups-simulation_options.json"
;
const char* simulation_options_mpi =
#include "dae_example_3_groups-simulation_options_mpi.json"
;

class BrusselatorGroups_2D
{
public:
    BrusselatorGroups_2D(int nx, int ny, const csNumber_t& bc_u_flux, const csNumber_t& bc_v_flux):
        Nx(nx), Ny(ny), u_flux_bc(bc_u_flux), v_flux_bc(bc_v_flux)
    {
        Nequations = 2*Nx*Ny;

        x0 = 0.0;
        x1 = 1.0;
        y0 = 0.0;
        y1 = 1.0;
        dx = (x1-x0) / (Nx-1);
        dy = (y1-y0) / (Ny-1);

        x_domain.resize(Nx);
        y_domain.resize(Ny);
        for(int x = 0; x < Nx; x++)
            x_domain[x] = x0 + x * dx;
        for(int y = 0; y < Ny; y++)
            y_domain[y] = y0 + y * dy;

        u_start_index = 0*Nx*Ny;
        v_start_index = 1*Nx*Ny;

        eps1 = 0.002;
        eps2 = 0.002;
        A    = 1.000;
        B    = 3.400;

        u_data     = NULL;
        u_data     = NULL;
        du_dt_data = NULL;
        dv_dt_data = NULL;
    }

    void SetInitialConditions(std::vector<real_t>& uv0)
    {
        int ix, iy;
        uv0.assign(Nequations, 0.0);

        real_t pi = 3.1415926535898;
        real_t Lx = x1 - x0;
        real_t Ly = y1 - y0;

        real_t* u_0 = &uv0[u_start_index];
        real_t* v_0 = &uv0[v_start_index];

        for(ix = 0; ix < Nx; ix++)
        {
            for(iy = 0; iy < Ny; iy++)
            {
                int index = getIndex(ix,iy);

                real_t x = x_domain[ix];
                real_t y = y_domain[iy];

                u_0[index] = 1.0 - 0.5 * std::cos(pi * y / Ly);
                v_0[index] = 3.5 - 2.5 * std::cos(pi * x / Lx);
            }
        }
    }

    void GetVariableNames(std::vector<std::string>& names)
    {
        const int bsize = 32;
        char buffer[bsize];
        int index = 0;

        names.resize(Nequations);
        for(int x = 0; x < Nx; x++)
        {
            for(int y = 0; y < Ny; y++)
            {
                std::snprintf(buffer, bsize, "%s(%d,%d)", "u", x, y);
                names[index] = buffer;
                index++;
            }
        }
        for(int x = 0; x < Nx; x++)
        {
            for(int y = 0; y < Ny; y++)
            {
                std::snprintf(buffer, bsize, "%s(%d,%d)", "v", x, y);
                names[index] = buffer;
                index++;
            }
        }
    }

    void CreateEquations(const std::vector<csNumber_t>& values,
                         const std::vector<csNumber_t>& derivs,
                         csNumber_t time,
                         std::vector<csEquation_t>& equations)
    {
        u_data     = &values[u_start_index];
        v_data     = &values[v_start_index];
        du_dt_data = &derivs[u_start_index];
        dv_dt_data = &derivs[v_start_index];

        equations.reserve(2*Nx*Ny);

        // Set groups as data members since they get out of scope once the CreateEquations function returns.
        group_BCs = csGroupPtr(new csGroup_t("BoundaryConditions", 1)); // the group for boundary conditions (group 1)
        group_u   = csGroupPtr(new csGroup_t("Brusselator_u",      2)); // the group for u-component (group 2)
        group_v   = csGroupPtr(new csGroup_t("Brusselator_v",      3)); // the group for v-component (group 3)

        csEquation_t equation(group_BCs.get()); // Belongs to the group "BoundaryConditions".
        csEquation_t equation_u(group_u.get()); // Belongs to the group "Brusselator_u"
        csEquation_t equation_v(group_v.get()); // Belongs to the group "Brusselator_v"

        for(int x = 0; x < Nx; x++)
        {
            for(int y = 0; y < Ny; y++)
            {
                /* u component */
                if(x == 0)          // Left BC: Neumann BCs
                {
                    csNumber_t bc = du_dx(x,y) - u_flux_bc;

                    equation[ u(x,y) ] = bc;
                    equations.push_back(equation);
                }
                else if(x == Nx-1)  // Right BC: Neumann BCs
                {
                    csNumber_t bc = du_dx(x,y) - u_flux_bc;

                    equation[ u(x,y) ] = bc;
                    equations.push_back(equation);
                }
                else if(y == 0)     // Bottom BC: Neumann BCs
                {
                    csNumber_t bc = du_dy(x,y) - u_flux_bc;

                    equation[ u(x,y) ] = bc;
                    equations.push_back(equation);
                }
                else if(y == Ny-1)  // Top BC: Neumann BCs
                {
                    csNumber_t bc = du_dy(x,y) - u_flux_bc;

                    equation[ u(x,y) ] = bc;
                    equations.push_back(equation);
                }
                else
                {
                    // Interior points
                    csNumber_t u_pde = du_dt(x,y)                                   /* accumulation term */
                                       - eps1 * (d2u_dx2(x,y) + d2u_dy2(x,y))       /* diffusion term    */
                                       - (u(x,y)*u(x,y)*v(x,y) - (B+1)*u(x,y) + A); /* generation term   */

                    equation_u[ u(x,y) ] = u_pde;
                    equations.push_back(equation_u);
                }
            }
        }
        for(int x = 0; x < Nx; x++)
        {
            for(int y = 0; y < Ny; y++)
            {
                /* v component */
                if(x == 0)          // Left BC: Neumann BCs
                {
                    csNumber_t bc = dv_dx(x,y) - v_flux_bc;

                    equation[ v(x,y) ] = bc;
                    equations.push_back(equation);
                }
                else if(x == Nx-1)  // Right BC: Neumann BCs
                {
                    csNumber_t bc = dv_dx(x,y) - v_flux_bc;

                    equation[ v(x,y) ] = bc;
                    equations.push_back(equation);
                }
                else if(y == 0)     // Bottom BC: Neumann BCs
                {
                    csNumber_t bc = dv_dy(x,y) - v_flux_bc;

                    equation[ v(x,y) ] = bc;
                    equations.push_back(equation);
                }
                else if(y == Ny-1)  // Top BC: Neumann BCs
                {
                    csNumber_t bc = dv_dy(x,y) - v_flux_bc;

                    equation[ v(x,y) ] = bc;
                    equations.push_back(equation);
                }
                else
                {
                    // Interior points
                    csNumber_t v_pde = dv_dt(x,y)                             /* accumulation term */
                                       - eps2 * (d2v_dx2(x,y) + d2v_dy2(x,y)) /* diffusion term    */
                                       + u(x,y)*u(x,y)*v(x,y) - B*u(x,y);     /* generation term   */

                    equation_v[ v(x,y) ] = v_pde;
                    equations.push_back(equation_v);
                }
            }
        }
    }

protected:
    csNumber_t u(int x, int y)
    {
        int index = getIndex(x,y);
        return u_data[index];
    }
    csNumber_t v(int x, int y)
    {
        int index = getIndex(x,y);
        return v_data[index];
    }
    csNumber_t du_dt(int x, int y)
    {
        int index = getIndex(x,y);
        return du_dt_data[index];
    }
    csNumber_t dv_dt(int x, int y)
    {
        int index = getIndex(x,y);
        return dv_dt_data[index];
    }

    // First order partial derivative per x.
    csNumber_t du_dx(int x, int y)
    {
        if(x == 0) // left
        {
            const csNumber_t& u0 = u(0, y);
            const csNumber_t& u1 = u(1, y);
            const csNumber_t& u2 = u(2, y);
            return (-3*u0 + 4*u1 - u2) / (2*dx);
        }
        else if(x == Nx-1) // right
        {
            const csNumber_t& un  = u(Nx-1,   y);
            const csNumber_t& un1 = u(Nx-1-1, y);
            const csNumber_t& un2 = u(Nx-1-2, y);
            return (3*un - 4*un1 + un2) / (2*dx);
        }
        else
        {
            const csNumber_t& u1 = u(x+1, y);
            const csNumber_t& u2 = u(x-1, y);
            return (u1 - u2) / (2*dx);
        }
    }
    csNumber_t dv_dx(int x, int y)
    {
        if(x == 0) // left
        {
            const csNumber_t& u0 = v(0, y);
            const csNumber_t& u1 = v(1, y);
            const csNumber_t& u2 = v(2, y);
            return (-3*u0 + 4*u1 - u2) / (2*dx);
        }
        else if(x == Nx-1) // right
        {
            const csNumber_t& un  = v(Nx-1,   y);
            const csNumber_t& un1 = v(Nx-1-1, y);
            const csNumber_t& un2 = v(Nx-1-2, y);
            return (3*un - 4*un1 + un2) / (2*dx);
        }
        else
        {
            const csNumber_t& u1 = v(x+1, y);
            const csNumber_t& u2 = v(x-1, y);
            return (u1 - u2) / (2*dx);
        }
    }

    // First order partial derivative per y.
    csNumber_t du_dy(int x, int y)
    {
        if(y == 0) // bottom
        {
            const csNumber_t& u0 = u(x, 0);
            const csNumber_t& u1 = u(x, 1);
            const csNumber_t& u2 = u(x, 2);
            return (-3*u0 + 4*u1 - u2) / (2*dy);
        }
        else if(y == Ny-1) // top
        {
            const csNumber_t& un  = u(x, Ny-1  );
            const csNumber_t& un1 = u(x, Ny-1-1);
            const csNumber_t& un2 = u(x, Ny-1-2);
            return (3*un - 4*un1 + un2) / (2*dy);
        }
        else
        {
            const csNumber_t& ui1 = u(x, y+1);
            const csNumber_t& ui2 = u(x, y-1);
            return (ui1 - ui2) / (2*dy);
        }
    }
    csNumber_t dv_dy(int x, int y)
    {
        if(y == 0) // bottom
        {
            const csNumber_t& u0 = v(x, 0);
            const csNumber_t& u1 = v(x, 1);
            const csNumber_t& u2 = v(x, 2);
            return (-3*u0 + 4*u1 - u2) / (2*dy);
        }
        else if(y == Ny-1) // top
        {
            const csNumber_t& un  = v(x, Ny-1  );
            const csNumber_t& un1 = v(x, Ny-1-1);
            const csNumber_t& un2 = v(x, Ny-1-2);
            return (3*un - 4*un1 + un2) / (2*dy);
        }
        else
        {
            const csNumber_t& ui1 = v(x, y+1);
            const csNumber_t& ui2 = v(x, y-1);
            return (ui1 - ui2) / (2*dy);
        }
    }

    // Second order partial derivative per x.
    csNumber_t d2u_dx2(int x, int y)
    {
        if(x == 0 || x == Nx-1)
            throw std::runtime_error("d2u_dx2 called at the boundary");

        const csNumber_t& ui1 = u(x+1, y);
        const csNumber_t& ui  = u(x,   y);
        const csNumber_t& ui2 = u(x-1, y);
        return (ui1 - 2*ui + ui2) / (dx*dx);
    }
    csNumber_t d2v_dx2(int x, int y)
    {
        if(x == 0 || x == Nx-1)
            throw std::runtime_error("d2v_dx2 called at the boundary");

        const csNumber_t& vi1 = v(x+1, y);
        const csNumber_t& vi  = v(x,   y);
        const csNumber_t& vi2 = v(x-1, y);
        return (vi1 - 2*vi + vi2) / (dx*dx);
    }

    // Second order partial derivative per y.
    csNumber_t d2u_dy2(int x, int y)
    {
        if(y == 0 || y == Ny-1)
            throw std::runtime_error("d2u_dy2 called at the boundary");

        const csNumber_t& ui1 = u(x, y+1);
        const csNumber_t& ui  = u(x,   y);
        const csNumber_t& ui2 = u(x, y-1);
        return (ui1 - 2*ui + ui2) / (dy*dy);
    }
    csNumber_t d2v_dy2(int x, int y)
    {
        if(y == 0 || y == Ny-1)
            throw std::runtime_error("d2v_dy2 called at the boundary");

        const csNumber_t& vi1 = v(x, y+1);
        const csNumber_t& vi  = v(x,   y);
        const csNumber_t& vi2 = v(x, y-1);
        return (vi1 - 2*vi + vi2) / (dy*dy);
    }

    int getIndex(int x, int y)
    {
        if(x < 0 || x >= Nx)
            throw std::runtime_error("Invalid x index");
        if(y < 0 || y >= Ny)
            throw std::runtime_error("Invalid y index");
        return Ny*x + y;
    }

public:
    int    Nequations;
    int    Nx;
    int    Ny;
    real_t eps1;
    real_t eps2;
    real_t A;
    real_t B;

protected:
    int    u_start_index;
    int    v_start_index;
    real_t x0, x1;
    real_t y0, y1;
    real_t dx;
    real_t dy;
    std::vector<real_t> x_domain;
    std::vector<real_t> y_domain;
    const csNumber_t* u_data;
    const csNumber_t* v_data;
    const csNumber_t* du_dt_data;
    const csNumber_t* dv_dt_data;
    const csNumber_t& u_flux_bc;
    const csNumber_t& v_flux_bc;

    csGroupPtr group_BCs;
    csGroupPtr group_u;
    csGroupPtr group_v;
};

// Problem-specific custom graph partitioner.
// Adds complete rows to a partition so that each partition gets equations from all groups.
class csGraphPartitioner_dae_example_3: public csGraphPartitioner_t
{
public:
    csGraphPartitioner_dae_example_3(uint32_t Nx_, uint32_t Ny_)
    {
        Nx = Nx_;
        Ny = Ny_;
    }

    std::string GetName()
    {
        return "dae_example_3";
    }

    virtual int Partition(int32_t                               Npe,
                          int32_t                               Nvertices,
                          int32_t                               Nconstraints,
                          std::vector<uint32_t>&                rowIndices,
                          std::vector<uint32_t>&                colIndices,
                          std::vector< std::vector<int32_t> >&  vertexWeights,
                          std::vector<uint16_t>&                groupIDs,
                          std::vector< std::set<int32_t> >&     partitions)
    {
        if(Nvertices != 2*Nx*Ny)
            csThrowException("Invalid number of equations or grid points");
        if(Nx % Npe != 0)
            csThrowException("Nx: Graph partitioner can handle only grids that are multiples of Npe");
        if(Ny % Npe != 0)
            csThrowException("Ny: Graph partitioner can handle only grids that are multiples of Npe");

        partitions.resize(Npe);

        uint32_t Nequations_pe = (uint32_t)((Nx*Ny) / Npe);
        uint32_t eq_cnt = 0;
        uint32_t pe = 0;
        printf("(%d,%d) Nequations_pe = %d, Nvertices = %d\n", (int)Nx, (int)Ny, (int)Nequations_pe, (int)Nvertices);
        for(uint32_t x = 0; x < Nx; x++)
        {
            for(uint32_t y = 0; y < Ny; y++)
            {
                pe = std::floor(eq_cnt / Nequations_pe);

                std::set<int32_t>& partition = partitions[pe];
                partition.insert(eq_cnt);
                partition.insert(eq_cnt + Nx*Ny);

                //printf("(%d,%d) pe = %d\n", (int)x, (int)y, (int)pe);
                eq_cnt++;
            }
        }
        return 0;
    }

    // Grid size
    uint32_t Nx;
    uint32_t Ny;
};

std::shared_ptr<csGraphPartitioner_t> createGraphPartitioner_dae_example_3(uint32_t Nx, uint32_t Ny)
{
    return std::shared_ptr<csGraphPartitioner_t>(new csGraphPartitioner_dae_example_3(Nx, Ny));
}


#endif