/*******************************************************************************
* Copyright (C) 2022 Intel Corporation
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

//@HEADER
// ***************************************************
//
// HPCG: High Performance Conjugate Gradient Benchmark
//
// Contact:
// Michael A. Heroux ( maherou@sandia.gov)
// Jack Dongarra     (dongarra@eecs.utk.edu)
// Piotr Luszczek    (luszczek@eecs.utk.edu)
//
// ***************************************************
//@HEADER

/*!
 @file main_test_kernels.cpp

 Test the custom kernels used in the HPCG benchmark
 */

#ifndef HPCG_NO_MPI
#include <mpi.h>
#endif

#include <fstream>
#include <iostream>
#include <iomanip>
#include <cstdlib>
#ifdef HPCG_DETAILED_DEBUG
using std::cin;
#endif
//using std::endl;
#include <vector>
#include <tuple>
#include <utility>
#include <map>

#include "hpcg.hpp"

#include "CheckAspectRatio.hpp"
#include "GenerateGeometry.hpp"
#include "GenerateProblem.hpp"
#include "GenerateCoarseProblem.hpp"
#include "SetupHalo.hpp"
#include "CheckProblem.hpp"
#include "ExchangeHalo.hpp"
#include "OptimizeProblem.hpp"
#include "WriteProblem.hpp"
#include "ReportResults.hpp"
#include "mytimer.hpp"
#include "ComputeSPMV_ref.hpp"
#include "ComputeMG_ref.hpp"
#include "ComputeResidual.hpp"
#include "CG.hpp"
#include "CG_ref.hpp"
#include "Geometry.hpp"
#include "SparseMatrix.hpp"
#include "Vector.hpp"
#include "CGData.hpp"
#include "TestCG.hpp"
#include "TestSymmetry.hpp"
#include "TestNorms.hpp"
#include "UsmUtil.hpp"
#include "VeryBasicProfiler.hpp"

#include "TestCustomKernels.hpp"

#include <cmath>
#include <cfloat>

int run_custom_kernel_tests( sycl::queue &main_queue, int rank, int size, HPCG_Params &params) {

  if (rank == 0) {
    std::cout << "###########################################################################" << std::endl;
    std::cout << "########## Running Functional Tests for Custom Sparse Kernels #############" << std::endl;
#ifdef HPCG_TEST_NO_HALO_EXCHANGE
    std::cout << "########## Using HPCG_TEST_NO_HALO_EXCHANGE                    ############" << std::endl;
#endif
    std::cout << "###########################################################################" << std::endl;
  }

  // Check if QuickPath option is enabled.
  // If the running time is set to zero, we minimize all paths through the program
  bool quickPath = (params.runningTime==0);

#ifdef HPCG_DETAILED_DEBUG
  if (size < 100 && rank==0) HPCG_fout << "Process "<<rank<<" of "<<size<<" is alive with " << params.numThreads << " threads." <<std::endl;

  if (rank==0) {
    char c;
    std::cout << "Press key to continue"<< std::endl;
    std::cin.get(c);
  }
#ifndef HPCG_NO_MPI
  MPI_Barrier(MPI_COMM_WORLD);
#endif
#endif

  local_int_t nx,ny,nz;
  nx = (local_int_t)params.nx;
  ny = (local_int_t)params.ny;
  nz = (local_int_t)params.nz;
  int ierr = 0;  // Used to check return codes on function calls

  ierr = CheckAspectRatio(0.125, nx, ny, nz, "local problem", rank==0);
  if (ierr)
    return ierr;

  /////////////////////////
  // Problem setup Phase //
  /////////////////////////

#ifdef HPCG_DEBUG
  double t1 = mytimer();
#endif

  // Construct the geometry and linear system
  Geometry * geom = new Geometry;
  GenerateGeometry(size, rank, params.numThreads, params.pz, params.zl, params.zu, nx, ny, nz, params.npx, params.npy, params.npz, geom, main_queue);
  ierr = CheckAspectRatio(0.125, geom->npx, geom->npy, geom->npz, "process grid", rank==0);
  if (ierr)
    return ierr;

  // Use this array for collecting timing information
  std::vector< double > times(10,0.0);
  double setup_time = mytimer();

  SparseMatrix A;
  InitializeSparseMatrix(A, geom);
  Vector b, x, xexact;
  GenerateProblem(A, &b, &x, &xexact, main_queue, params.runRealRef);
  SetupHalo(A, main_queue);
  int numberOfMgLevels = 4; // Number of levels including first
  SparseMatrix * curLevelMatrix = &A;
  for (int level = 1; level< numberOfMgLevels; ++level) {
      GenerateCoarseProblem(*curLevelMatrix, main_queue, params.runRealRef);
      curLevelMatrix = curLevelMatrix->Ac; // Make the just-constructed coarse grid the next level
  }

  local_int_t nrow = A.localNumberOfRows;
  local_int_t ncol = A.localNumberOfColumns;

  double t7 = 0.0;
  const bool need_MKL_matrix = true;
  OptimizeProblem(&A, &b, t7, need_MKL_matrix, main_queue);
  if (rank == 0) {
      std::cout << "OptimizeProblem took " << t7 << " seconds" << std::endl;
  }

  ////////////////////////////////////////////
  // Custom Kernel Validation Testing Phase //
  ////////////////////////////////////////////

  TestCustomKernelsData testck_data;
  TestCustomKernels(A, b, x, rank, size, testck_data, main_queue);

  if (rank == 0) {
    if (testck_data.count_fail == 0) {
        std::cout << "Custom Kernel Verification success" << std::endl;
    }
    else {
        std::cout << "Custom Kernel Verification failed" << std::endl;
    }
  }

  // Clean up
  DeleteMatrix(A, main_queue); // This delete will recursively delete all coarse grid data
  DeleteVector(x, main_queue);
  DeleteVector(b, main_queue);
  DeleteVector(xexact, main_queue);

  HPCG_Finalize();
  // Finish up
#ifndef HPCG_NO_MPI
  MPI_Finalize();
#endif
  return 0;

}


/*!
  Main driver program: Construct synthetic problem, run V&V tests, compute benchmark parameters, run benchmark, report results.

  @param[in]  argc Standard argument count.  Should equal 1 (no arguments passed in) or 4 (nx, ny, nz passed in)
  @param[in]  argv Standard argument array.  If argc==1, argv is unused.  If argc==4, argv[1], argv[2], argv[3] will be interpreted as nx, ny, nz, resp.

  @return Returns zero on success and a non-zero value otherwise.

*/
int main(int argc, char * argv[]) {

#include "main_test_base.cxx"

  //
  // run functional test for cutom kernels and exit
  //
  return run_custom_kernel_tests(main_queue, rank, size, params );
}
