/*
 * Copyright (C) 2025-2026 Intel Corporation
 *
 * SPDX-License-Identifier: MIT
 *
 */

#include "level_zero/core/source/mutable_cmdlist/mutable_indirect_data.h"

#include "shared/source/debug_settings/debug_settings_manager.h"
#include "shared/source/helpers/debug_helpers.h"
#include "shared/source/helpers/string.h"

#include <cinttypes>

namespace L0::MCL {

void MutableIndirectData::setAddress(CrossThreadDataOffset offset, uint64_t address, size_t addressSize) {
    if (isDefined(offset)) {
        if (inlineData.begin() != nullptr) {
            if (offset < inlineData.size()) {
                PRINT_STRING(NEO::debugManager.flags.PrintMclData.get(), stderr, "MCL store address value %" PRIx64 " size %zu in inline at offset %" PRIu16 "\n", address, addressSize, offset);
                memcpy_s(reinterpret_cast<void *>(inlineData.begin() + offset), addressSize, &address, addressSize);
            } else {
                PRINT_STRING(NEO::debugManager.flags.PrintMclData.get(), stderr, "MCL store address value %" PRIx64 " size %zu in cross-thread minus inline at offset %" PRIu16 "\n", address, addressSize, offset);
                memcpy_s(reinterpret_cast<void *>(crossThreadData.begin() + offset - inlineData.size()), addressSize, &address, addressSize);
            }
        } else {
            PRINT_STRING(NEO::debugManager.flags.PrintMclData.get(), stderr, "MCL store address value %" PRIx64 " size %zu in cross-thread at offset %" PRIu16 "\n", address, addressSize, offset);
            memcpy_s(reinterpret_cast<void *>(crossThreadData.begin() + offset), addressSize, &address, addressSize);
        }
    }
}

inline void MutableIndirectData::setIfDefined(const CrossThreadDataOffset (&offsets)[3], MaxChannelsArray data) {
    if (isDefined(offsets[0])) {
        size_t sizeToCopy = sizeof(data[0]) * (1 + !!(offsets[1] != undefined<CrossThreadDataOffset>)+!!(offsets[2] != undefined<CrossThreadDataOffset>));

        // check inline data is present
        if (inlineData.begin() != nullptr) {
            // check first offset begins in inline data
            // assuming all offsets are consecutively layout in memory
            if (offsets[0] < inlineData.size()) {
                // check if all data fits in inline data
                if (offsets[0] + sizeToCopy <= inlineData.size()) {
                    PRINT_STRING(NEO::debugManager.flags.PrintMclData.get(), stderr, "MCL store data in inline at offset %" PRIu16 "\n", offsets[0]);
                    memcpy_s(reinterpret_cast<void *>(inlineData.begin() + offsets[0]), sizeToCopy, data.data(), sizeToCopy);
                } else {
                    PRINT_STRING(NEO::debugManager.flags.PrintMclData.get(), stderr, "MCL store data in inline split at offset %" PRIu16 "\n", offsets[0]);
                    // data is split between inline and crossthread
                    size_t inlineDataCopySize = inlineData.size() - offsets[0];
                    memcpy_s(reinterpret_cast<void *>(inlineData.begin() + offsets[0]), inlineDataCopySize, data.data(), inlineDataCopySize);

                    size_t crossThreadDataCopy = sizeToCopy - inlineDataCopySize;
                    auto srcOffsetDataAddress = reinterpret_cast<uintptr_t>(data.data()) + inlineDataCopySize;
                    memcpy_s(reinterpret_cast<void *>(crossThreadData.begin()), crossThreadDataCopy, reinterpret_cast<void *>(srcOffsetDataAddress), crossThreadDataCopy);
                }
            } else {
                // offset does not start in existing inline, decrease crossthread offset by inline data size
                PRINT_STRING(NEO::debugManager.flags.PrintMclData.get(), stderr, "MCL store data in cross-thread minus inline at offset %" PRIu16 "\n", offsets[0]);
                memcpy_s(reinterpret_cast<void *>(crossThreadData.begin() + offsets[0] - inlineData.size()), sizeToCopy, data.data(), sizeToCopy);
            }
        } else {
            PRINT_STRING(NEO::debugManager.flags.PrintMclData.get(), stderr, "MCL store data in cross-thread at offset %" PRIu16 "\n", offsets[0]);
            memcpy_s(reinterpret_cast<void *>(crossThreadData.begin() + offsets[0]), sizeToCopy, data.data(), sizeToCopy);
        }
    }
}

void MutableIndirectData::setLocalWorkSize(MaxChannelsArray localWorkSize) {
    PRINT_STRING(NEO::debugManager.flags.PrintMclData.get(), stderr, "MCL mutation set lws %u %u %u\n", localWorkSize[0], localWorkSize[1], localWorkSize[2]);
    setIfDefined(offsets->localWorkSize, localWorkSize);
}

void MutableIndirectData::setLocalWorkSize2(MaxChannelsArray localWorkSize2) {
    PRINT_STRING(NEO::debugManager.flags.PrintMclData.get(), stderr, "MCL mutation set lws2 %u %u %u\n", localWorkSize2[0], localWorkSize2[1], localWorkSize2[2]);
    setIfDefined(offsets->localWorkSize2, localWorkSize2);
}

void MutableIndirectData::setEnqLocalWorkSize(MaxChannelsArray enqLocalWorkSize) {
    setIfDefined(offsets->enqLocalWorkSize, enqLocalWorkSize);
}

void MutableIndirectData::setGlobalWorkSize(MaxChannelsArray globalWorkSize) {
    PRINT_STRING(NEO::debugManager.flags.PrintMclData.get(), stderr, "MCL mutation set gws %u %u %u\n", globalWorkSize[0], globalWorkSize[1], globalWorkSize[2]);
    setIfDefined(offsets->globalWorkSize, globalWorkSize);
}

void MutableIndirectData::setNumWorkGroups(MaxChannelsArray numWorkGroups) {
    PRINT_STRING(NEO::debugManager.flags.PrintMclData.get(), stderr, "MCL mutation set num wgs %u %u %u\n", numWorkGroups[0], numWorkGroups[1], numWorkGroups[2]);
    setIfDefined(offsets->numWorkGroups, numWorkGroups);
}

void MutableIndirectData::setWorkDimensions(uint32_t workDimensions) {
    if (isDefined(offsets->workDimensions)) {
        if (inlineData.begin() != nullptr) {
            if (offsets->workDimensions < inlineData.size()) {
                *reinterpret_cast<uint32_t *>(inlineData.begin() + offsets->workDimensions) = workDimensions;
            } else {
                *reinterpret_cast<uint32_t *>(crossThreadData.begin() + offsets->workDimensions - inlineData.size()) = workDimensions;
            }
        } else {
            *reinterpret_cast<uint32_t *>(crossThreadData.begin() + offsets->workDimensions) = workDimensions;
        }
    }
}

void MutableIndirectData::setGlobalWorkOffset(MaxChannelsArray globalWorkOffset) {
    setIfDefined(offsets->globalWorkOffset, globalWorkOffset);
}

void MutableIndirectData::setPerThreadData(ArrayRef<const uint8_t> perThreadData) {
    UNRECOVERABLE_IF(this->perThreadData.size() < perThreadData.size());
    PRINT_STRING(NEO::debugManager.flags.PrintMclData.get(), stderr, "MCL copy local IDs into per-thread %p\n", this->perThreadData.begin());
    memcpy_s(this->perThreadData.begin(), this->perThreadData.size(),
             perThreadData.begin(), perThreadData.size());
}

void MutableIndirectData::setCrossThreadData(ArrayRef<const uint8_t> crossThreadData) {
    UNRECOVERABLE_IF(this->crossThreadData.size() < crossThreadData.size());
    memcpy_s(this->crossThreadData.begin(), this->crossThreadData.size(),
             crossThreadData.begin(), crossThreadData.size());
}

}; // namespace L0::MCL
