/*******************************************************************************
* Copyright (C) 2023 Intel Corporation
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

//@HEADER
// ***************************************************
//
// HPCG: High Performance Conjugate Gradient Benchmark
//
// Contact:
// Michael A. Heroux ( maherou@sandia.gov)
// Jack Dongarra     (dongarra@eecs.utk.edu)
// Piotr Luszczek    (luszczek@eecs.utk.edu)
//
// ***************************************************
//@HEADER


#ifndef CUSTOMKERNELS_HPP
#define CUSTOMKERNELS_HPP



#include <sycl/sycl.hpp>
#include "UsmUtil.hpp"
#include "Helpers.hpp"


/******************************************************************
 Functions we need for algorithm with Graph permutation:

	• fillDiagData for sparse_matrix<local_int_t, double> & mat
	• SpTRSV lower kernel: solve   (L+D) * y = x
	• SpTRSV upper kernel: solve   (D+U) * y = x
	• SpTRMV lower kernel: compute y = L * x
	• SpTRMV upper kernel: compute y = U * x
	• SpGEMV:   compute y = A * x
	• SpGEMV_dot: compute y = A * x,  d = dot(x,y)

We will utilize sparse::sort_matrix(), sparse::gemv(B),
  sparse::update_diagonal_data() to get matrix.values updated back and forth…

Note: For algorithm when not using graph permutations, we will use
  sparse:: functionality for everything…

*******************************************************************/

namespace custom {


    // Device data to be cached
    struct deviceInfo {

        // default constructor
        deviceInfo()
        :
        isInitialized(false),
        device_num_threads(-1),
        device_global_mem_cache_line_size(-1),
        device_global_mem_cache_size(-1),
        device_alignment(0),
        device_max_mem_alloc_size(0)
        {}

        // non-default constructor
        deviceInfo(sycl::queue &queue) {
            init(queue);
            this->isInitialized = true;
        }

        // public members
        bool isInitialized;
        int device_num_threads;  // do I need this?  looks like a maybe yes for spmv merge-based algorithm
        int device_global_mem_cache_line_size;
        int device_global_mem_cache_size;
        std::uint32_t device_alignment;
        std::uint64_t device_max_mem_alloc_size;

        // public member function for initialization
        void init(sycl::queue &queue);

    };


    struct sparseMatrix
    {
        // constructor with initialization
        sparseMatrix()
        :
        nrows(0), ncols(0),
        diagDataIsFilled(false),
        block_size(0), nVectors(0), nBlocks(0),
        nColors(0)
        {

            // ESB format
            this->esbblockptr = nullptr;
            this->esbcolind = nullptr;
            this->esbvalues = nullptr;

            // Diagonal stuff and ESB helpers for triangular parts
            this->esblastLower = nullptr;
            this->esbfirstUpper = nullptr;
            this->esblastUpper = nullptr;
            this->esbfirstNonloc = nullptr;
            this->diags = nullptr;

            // coloring for SpTRSV
            this->xcolors_host = nullptr;
            this->xcolors_dev = nullptr;

            // device info
            this->p_info = nullptr;

        }

        // destructor
        ~sparseMatrix()
        { }

        void initDiagonal(double *diags, local_int_t *esblastLower, local_int_t *esbfirstUpper, local_int_t *esblastUpper, local_int_t *esbfirstNonloc)
        {
            this->diags = diags;
            this->esblastLower = esblastLower;
            this->esbfirstUpper = esbfirstUpper;
            this->esblastUpper = esblastUpper;
            this->esbfirstNonloc = esbfirstNonloc;
            this->diagDataIsFilled = true;
        }

        void initMatrix(local_int_t nrows, local_int_t ncols, local_int_t block_size, local_int_t nVectors, local_int_t nBlocks, local_int_t *esbblockptr, local_int_t *esbcolind, double *esbvalues)
        {
            this->nrows = nrows;
            this->ncols = ncols;
            this->block_size = block_size;
            this->nVectors = nVectors;
            this->nBlocks = nBlocks;
            this->esbblockptr = esbblockptr;
            this->esbcolind = esbcolind;
            this->esbvalues = esbvalues;
        }

        void initColoring(local_int_t nColors, local_int_t *xcolors_host, local_int_t *xcolors_dev) {
            this->nColors = nColors;
            this->xcolors_host = xcolors_host;
            this->xcolors_dev = xcolors_dev;
        }

        void setDevInfo(deviceInfo *devInfo)
        {
            this->p_info = devInfo;
        }

        // matrix sizes
        local_int_t nrows;
        local_int_t ncols;

        // arrays for handling diagonals and triangular parts
        bool diagDataIsFilled;   // bool diag data is filled or not
        double *diags;           // (nrows) USM device
        local_int_t *esblastLower;  // (nBlocks) USM device
        local_int_t *esbfirstUpper; // (nBlocks) USM device
        local_int_t *esblastUpper; // (nBlocks) USM device
        local_int_t *esbfirstNonloc; // (nBlocks) USM device

        // arrays for reordering blocks
        local_int_t *mv_reorder;
        local_int_t *trmv_l_reorder;
        local_int_t *trmv_u_reorder;

        bool applyL3CacheMVReorder = false;
        bool applyL3CacheTRMVLReorder = false;
        bool applyL3CacheTRMVUReorder = false;

        // ESB matrix parameters
        local_int_t block_size;
        local_int_t nVectors;
        local_int_t nBlocks;
        local_int_t *esbblockptr; // (nBlocks + 1) USM device
        local_int_t *esbcolind; // (block_size * nVectors) USM device
        double *esbvalues;      // (block_size * nVectors) USM device

        // Triangular Solve coloring
        local_int_t nColors;
        local_int_t *xcolors_host;
        local_int_t *xcolors_dev;

        // store a pointer to device data here
        deviceInfo * p_info;

    };



    // custom uplo for hpcg kernels
    enum class uplo : char {
        upper          = 0,
        upper_diagonal = 1,
        lower          = 2,
        lower_diagonal = 3,
        lower_update   = 4,
        upper_nonlocal = 5
    };

    //
    // Sparse::gemv()  use esimd kernel
    //

    // y = A * x          y,x should be USM device allocated
    sycl::event SpGEMV(sycl::queue &queue, sparseMatrix *matM, const double *x, double *y,
            const std::vector<sycl::event>& dependencies = {});


    sycl::event SpGEMV_DOT(sycl::queue &queue, sparseMatrix *matA, const double *x, double *y, double *xAx,
            const std::vector<sycl::event>& dependencies = {});

    //
    // Sparse::trmv() use esimd kernel with lastLower or firstUpper  or switch to merge-based kernel once it is ready
    //

    // select between kernel modes for trmv
    // y = (L+D) * x,    -- uplo::lower_diagonal
    // y = L * x,        -- uplo::lower
    // y = (D+U) * x     -- uplo::upper_diagonal
    // y = r - U * x     -- uplo::upper
    // y = (L+D) * x + y -- uplo::lower_diagonal_update
    sycl::event SpTRMV(sycl::queue &queue, sparseMatrix *matM, uplo mode, const double *x,
                       const double *r, double *y, double *y1, const std::vector<sycl::event>& dependencies = {});



    //
    // Sparse::trsv() lower -- use kernel from GCA team but avoid reordering into new ordering in kernel -- Kernel_trsv_sycl_by_uncached_load()
    //

    // Select between kernel modes for SpTRSV
    // solve (L+D) * y = x,   -- uplo::lower_diagonal
    // solve (D+U) * y = x    -- uplo::upper_diagonal
    sycl::event SpTRSV(sycl::queue &queue, sparseMatrix *matM, uplo mode, double *x,
            double *y, const std::vector<sycl::event>& dependencies = {});

    // Select between kernel modes for SpTRSV_FUSED
    // t = y, solve (L+D) * y = x, x = t + diag * y -- uplo::lower_diagonal
    // x = diag * x, solve (D+U) * y = x,           -- uplo::upper_diagonal
    sycl::event SpTRSV_FUSED(sycl::queue &queue, sparseMatrix *matM, uplo mode, double *x,
            double *y, const std::vector<sycl::event>& dependencies = {});


    // beta = rtz_dev[0]/ oldrtz_dev[0]
    // p = beta * p + z
    sycl::event SpAXPBY_ker1(sycl::queue &queue, double *rtz_dev, double *oldrtz_dev,
                        local_int_t nrow, Vector &p, Vector &z, const std::vector<sycl::event>&dependencies = {});

    // alpha = rtz_dev[0]/ pAp_dev[0]
    // r = r - alpha * Ap
    sycl::event SpAXPBY_ker2(sycl::queue &queue, double *rtz_dev, double *pAp_dev,
                        local_int_t nrow, Vector &r, Vector &Ap, const std::vector<sycl::event>&dependencies = {});

    // alpha = rtz_dev[0]/ pAp_dev[0]
    // x = x + alpha * p
    sycl::event SpAXPBY_ker3(sycl::queue &queue, double *rtz_dev, double *pAp_dev,
                        local_int_t nrow, Vector &x, Vector &p, const std::vector<sycl::event>&dependencies = {});

} // namespace custom

#endif  // CUSTOMKERNEL_HPP
