#!/usr/bin/env python

import nsysstats

class SyncMemcpy(nsysstats.Report):

    ROW_LIMIT = 50

    usage = f"""{{SCRIPT}}[:rows=<limit>] -- SyncMemcpy

    Options:
        rows=<limit> - Maximum number of rows returned by the query.
            Default is {ROW_LIMIT}.

    Output: All time values default to nanoseconds
        Duration : Duration of memcpy on GPU
        Start : Start time of memcpy on GPU
        Src Kind : Memcpy source memory kind
        Dst Kind : Memcpy destination memory kind
        Bytes : Number of bytes transferred
        PID : Process identifier
        Device ID : GPU device identifier
        Context ID : Context identifier
        Stream ID : Stream identifier
        API Name : Name of runtime API function

    This rule identifies memory transfers that are synchronous.
    It does not include cudaMemcpy*() (no Async suffix) occurred within the same
    device as well as H2D copy kind with a memory block of 64 KB or less.
"""

    query_sync_memcpy = """
    WITH
        {MEM_KIND_STRS_CTE}
        sync AS (
            SELECT
                id,
                value
            FROM
                StringIds
            WHERE
                value LIKE 'cudaMemcpy%'
                AND value NOT LIKE '%Async%'
        ),
        memcpy AS (
            SELECT
                *
            FROM
                CUPTI_ACTIVITY_KIND_MEMCPY
            WHERE
                NOT (bytes <= 64000 AND copyKind = 1)
                AND NOT (srcDeviceId IS NOT NULL
                        AND srcDeviceId = dstDeviceId)
        )
    SELECT
        memcpy.end - memcpy.start AS "Duration:dur_ns",
        memcpy.start AS "Start:ts_ns",
        msrck.name AS "Src Kind",
        mdstk.name AS "Dst Kind",
        bytes AS "Bytes:mem_B",
        (globalPid >> 24) & 0x00FFFFFF AS "PID",
        deviceId AS "Device ID",
        contextId AS "Context ID",
        streamId AS "Stream ID",
        value AS "API Name",
        globalPid AS "_Global PID",
        copyKind AS "_Copy Kind"
    FROM
        memcpy
    JOIN
        sync
        ON sync.id = runtime.nameId
    JOIN
        CUPTI_ACTIVITY_KIND_RUNTIME AS runtime
        ON runtime.correlationId = memcpy.correlationId
    LEFT JOIN
        MemKindStrs AS msrck
        ON srcKind = msrck.id
    LEFT JOIN
        MemKindStrs AS mdstk
        ON dstKind = mdstk.id
    ORDER BY
        1 DESC
    LIMIT {ROW_LIMIT}
"""

    table_checks = {
        'CUPTI_ACTIVITY_KIND_RUNTIME':
            "{DBFILE} could not be analyzed because it does not contain CUDA trace data.",
        'CUPTI_ACTIVITY_KIND_MEMCPY':
            "{DBFILE} could not be analyzed because it does not contain CUDA trace data."
    }

    def setup(self):
        err = super().setup()
        if err != None:
            return err

        row_limit = self.ROW_LIMIT
        for arg in self.args:
            s = arg.split('=')
            if len(s) == 2 and s[0] == 'rows' and s[1].isdigit():
                row_limit = s[1]
            else:
                exit(self.EXIT_INVALID_ARG)

        self.query = self.query_sync_memcpy.format(
            MEM_KIND_STRS_CTE = self.MEM_KIND_STRS_CTE,
            ROW_LIMIT = row_limit)

if __name__ == "__main__":
    SyncMemcpy.Main()
