/*******************************************************************************
* Copyright 2023 Intel Corporation.
*
* This software and the related documents are Intel copyrighted  materials,  and
* your use of  them is  governed by the  express license  under which  they were
* provided to you (License).  Unless the License provides otherwise, you may not
* use, modify, copy, publish, distribute,  disclose or transmit this software or
* the related documents without Intel's prior written permission.
*
* This software and the related documents  are provided as  is,  with no express
* or implied  warranties,  other  than those  that are  expressly stated  in the
* License.
*******************************************************************************/

//@HEADER
// ***************************************************
//
// HPCG: High Performance Conjugate Gradient Benchmark
//
// Contact:
// Michael A. Heroux ( maherou@sandia.gov)
// Jack Dongarra     (dongarra@eecs.utk.edu)
// Piotr Luszczek    (luszczek@eecs.utk.edu)
//
// ***************************************************
//@HEADER

#pragma once

#include "../Helpers.hpp"
#include "../EsimdHelpers.hpp"

// Template unroll for axpby operation: y = alpha * x + beta * y
template <local_int_t block_size, local_int_t uroll, local_int_t s = 0>
static inline void axpby_impl(const double *x, double* y, const double alpha, const double beta) {
    if constexpr (s < uroll) {
        auto y_vec = esimd_lsc_block_load<double, local_int_t, block_size, ca, ca>(y, s * block_size);
        auto x_vec = esimd_lsc_block_load<double, local_int_t, block_size, st, uc>(x, s * block_size);
        y_vec = alpha * x_vec + beta * y_vec;
        esimd_lsc_block_store<double, local_int_t, block_size, nc, nc>(y, s * block_size, y_vec);
        axpby_impl<block_size, uroll, s + 1>(x, y, alpha, beta);
    }
}

// Called in ESIMD kernel
// Computes y[i] = alpha * x[i] + beta * y[i], for i = 0, ... n - 1
template <local_int_t block_size, local_int_t uroll>
static inline void axpby_body(sycl::nd_item<1> item, const double *x, double* y,
                              const double alpha, const double beta, const local_int_t n, const local_int_t nBlocks) {
    local_int_t block = item.get_global_id(0);
    auto offset = block * uroll * block_size;

    if (block < nBlocks - 1) {
        axpby_impl<block_size, uroll>(x + offset, y + offset, alpha, beta);
    }
    else if (block == nBlocks - 1){ // Last WG handles remainder w.r.t unroll if needed
        if (offset + uroll * block_size == n)
            axpby_impl<block_size, uroll>(x + offset, y + offset, alpha, beta);
        else {
            for (; offset < n; offset += block_size)
                axpby_impl<block_size, 1>(x + offset, y + offset, alpha, beta);
        }
    }
}
