#!/usr/bin/env python3
# SPDX-License-Identifier: GPL-3.0-or-later
# Copyright (C) 2018 Takashi Sakamoto

from sys import exit
from time import sleep
from signal import signal, SIGINT
from json import dumps

from hinawa_utils.misc.cui_kit import CuiKit
from hinawa_utils.dice.dice_extended_unit import DiceExtendedUnit

def _print_stream_params(direction, index, params):
    print('{0} stream {1}:'.format(direction.upper(), index))
    print('  pcm:    {0}'.format(params['pcm']))
    print('  midi:   {0}'.format(params['midi']))
    print('  formation:')
    for i, name in enumerate(params['formation']):
        print('    {0}: {1}'.format(i, name))
    print('  ac3:    0x{0:08x}'.format(params['ac3']))

def handle_current_status(unit, args):
    print('Capabilities:')
    for category in ('general', 'router', 'mixer'):
        print('  {0}:'.format(category))
        for name, cap in unit.get_caps(category).items():
            print('    {0}: {1}'.format(name, cap))
    print('Stream configurations:')
    rate = unit.get_sampling_rate()
    stream_params = unit.get_stream_params(rate)
    for direction, params in stream_params.items():
        for i, param in enumerate(params):
            _print_stream_params(direction, i, param)
    print('Standalone setting:')
    source = unit.get_standalone_clock_source()
    print('  clock source: {0}'.format(source))
    for key, val in unit.get_standalone_clock_source_params(source).items():
        print('  {0}: {1}'.format(key, val))

    print('Output sources:')
    targets = unit.get_output_labels()
    for target in targets:
        source = unit.get_output_source(target)
        if source != 'None':
            print('  {0} {1}'.format(target, source))

    print('Capture sources:')
    targets = unit.get_tx_stream_labels()
    for target in targets:
        source = unit.get_tx_stream_source(target)
        if source != 'None':
            print('  {0} {1}'.format(target, source))

    print('Mixer sources:')
    targets = unit.get_mixer_input_labels()
    for target in targets:
        source = unit.get_mixer_source(target)
        if source != 'None':
            print('  {0} {1}'.format(target, source))

    print('Mixer setting:')
    for output in unit.get_mixer_output_labels():
        for input in unit.get_mixer_input_labels():
            source = unit.get_mixer_source(input)
            if source == 'None':
                continue
            gain_r = unit.get_mixer_gain(output, input, 0)
            gain_l = unit.get_mixer_gain(output, input, 1)
            print('  {0} {1}({2}): {3} dB, {4} dB'.format(output, input, source, gain_r, gain_l))

    return True

def _print_dot_formatted_text(inputs, tx_stream_map, rx_streams, output_map,
                              mixer_in_map, mixer_out_map):
    input_labels = {}
    tx_stream_labels = {}
    rx_stream_labels = {}
    mixer_in_labels = {}
    mixer_out_labels = {}
    output_labels = {}
    print('digraph {')
    print('  labelloc = "t"')
    print('  rankdir = "LR"')
    print('  subgraph cluster_input {')
    print('    label = "inputs"')
    print('    graph [rank = "min"]')
    for i, input in enumerate(inputs):
        if input != 'None':
            label = 'in{0}'.format(i)
            input_labels[input] = label
            print('    {0} [label = "{1}"]'.format(label, input))
    print('  }')
    print('  subgraph cluster_1394bus {')
    print('    label = "1394 Bus"')
    print('    graph [rank = "source"]')
    print('    subgraph cluster_1394_bus_out {')
    print('      label = "1394 bus out"')
    print('      graph [rank = "min"]')
    for i, tx_stream in enumerate(tx_stream_map):
        label = 'tx{0}'.format(i)
        tx_stream_labels[tx_stream] = label
        print('      {0} [label = "{1}"]'.format(label, tx_stream))
    print('    }')
    print('    subgraph cluster_1394_bus_in {')
    print('      label = "1394 bus in"')
    print('      graph [rank = "max"]')
    for i, rx_stream in enumerate(rx_streams):
        label = 'rx{0}'.format(i)
        rx_stream_labels[rx_stream] = label
        print('      {0} [label = "{1}"]'.format(label, rx_stream))
    print('    }')
    print('  }')
    print('  subgraph cluster_mixer {')
    print('    label = "mixer"')
    print('    graph [rank = "sink"]')
    print('    subgraph cluster_mixer_input {')
    print('      label = "inputs"')
    for i, mixer_in in enumerate(mixer_in_map):
        label = 'mix_in{0}'.format(i)
        mixer_in_labels[mixer_in] = label
        print('      {0} [label = "{1}"]'.format(label, mixer_in))
    print('    }')
    print('    subgraph cluster_mixer_output {')
    print('      label = "outputs"')
    for i, mixer_out in enumerate(mixer_out_map):
        label = 'mix_out{0}'.format(i)
        mixer_out_labels[mixer_out] = label
        print('      {0} [label = "{1}"]'.format(label, mixer_out))
    print('    }')
    print('  }')
    print('  subgraph cluster_output {')
    print('    label = "outputs"')
    print('    graph [rank = "max"]')
    for i, output in enumerate(output_map):
        label = 'out{0}'.format(i)
        output_labels[output] = label
        print('    out{0} [label = "{1}"]'.format(i, output))
    print('  }')

    for mixer_in, mixer_in_src in mixer_in_map.items():
        if mixer_in_src == 'None':
            continue
        if mixer_in_src in input_labels:
            input_label = input_labels[mixer_in_src]
        elif mixer_in_src in rx_stream_labels:
            input_label = rx_stream_labels[mixer_in_src]
        mixer_in_label = mixer_in_labels[mixer_in]
        print('  {0} -> {1}'.format(input_label, mixer_in_label))

    for tx_stream, tx_stream_src in tx_stream_map.items():
        if tx_stream_src == 'None':
            continue
        if tx_stream_src not in input_labels:
            continue
        input_label = input_labels[tx_stream_src]
        tx_stream_label = tx_stream_labels[tx_stream]
        print('  {0} -> {1}'.format(input_label, tx_stream_label))

    for mixer_out, mixer_ins in mixer_out_map.items():
        for mixer_in in mixer_ins:
            if mixer_in in mixer_in_labels:
                mixer_in_label = mixer_in_labels[mixer_in]
                mixer_out_label = mixer_out_labels[mixer_out]
                print('  {0} -> {1}'.format(mixer_in_label, mixer_out_label))

    for output, output_src in output_map.items():
        output_label = output_labels[output]
        if output_src in mixer_out_labels:
            input_label = mixer_out_labels[output_src]
        elif output_src in input_labels:
            input_label = input_labels[output_src]
        elif output_src in rx_stream_labels:
            input_label = rx_stream_labels[output_src]
        else:
            continue
        print('  {0} -> {1}'.format(input_label, output_label))

    print('}')

def _print_json_formatted_text(inputs, tx_stream_map, rx_streams, output_map,
                               mixer_in_map, mixer_out_map):
    data = {
        'inputs':           inputs,
        'tx-stream-map':    tx_stream_map,
        'rx-streams':       rx_streams,
        'output-map':       output_map,
        'mixer-in-map':     mixer_in_map,
        'mixer-out-map':    mixer_out_map,
    }
    print(dumps(data))

def _print_pretty_format(inputs, tx_stream_map, rx_streams, output_map,
                         mixer_in_map, mixer_out_map):
    print('Inputs:')
    for input in inputs:
        print('  {0}'.format(input))
    print('Tx streams:')
    for tx_stream, tx_stream_src in tx_stream_map.items():
        print('  {0: <16} -> {1}'.format(tx_stream_src, tx_stream))
    print('Rx streams:')
    for rx_stream in rx_streams:
        print('  {0}'.format(rx_stream))
    print('Outputs')
    for output, output_src in output_map.items():
        print('  {0: <16} -> {1}'.format(output_src, output))
    print('Mixer-inputs:')
    for mixer_in, mixer_in_src in mixer_in_map.items():
        print('  {0: <16} -> {1}'.format(mixer_in_src, mixer_in))
    print('Mixer-outputs:')
    for mixer_out, mixer_out_srcs in mixer_out_map.items():
        for mixer_out_src in mixer_out_srcs:
            print('  {0: <16} -> {1}'.format(mixer_out_src, mixer_out))

def handle_current_routing(unit, args):
    ops = {
        'graphviz': _print_dot_formatted_text,
        'json':     _print_json_formatted_text,
        'pretty':   _print_pretty_format,
    }
    if len(args) >= 1 and args[0] in ops:
        op = args[0]
        tx_stream_map = {}
        for tx_stream in unit.get_tx_stream_labels():
            tx_stream_map[tx_stream] = unit.get_tx_stream_source(tx_stream)

        output_map = {}
        for output in unit.get_output_labels():
            output_map[output] = unit.get_output_source(output)

        mixer_in_map = {}
        for mixer_in in unit.get_mixer_input_labels():
            mixer_in_map[mixer_in] = unit.get_mixer_source(mixer_in)

        inputs = []
        rx_streams = []
        for input in unit.get_mixer_source_labels():
            if input.find('Stream-') != 0:
                inputs.append(input)
            else:
                rx_streams.append(input)

        mixer_out_map = {}
        for mixer_out in unit.get_mixer_output_labels():
            if mixer_out not in mixer_out_map:
                mixer_out_map[mixer_out] = []
            for mixer_in, mixer_in_src in mixer_in_map.items():
                if mixer_in_src == 'None':
                    continue
                if (unit.get_mixer_gain(mixer_out, mixer_in, 0) != -float('inf') or
                    unit.get_mixer_gain(mixer_out, mixer_in, 1) != -float('inf')):
                    mixer_out_map[mixer_out].append(mixer_in)

        ops[op](inputs, tx_stream_map, rx_streams, output_map, mixer_in_map,
                mixer_out_map)
        return True
    print('Argument for print-routing command')
    print('  print-routing OP')
    print('    OP: [{0}]'.format('|'.join(ops)))
    return False

def handle_stream_params(unit, args):
    rates = unit.get_supported_sampling_rates()
    if len(args) >= 1 and int(args[0]) in rates:
        rate = int(args[0])
        stream_params = unit.get_stream_params(rate)
        for direction, params in stream_params.items():
            for i, param in enumerate(params):
                _print_stream_params(direction, i, param)
        return True
    print('Argument for stream-params command:')
    print('  stream-params RATE')
    print('    RATE:   [{0}]'.format('|'.join(map(str, rates))))
    return False

def handle_router_entries(unit, args):
    rates = unit.get_supported_sampling_rates()
    if len(args) >= 1 and int(args[0]) in rates:
        rate = int(args[0])
        entries = unit.get_router_entries(rate)
        for entry in entries:
            print('  {0: <16} -> {1}'.format(entry['src'], entry['dst']))
        return True
    print('Argument for router-entries command:')
    print('  router-entries RATE')
    print('    RATE:   [{0}]'.format('|'.join(map(str, rates))))
    return False

def _handle_target_source(unit, args, cmd, targets_func, sources_func,
                          set_func, get_func):
    ops = ('set', 'get')
    targets = targets_func()
    srcs = sources_func()
    if len(args) >= 2:
        target, op = args[0:2]
        if target in targets and op in ops:
            if op == 'set' and len(args) >= 3 and args[2] in srcs:
                src = args[2]
                set_func(target, src)
                return True
            elif op == 'get':
                print(get_func(target))
                return True
    print('Arguments for {0} command:'.format(cmd))
    print('  {0} TARGET OP [SOURCE]'.format(cmd))
    print('    TARGET:    [{0}]'.format('|'.join(targets)))
    print('    OP:        [{0}]'.format('|'.join(ops)))
    print('    SOURCE:    [{0}] if OP=set'.format('|'.join(srcs)))
    return False

def handle_output_source(unit, args):
    return _handle_target_source(unit, args, 'output-source',
                        unit.get_output_labels, unit.get_output_source_labels,
                        unit.set_output_source, unit.get_output_source)

def handle_capture_source(unit, args):
    return _handle_target_source(unit, args, 'capture-source',
                unit.get_tx_stream_labels, unit.get_tx_stream_source_labels,
                unit.set_tx_stream_source, unit.get_tx_stream_source)

def handle_mixer_source(unit, args):
    return _handle_target_source(unit, args, 'mixer-source',
                    unit.get_mixer_input_labels, unit.get_mixer_source_labels,
                    unit.set_mixer_source, unit.get_mixer_source)

def handle_mixer_gain(unit, args):
    chs = ('0', '1')
    ops = ('set', 'get')
    outputs = unit.get_mixer_output_labels()
    inputs = unit.get_mixer_input_labels()
    if len(args) >= 4:
        output, input, ch, op = args[0:4]
        if output in outputs and input in inputs and ch in chs and op in ops:
            ch = int(ch)
            if op == 'set' and len(args) >= 5:
                db = float(args[4])
                unit.set_mixer_gain(output, input, ch, db)
                return True
            elif op == 'get':
                db = unit.get_mixer_gain(output, input, ch)
                print('{:.3f}'.format(db))
                return True
    print('Arguments for mixer-gain command:')
    print('  mixer-gain OUTPUT INPUT CH ITEM OP [dB]')
    print('    OUTPUT: [{0}]'.format('|'.join(outputs)))
    srcs = []
    for input in inputs:
        src = unit.get_mixer_source(input)
        if src != 'None':
            srcs.append('{0}(={1})'.format(input, src))
    print('    INPUT:  [{0}]'.format('|'.join(srcs)))
    print('    CH:     [{0}]'.format('|'.join(chs)))
    print('    OP:     [{0}]'.format('|'.join(ops)))
    print('    dB:     [-inf,-84.00..4.00] if OP=set ')
    print('Saturation state:')
    for output, saturation in unit.get_mixer_saturations().items():
        print('  {0}: {1}/{2}'.format(output, saturation[0], saturation[1]))
    return False

def handle_mixer_panning(unit, args):
    chs = ('0', '1')
    ops = ('set', 'get')
    outputs = unit.get_mixer_output_labels()
    inputs = unit.get_mixer_input_labels()
    if len(args) >= 4:
        output, input, ch, op = args[0:4]
        if output in outputs and input in inputs and ch in chs and op in ops:
            ch = int(ch)
            if op == 'set' and len(args) >= 5:
                balance = float(args[4])
                unit.set_mixer_balance(output, input, ch, balance)
                return True
            elif op == 'get':
                balance = unit.get_mixer_balance(output, input, ch)
                print('{:.3f}'.format(balance))
                return True
    print('Arguments for mixer-panning command:')
    print('  mixer-panning OUTPUT INPUT CH ITEM OP [BALANCE]')
    print('    OUTPUT: [{0}]'.format('|'.join(outputs)))
    srcs = []
    for input in inputs:
        src = unit.get_mixer_source(input)
        if src != 'None':
            srcs.append('{0}(={1})'.format(input, src))
    print('    INPUT:  [{0}]'.format('|'.join(srcs)))
    print('    CH:     [{0}]'.format('|'.join(chs)))
    print('    OP:     [{0}]'.format('|'.join(ops)))
    print('    BALANCE:[0.00...100.00] (left-right) if OP=set ')
    return False

def handle_standalone_clock(unit, args):
    ops = ('set', 'get')
    clocks = unit.get_supported_clock_sources()
    if len(args) >= 1 and args[0] in ops:
        op = args[0]
        if op == 'set' and len(args) >= 2 and args[1] in clocks:
            clock = args[1]
            unit.set_standalone_clock_source(clock)
            return True
        elif op == 'get':
            print(unit.get_standalone_clock_source())
            return True
    print('Arguments for standalone-clock command:')
    print('  standalone-clock OP [CLOCK]')
    print('    OP:     [{0}]'.format('|'.join(ops)))
    print('    CLOCK:  [{0}]'.format('|'.join(clocks)))
    return False

def handle_standalone_params(unit, args):
    ops = ('set', 'get')
    clocks = unit.get_supported_clock_sources()
    if len(args) >= 2:
        clk, op = args[0:2]
        if clk in clocks and op in ops:
            param_options = unit.get_standalone_clock_source_param_options(clk)
            if op == 'set' and len(args) >= 3:
                params = {}
                for entry in args[2].split(','):
                    if entry.count('=') != 1:
                        continue
                    if entry.find('=') > 0:
                        key, val = entry.split('=')
                        if key not in param_options:
                            continue
                        options = param_options[key]
                        if isinstance(options, int):
                            params[key] = int(val)
                        elif (isinstance(options, list) and
                                isinstance(options[0], int)):
                            params[key] = int(val)
                        else:
                            params[key] = val
                unit.set_standalone_clock_source_params(clk, params)
                return True
            elif op == 'get':
                params = unit.get_standalone_clock_source_params(clk)
                for name, value in params.items():
                    print('{0}={1}'.format(name, value))
                return True
    print('Arguments for standalone-params command:')
    print('  standalone-params CLOCK OP [MODE]')
    print('    CLOCK:  [{0}]'.format('|'.join(clocks)))
    print('    OP:     [{0}]'.format('|'.join(ops)))
    print('    MODE:')
    for clk in clocks:
        params = unit.get_standalone_clock_source_param_options(clk)
        for mode, options in params.items():
            if isinstance(options, int):
                vals = 'integer(up tp {0})'.format(options)
            elif mode == 'rate':
                vals = '|'.join(map(str, options))
            else:
                vals = '|'.join(options)
            print('      {0}=[{1}] if CLOCK={2}'.format(mode, vals, clk))
    return False

def handle_listen_metering(unit, args):
    # This is handled by another context.
    def handle_unix_signal(signum, frame):
        exit()
    signal(SIGINT, handle_unix_signal)
    while 1:
        meters = unit.get_metering()
        for src_blk_id, src_blk_params in unit.get_metering().items():
            for src_blk_ch, dst_params in src_blk_params.items():
                for dst_blk_id, dst_blk_params in dst_params.items():
                    for dst_blk_ch, peak in dst_blk_params.items():
                        print(src_blk_id, src_blk_ch, dst_blk_id, dst_blk_ch, peak)
        print('')
        sleep(0.1)
    return True

def handle_storage_operation(unit, args):
    ops = {
        'store':    unit.store_to_storage,
        'load':     unit.load_from_storage
    }
    if len(args) >= 1 and args[0] in ops:
        categories = ops[args[0]]()
        for category in categories:
            print('{0} is processed.'.format(category))
        return True
    print('Arguments for store-config command:')
    print('  store-config OP')
    print('    OP:     [{0}]'.format('|'.join(ops)))
    return False

cmds = {
    'current-status':   handle_current_status,
    'current-routing':  handle_current_routing,
    'stream-params':    handle_stream_params,
    'router-entries':   handle_router_entries,
    'output-source':    handle_output_source,
    'capture-source':   handle_capture_source,
    'mixer-source':     handle_mixer_source,
    'mixer-gain':       handle_mixer_gain,
    'mixer-panning':    handle_mixer_panning,
    'standalone-clock': handle_standalone_clock,
    'standalone-params':  handle_standalone_params,
    'listen-metering':  handle_listen_metering,
}

fullpath = CuiKit.seek_snd_unit_path()
if fullpath:
    unit = DiceExtendedUnit(fullpath)
    if (unit.get_caps('general')['storage-available'] and
        unit.get_caps('general')['storable-stream-conf'] and
        unit.get_caps('router')['is-storable'] and
        unit.get_caps('mixer')['is-storable']):
        cmds['storage-operation'] = handle_storage_operation
    CuiKit.dispatch_command(unit, cmds)
