!===============================================================================
! Copyright (C) 2025 Intel Corporation
!
! This software and the related documents are Intel copyrighted  materials,  and
! your use of  them is  governed by the  express license  under which  they were
! provided to you (License).  Unless the License provides otherwise, you may not
! use, modify, copy, publish, distribute,  disclose or transmit this software or
! the related documents without Intel's prior written permission.
!
! This software and the related documents  are provided as  is,  with no express
! or implied  warranties,  other  than those  that are  expressly stated  in the
! License.
!===============================================================================

!  Content:
!      Intel(R) oneAPI Math Kernel Library (oneMKL)
!      FORTRAN OpenMP offload example for SGETRF_BATCH
!*******************************************************************************


include "mkl_omp_offload.f90"

program sgetrf_batch_example

    #if defined(MKL_ILP64)
    use onemkl_lapack_omp_offload_ilp64
    #else
    use onemkl_lapack_omp_offload_lp64
    #endif
    use, intrinsic :: ISO_C_BINDING

    ! Group API size variables
    integer, parameter :: group_count = 4
    integer :: m(group_count), n(group_count), lda(group_count), group_size(group_count)

    ! Group API data variables
    integer(KIND=C_SIZE_T),allocatable :: a_array_dev(:), ipiv_array_dev(:)
    real,    allocatable, target :: a(:,:,:)
    integer, allocatable, target :: ipiv(:,:), info(:)

    integer(KIND=C_SIZE_T),allocatable :: a_ref_array(:), ipiv_ref_array(:)
    real,    allocatable :: a_ref(:,:,:)
    integer, allocatable :: ipiv_ref(:,:), info_ref(:)

    real,    pointer :: tmp_a(:,:)
    integer, pointer :: tmp_ipiv(:)
    integer :: max_lda = 0, max_m = 0, max_n = 0, max_pivlen = 0, batch_size = 0
    integer :: passed, imat, i, diff 
    real    :: error, thresh = 0.001


    !
    ! Initialize matrix dimensions and group size for each group
    !
    do i = 1, group_count
      m(i) = i + 10
      n(i) = i + 10
      lda(i) = m(i)

      group_size(i) = i + 4
      batch_size = batch_size + group_size(i)

      if (max_lda .lt. lda(i)) max_lda = lda(i)
      if (max_m   .lt. m(i))   max_m = m(i)
      if (max_n   .lt. n(i))   max_n = n(i)
      if (max_pivlen .lt. min(m(i),n(i))) max_pivlen = min(m(i),n(i))
    end do

    !
    ! Hold all batch matrix data in a rank-3 block of memory.
    ! The dimensions of this array are determined by the dimensions
    ! of largest matrix in the batch.
    !
    allocate(a(max_lda, max_n, batch_size))
    allocate(ipiv(max_pivlen, batch_size))
    allocate(info(batch_size))
    allocate(a_ref(max_lda, max_n, batch_size))
    allocate(ipiv_ref(max_pivlen, batch_size))
    allocate(info_ref(batch_size))
    allocate(a_array_dev(batch_size))
    allocate(ipiv_array_dev(batch_size))
    allocate(a_ref_array(batch_size))
    allocate(ipiv_ref_array(batch_size))

    if ((.not. allocated(a))        .or. &
        (.not. allocated(ipiv))     .or. &
        (.not. allocated(info))     .or. &
        (.not. allocated(a_ref))    .or. &
        (.not. allocated(ipiv_ref)) .or. &
        (.not. allocated(info_ref))) then
        print *, "Cannot allocate matrices"
        goto 998
    end if

    if ((.not. allocated(a_ref_array)) .or. &
        (.not. allocated(ipiv_ref_array))) then
        print *, "Cannot allocate array of pointers"
        goto 998
    end if

    if ((.not. allocated(a_array_dev)) .or. &
        (.not. allocated(ipiv_array_dev))) then
        print *, "Cannot allocate array of device pointers"
        goto 998
    end if

    !
    ! Initialize matrix data
    !
    call random_number(a)
    a_ref = a
    ipiv(:,:) = 0
    ipiv_ref(:,:) = 0
    info(:) = 0
    info_ref(:) = 0

    !
    ! Compute reference solution
    !
    do i = 1, batch_size
        a_ref_array(i)    = LOC(a_ref(1,1,i))
        ipiv_ref_array(i) = LOC(ipiv_ref(1,i))
    end do
    call sgetrf_batch(m, n, a_ref_array, lda, ipiv_ref_array, group_count, group_size, info_ref)

    !
    ! Map each matrix to the device and store the device pointers into arrays 
    !
    do i = 1, batch_size
        !$omp target enter data map(to:a(:,:,i), ipiv(:,i))
        tmp_a => a(:,:,i)
        tmp_ipiv => ipiv(:,i)
        !$omp target data use_device_addr(tmp_a, tmp_ipiv)
            a_array_dev(i) = LOC(tmp_a)
            ipiv_array_dev(i) = LOC(tmp_ipiv)
        !$omp end target data
    end do

    !
    ! Compute batched LU factorization on the device by calling the OpenMP offload variant dispatch
    !
    !$omp target data map(tofrom:a_array_dev, ipiv_array_dev) map(from:info)
    !$omp dispatch
        call sgetrf_batch(m, n, a_array_dev, lda, ipiv_array_dev, group_count, group_size, info)
    !$omp end target data

    !
    ! Retrieve solution from the device to the host
    !
    do i = 1, batch_size
    !$omp target exit data map(from:a(:,:,i), ipiv(:,i))
    end do
    !$omp target exit data map(from:info)

    !
    ! Validate results
    !
    do imat = 1, batch_size
        ! Check matrix factorizations
        do i = 1, max_n
            do j = 1, max_m
                err = abs(a(i,j,imat) - a_ref(i,j,imat))
                if (err .gt. thresh) then
                    print 100, imat, i, j, a_ref(i,j,imat), a(i,j,imat), err
                    status = 1
                end if
            end do
        end do

        ! Check pivot arrays
        do i = 1, max_pivlen
            diff = abs(ipiv(i,imat) - ipiv_ref(i,imat))
            if (diff .gt. 0) then
                print 101, imat, i, ipiv_ref(i,imat), ipiv(i,imat), diff
                status = 1
            end if
        end do

        ! Check info value
        diff = abs(info(imat) - info_ref(imat))
        if (info(imat) .ne. 0) then
            print 102, imat, info_ref(imat), info(imat), diff
            status = 1
        end if
    end do

    !
    ! Clean up
    !
    deallocate(a);
    deallocate(ipiv);
    deallocate(info);
    deallocate(a_ref);
    deallocate(ipiv_ref);
    deallocate(info_ref);
    deallocate(a_ref_array);
    deallocate(ipiv_ref_array);
    deallocate(a_array_dev);
    deallocate(ipiv_array_dev);

    100 format(7x, 'Error at matrix ',i1,', index (',i1,',',i1,'), expected = ',f10.6,', computed = ',f10.6,', difference = ',f10.6)
    101 format(7x, 'Error at ipiv vector ',i2,', index (',i2,'), expected = ',i4,', computed = ',i4,', difference = ',i4)
    102 format(7x, 'Error at info (',i2,'), expected = ',i1,', computed = ',i4,', difference = ',i4)

    if (passed .ne. 0) then
      goto 999
    else
      print *, "PASSED"
    end if

    stop
    998 print *, 'Error: cannot allocate memory'
    999 stop 1
end program
