/*******************************************************************************
* Copyright (C) 2014 Intel Corporation
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

//@HEADER
// ***************************************************
//
// HPCG: High Performance Conjugate Gradient Benchmark
//
// Contact:
// Michael A. Heroux ( maherou@sandia.gov)
// Jack Dongarra     (dongarra@eecs.utk.edu)
// Piotr Luszczek    (luszczek@eecs.utk.edu)
//
// ***************************************************
//@HEADER

/*!
 @file CG.cpp

 HPCG routine
 */

#include <fstream>
#include <iomanip>

#include <cmath>
#include <cfloat>

#include "hpcg.hpp"

#include "CG.hpp"
#include "CG_ref.hpp"
#include "mytimer.hpp"
#include "ComputeSPMV.hpp"
#include "ComputeSPMV_ref.hpp"
#include "ComputeMG.hpp"
#include "ComputeDotProduct_ref.hpp"
#include "ComputeDotProduct.hpp"
#include "ComputeWAXPBY.hpp"
#include "UsmUtil.hpp"
#include "Helpers.hpp"
#include "CustomKernels.hpp" // for axpy kernels

#ifndef HPCG_NO_MPI
#include "ExchangeHalo.hpp"
#endif

#include "VeryBasicProfiler.hpp"

#ifdef BASIC_PROFILING

#define BEGIN_PROFILE(n) optData->profiler->begin(n);
#define END_PROFILE(n) optData->profiler->end(n);
#define END_PROFILE_WAIT(n, event) event.wait(); optData->profiler->end(n);

// could turn off TICK/TOCK here, or somehow combine them

#else

#define BEGIN_PROFILE(n)
#define END_PROFILE(n)
#define END_PROFILE_WAIT(n, event)

#endif

// Use TICK and TOCK to time a code section in MATLAB-like fashion
#define TICK()  t0 = mytimer() //!< record current time in 't0'
#define TOCK(t) t += mytimer() - t0 //!< store time difference in 't' using time in 't0'

using namespace sycl;
/*!
  Routine to compute an approximate solution to Ax = b

  @param[in]    geom The description of the problem's geometry.
  @param[inout] A    The known system matrix
  @param[inout] data The data structure with all necessary CG vectors preallocated
  @param[in]    b    The known right hand side vector
  @param[inout] x    On entry: the initial guess; on exit: the new approximate solution
  @param[in]    max_iter  The maximum number of iterations to perform, even if tolerance is not met.
  @param[in]    tolerance The stopping criterion to assert convergence: if norm of residual is <= to tolerance.
  @param[out]   niters    The number of iterations actually performed.
  @param[out]   normr     The 2-norm of the residual vector after the last iteration.
  @param[out]   normr0    The 2-norm of the residual vector before the first iteration.
  @param[out]   times     The 7-element vector of the timing information accumulated during all of the iterations.
  @param[in]    doPreconditioning The flag to indicate whether the preconditioner should be invoked at each iteration.

  @return Returns zero on success and a non-zero value otherwise.

  @see CG_ref()
*/
int CG(const SparseMatrix & A, CGData & data, const Vector & b, Vector & x,
    const int max_iter, const double tolerance, int & niters, double & normr, double & normr0,
    double * times, bool doPreconditioning, sycl::queue & main_queue) {

#ifdef HPCG_LOCAL_LONG_LONG
    return CG_ref(A, data, b, x, max_iter, tolerance, niters, normr, normr0, times, doPreconditioning);
#else
    double ff = 0.0;
    double t0 = 0.0, t1 = 0.0, t2 = 0.0, t3 = 0.0, t4 = 0.0, t5 = 0.0;
    double t_begin = mytimer();  // Start timing right away
    struct optData *optData = (struct optData *)A.optimizationData;
    BEGIN_PROFILE("preiteration");

#ifndef HPCG_NO_MPI
    MPI_Request *request = new MPI_Request();
#endif
    normr = 0.0;
    const local_int_t nrow = A.localNumberOfRows;

    Vector & r = data.r; // Residual vector
    Vector & z = data.z; // Preconditioned residual vector
    Vector & p = data.p; // Direction vector (in MPI mode ncol>=nrow)
    Vector & Ap = data.Ap;

    if (!doPreconditioning && A.geom->rank==0) HPCG_fout << "WARNING: PERFORMING UNPRECONDITIONED ITERATIONS OPT" << std::endl;

#ifdef HPCG_DEBUG
    int print_freq = 1;
    if (print_freq>50) print_freq=50;
    if (print_freq<1)  print_freq=1;
#endif
    // p is of length ncols, copy x to p for sparse MV operation

    double *normr_dev          = optData->normr_dev;
    double *pAp_dev            = optData->pAp_dev;
    double *pAp_loc_dev        = optData->pAp_loc_dev;
    double *rtz_dev            = optData->rtz_dev;
    double *rtz_loc_dev        = optData->rtz_loc_dev;
    double *oldrtz_dev         = optData->oldrtz_dev;

    double *normr_host         = optData->normr_host;
    double *rtz_loc_host       = optData->rtz_loc_host;
    double *pAp_loc_host       = optData->pAp_loc_host;
    double *global_result_host = optData->global_result_host;

    TICK();

    auto set_r_ev = main_queue.memcpy(r.values, b.values, nrow * sizeof(double));

    // wait because we require normr_dev to be computed for MPI
    auto update_normr_ev = ComputeDotProductLocal(nrow, r, r, normr_dev, main_queue, {set_r_ev});
    main_queue.memcpy(normr_host, normr_dev, sizeof(double), {update_normr_ev}).wait();

    TOCK(t2);
    TICK();
#ifndef HPCG_NO_MPI
    global_result_host[0] = 0.0;
    MPI_Allreduce(normr_host, global_result_host, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
    normr_host[0] = global_result_host[0];
#endif
    normr = std::sqrt(normr_host[0]);
    // Record initial residual for convergence testing
    normr0 = normr;
    // Convergence check accepts an error of no more than 6 significant digits of tolerance
    ff = normr/normr0-tolerance*(1.0 + 1e-6);
    TOCK(t4);

    int converge_flag = 0;
    if ( ff <= 0.0 )
    {
        converge_flag = 1;
    }

#ifdef HPCG_DEBUG
    if (A.geom->rank==0) HPCG_fout << "Initial Residual = "<< normr << std::endl;
#endif
    END_PROFILE("preiteration");

    // Start iterations
    int ierr = 0;
    sycl::event update_x_ev;
    for (int k=1; (k<=max_iter && ff >= DBL_EPSILON) || (converge_flag == 1 && k <= 50); k++ )
    {
        sycl::event apply_preconditioner_ev;
        TICK(); BEGIN_PROFILE("preconditioner");
        if (doPreconditioning) {
            apply_preconditioner_ev = ComputeMG(A, r, z, main_queue, ierr); // Apply preconditioner
        }
        else {
            apply_preconditioner_ev = CopyVector(r, z, main_queue);  // copy r to z (no preconditioning)
        }
        apply_preconditioner_ev.wait();
        END_PROFILE_WAIT("preconditioner", apply_preconditioner_ev); TOCK(t5);

        sycl::event update_p_ev;
        if (k == 1) {
            TICK(); BEGIN_PROFILE("dot");
            auto update_rtz_loc_ev = ComputeDotProductLocal(nrow, r, z, rtz_loc_dev, main_queue, {apply_preconditioner_ev}); // rtz_loc <- dot(r,z)
            END_PROFILE_WAIT("dot", update_rtz_loc_ev); TOCK(t1);

            TICK(); BEGIN_PROFILE("Allreduce");
#ifndef HPCG_NO_MPI
#ifdef HPCG_USE_MPI_OFFLOAD
            (void)rtz_loc_host; // to prevent compiler warning
            update_rtz_loc_ev.wait();
            MPI_Allreduce(rtz_loc_dev, rtz_dev, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
            sycl::event update_rtz_ev = update_rtz_loc_ev;
#else // HPCG_USE_MPI_OFFLOAD
            global_result_host[0] = 0.0;
            auto copy_rtz_loc_host_ev = main_queue.memcpy(rtz_loc_host, rtz_loc_dev, sizeof(double), {update_rtz_loc_ev});
            copy_rtz_loc_host_ev.wait();
            MPI_Allreduce(rtz_loc_host, global_result_host, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
            auto update_rtz_ev = main_queue.memcpy(rtz_dev, global_result_host, sizeof(double)); // and MPI dependency
#endif // HPCG_USE_MPI_OFFLOAD
#else // HPCG_NO_MPI
            auto update_rtz_ev = main_queue.memcpy(rtz_dev, rtz_loc_dev, sizeof(double), {update_rtz_loc_ev});
#endif // HPCG_NO_MPI
            END_PROFILE_WAIT("Allreduce", update_rtz_ev); TOCK(t4);

            TICK(); BEGIN_PROFILE("WAXPBY");
            update_p_ev = main_queue.memcpy(p.values, z.values, nrow * sizeof(double), {update_rtz_ev}); // p <- z
            END_PROFILE_WAIT("WAXPBY", update_p_ev); TOCK(t2);
        } else {
            auto update_oldrtz_ev = main_queue.memcpy(oldrtz_dev, rtz_dev, sizeof(double), {apply_preconditioner_ev});

            TICK(); BEGIN_PROFILE("dot");
            auto update_rtz_loc_ev = ComputeDotProductLocal(nrow, r, z, rtz_loc_dev, main_queue, {apply_preconditioner_ev}); // rtz_loc <- dot(r,z)
            END_PROFILE_WAIT("dot", update_rtz_loc_ev); TOCK(t1);

            TICK(); BEGIN_PROFILE("Allreduce");
#ifndef HPCG_NO_MPI
#ifdef HPCG_USE_MPI_OFFLOAD
            update_rtz_loc_ev.wait();
            MPI_Allreduce(rtz_loc_dev, rtz_dev, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
            sycl::event update_rtz_ev = update_rtz_loc_ev;
#else // HPCG_USE_MPI_OFFLOAD
            global_result_host[0] = 0.0;
            auto copy_rtz_loc_host_ev = main_queue.memcpy(rtz_loc_host, rtz_loc_dev, sizeof(double), {update_rtz_loc_ev});
            copy_rtz_loc_host_ev.wait();
            MPI_Allreduce(rtz_loc_host, global_result_host, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
            auto update_rtz_ev = main_queue.memcpy(rtz_dev, global_result_host, sizeof(double), {update_oldrtz_ev}); // and mpi dependency
#endif // HPCG_USE_MPI_OFFLOAD
#else // HPCG_NO_MPI
            auto update_rtz_ev = main_queue.memcpy(rtz_dev, rtz_loc_dev, sizeof(double), {update_rtz_loc_ev, update_oldrtz_ev});
#endif
            END_PROFILE_WAIT("Allreduce", update_rtz_ev); TOCK(t4);

            TICK(); BEGIN_PROFILE("WAXPBY");
            update_p_ev = custom::SpAXPBY_ker1(main_queue, rtz_dev, oldrtz_dev, nrow, p, z, {update_rtz_ev});
            END_PROFILE_WAIT("WAXPBY", update_p_ev); TOCK(t2);
        }

        sycl::event update_pAp_ev;
        if ( A.geom->size > 1 )
        {
            TICK(); BEGIN_PROFILE("SPMV");
            int ierr = 0;
            auto update_pAp_loc_ev = ComputeSPMV_DOT(A, p, Ap, pAp_loc_dev, main_queue, ierr, {update_p_ev}); // Ap <- A*p, pAp_loc_dev <- dot(p,Ap)
            END_PROFILE_WAIT("SPMV", update_pAp_loc_ev); TOCK(t3);

            TICK(); BEGIN_PROFILE("Allreduce");
#ifndef HPCG_NO_MPI
#ifdef HPCG_USE_MPI_OFFLOAD
            (void)pAp_loc_host; // to prevent compiler warning
            update_pAp_loc_ev.wait();
            MPI_Allreduce(pAp_loc_dev, pAp_dev, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
            update_pAp_ev = update_pAp_loc_ev;
#else // HPCG_USE_MPI_OFFLOAD
            global_result_host[0] = 0.0;
            auto copy_pAp_loc_host_ev = main_queue.memcpy(pAp_loc_host, pAp_loc_dev, sizeof(double), {update_pAp_loc_ev});
            copy_pAp_loc_host_ev.wait();
            MPI_Allreduce(pAp_loc_host, global_result_host, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD);
            update_pAp_ev = main_queue.memcpy(pAp_dev, global_result_host, sizeof(double)); // and MPI dependency
#endif // HPCG_USE_MPI_OFFLOAD
#else // HPCG_NO_MPI
            update_pAp_ev = main_queue.memcpy(pAp_dev, pAp_loc_dev, sizeof(double), {update_pAp_loc_ev});
#endif
            END_PROFILE_WAIT("Allreduce", update_pAp_ev); TOCK(t4);
        }
        else {
            TICK(); BEGIN_PROFILE("SPMV");
            update_pAp_ev = ComputeSPMV_DOT(A, p, Ap, pAp_dev, main_queue, ierr, {update_p_ev}); // Ap <- A * p, pAp <- dot(p,Ap)
            END_PROFILE_WAIT("SPMV", update_pAp_ev); TOCK(t3);
        }

        TICK(); BEGIN_PROFILE("WAXPBY");

        auto update_r_ev = custom::SpAXPBY_ker2(main_queue, rtz_dev, pAp_dev, nrow, r, Ap, {update_pAp_ev});

        auto update_normr_ev = ComputeDotProductLocal(nrow, r, r, normr_dev, main_queue, {update_r_ev}); // normr = ||r||^2
        auto update_normr_host_ev = main_queue.memcpy(normr_host, normr_dev, sizeof(double), {update_normr_ev});
        END_PROFILE_WAIT("WAXPBY", update_normr_host_ev); TOCK(t2);

        TICK(); BEGIN_PROFILE("Allreduce");
#ifndef HPCG_NO_MPI
        global_result_host[0] = 0.0;
        update_normr_host_ev.wait();
        MPI_Iallreduce(normr_host, global_result_host, 1, MPI_DOUBLE, MPI_SUM, MPI_COMM_WORLD, (MPI_Request *)request);
#endif

        // update_pAp_ev  makes sure p.values is up-to-date before kernel launch
        // update_x_ev waits for previous x_update (if any) to finish making sure
	//  x.values is up-to-date before kernel launch
        update_x_ev = custom::SpAXPBY_ker3(main_queue, rtz_dev, pAp_dev, nrow, x, p, {update_pAp_ev, update_x_ev});

#ifndef HPCG_NO_MPI
        MPI_Wait((MPI_Request *)request, MPI_STATUS_IGNORE);
        normr_host[0] = global_result_host[0];
#else
        update_normr_host_ev.wait();
#endif
        normr = std::sqrt(normr_host[0]); // normr = ||r||
        ff = normr/normr0-tolerance*(1.0 + 1e-6);
        niters = k;
        END_PROFILE("Allreduce"); TOCK(t4);
#ifdef HPCG_DEBUG
        if (A.geom->rank==0 && ((k % print_freq) == 0 || k == max_iter))
            HPCG_fout << "Iteration = "<< k <<" " << tolerance <<" " << A.geom->rank << "   Scaled Residual = "<< normr/normr0 << std::endl;
#endif
        if (A.geom->rank==0)
            HPCG_fout << "Iteration = "<< k <<" " << std::setprecision(15) << tolerance <<" " << A.geom->rank << "   Scaled Residual = "<< normr/normr0 << std::endl;
    }

    // Store times
    times[0] += mytimer() - t_begin;  // Total time. All done...
    times[1] += t1; // dot-product time
    times[2] += t2; // WAXPBY time
    times[3] += t3; // SPMV time
    times[4] += t4; // AllReduce time
    times[5] += t5; // preconditioner apply time

#ifndef HPCG_NO_MPI
    delete (MPI_Request *)request;
#endif
#endif
    return 0;
}
