#include "nak_private.h"
#include "nir.h"
#include "nir_builder.h"
#include "nir_search.h"
#include "nir_search_helpers.h"

/* What follows is NIR algebraic transform code for the following 9
 * transforms:
 *    ('imin', 'a', 'b') => ('bcsel', ('ilt', 'a', 'b'), 'a', 'b')
 *    ('imax', 'a', 'b') => ('bcsel', ('ilt', 'a', 'b'), 'b', 'a')
 *    ('umin', 'a', 'b') => ('bcsel', ('ult', 'a', 'b'), 'a', 'b')
 *    ('umax', 'a', 'b') => ('bcsel', ('ult', 'a', 'b'), 'b', 'a')
 *    ('iadd', 'a@64', ('ineg', 'b@64')) => ('isub', 'a', 'b')
 *    ('iadd', ('iadd(is_used_once)', 'a(is_not_const)', '#b'), 'c(is_not_const)') => ('iadd3', 'a', 'b', 'c')
 *    ('iadd', ('iadd(is_used_once)', 'a(is_not_const)', 'b(is_not_const)'), '#c') => ('iadd3', 'a', 'b', 'c')
 *    ('iadd(is_used_by_non_ldc_nv)', 'a@32', ('ishl', 'b@32', '#s@32')) => ('lea_nv', 'a', 'b', 's')
 *    ('iadd', 'a@64', ('ishl', 'b@64', '#s@32')) => ('lea_nv', 'a', 'b', 's')
 */


static const nir_search_value_union nak_nir_lower_algebraic_late_values[] = {
   /* ('imin', 'a', 'b') => ('bcsel', ('ilt', 'a', 'b'), 'a', 'b') */
   { .variable = {
      { nir_search_value_variable, -2 },
      0, /* a */
      false,
      nir_type_invalid,
      -1,
      {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
   } },
   { .variable = {
      { nir_search_value_variable, -2 },
      1, /* b */
      false,
      nir_type_invalid,
      -1,
      {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
   } },
   { .expression = {
      { nir_search_value_expression, -2 },
      false,
      false,
      false,
      false,
      false,
      false,
      false,
      false,
      -1,
      nir_op_imin,
      0, 1,
      { 0, 1 },
      -1,
   } },

   /* replace0_0_0 -> 0 in the cache */
   /* replace0_0_1 -> 1 in the cache */
   { .expression = {
      { nir_search_value_expression, 1 },
      false,
      false,
      false,
      false,
      false,
      false,
      false,
      false,
      -1,
      nir_op_ilt,
      -1, 0,
      { 0, 1 },
      -1,
   } },
   /* replace0_1 -> 0 in the cache */
   /* replace0_2 -> 1 in the cache */
   { .expression = {
      { nir_search_value_expression, -2 },
      false,
      false,
      false,
      false,
      false,
      false,
      false,
      false,
      -1,
      nir_op_bcsel,
      -1, 0,
      { 3, 0, 1 },
      -1,
   } },

   /* ('imax', 'a', 'b') => ('bcsel', ('ilt', 'a', 'b'), 'b', 'a') */
   /* search1_0 -> 0 in the cache */
   /* search1_1 -> 1 in the cache */
   { .expression = {
      { nir_search_value_expression, -2 },
      false,
      false,
      false,
      false,
      false,
      false,
      false,
      false,
      -1,
      nir_op_imax,
      0, 1,
      { 0, 1 },
      -1,
   } },

   /* replace1_0_0 -> 0 in the cache */
   /* replace1_0_1 -> 1 in the cache */
   /* replace1_0 -> 3 in the cache */
   /* replace1_1 -> 1 in the cache */
   /* replace1_2 -> 0 in the cache */
   { .expression = {
      { nir_search_value_expression, -2 },
      false,
      false,
      false,
      false,
      false,
      false,
      false,
      false,
      -1,
      nir_op_bcsel,
      -1, 0,
      { 3, 1, 0 },
      -1,
   } },

   /* ('umin', 'a', 'b') => ('bcsel', ('ult', 'a', 'b'), 'a', 'b') */
   /* search2_0 -> 0 in the cache */
   /* search2_1 -> 1 in the cache */
   { .expression = {
      { nir_search_value_expression, -2 },
      false,
      false,
      false,
      false,
      false,
      false,
      false,
      false,
      -1,
      nir_op_umin,
      0, 1,
      { 0, 1 },
      -1,
   } },

   /* replace2_0_0 -> 0 in the cache */
   /* replace2_0_1 -> 1 in the cache */
   { .expression = {
      { nir_search_value_expression, 1 },
      false,
      false,
      false,
      false,
      false,
      false,
      false,
      false,
      -1,
      nir_op_ult,
      -1, 0,
      { 0, 1 },
      -1,
   } },
   /* replace2_1 -> 0 in the cache */
   /* replace2_2 -> 1 in the cache */
   { .expression = {
      { nir_search_value_expression, -2 },
      false,
      false,
      false,
      false,
      false,
      false,
      false,
      false,
      -1,
      nir_op_bcsel,
      -1, 0,
      { 8, 0, 1 },
      -1,
   } },

   /* ('umax', 'a', 'b') => ('bcsel', ('ult', 'a', 'b'), 'b', 'a') */
   /* search3_0 -> 0 in the cache */
   /* search3_1 -> 1 in the cache */
   { .expression = {
      { nir_search_value_expression, -2 },
      false,
      false,
      false,
      false,
      false,
      false,
      false,
      false,
      -1,
      nir_op_umax,
      0, 1,
      { 0, 1 },
      -1,
   } },

   /* replace3_0_0 -> 0 in the cache */
   /* replace3_0_1 -> 1 in the cache */
   /* replace3_0 -> 8 in the cache */
   /* replace3_1 -> 1 in the cache */
   /* replace3_2 -> 0 in the cache */
   { .expression = {
      { nir_search_value_expression, -2 },
      false,
      false,
      false,
      false,
      false,
      false,
      false,
      false,
      -1,
      nir_op_bcsel,
      -1, 0,
      { 8, 1, 0 },
      -1,
   } },

   /* ('iadd', 'a@64', ('ineg', 'b@64')) => ('isub', 'a', 'b') */
   { .variable = {
      { nir_search_value_variable, 64 },
      0, /* a */
      false,
      nir_type_invalid,
      -1,
      {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
   } },
   { .variable = {
      { nir_search_value_variable, 64 },
      1, /* b */
      false,
      nir_type_invalid,
      -1,
      {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
   } },
   { .expression = {
      { nir_search_value_expression, 64 },
      false,
      false,
      false,
      false,
      false,
      false,
      false,
      false,
      -1,
      nir_op_ineg,
      -1, 0,
      { 13 },
      -1,
   } },
   { .expression = {
      { nir_search_value_expression, 64 },
      false,
      false,
      false,
      false,
      false,
      false,
      false,
      false,
      -1,
      nir_op_iadd,
      0, 1,
      { 12, 14 },
      -1,
   } },

   /* replace4_0 -> 12 in the cache */
   /* replace4_1 -> 13 in the cache */
   { .expression = {
      { nir_search_value_expression, 64 },
      false,
      false,
      false,
      false,
      false,
      false,
      false,
      false,
      -1,
      nir_op_isub,
      -1, 0,
      { 12, 13 },
      -1,
   } },

   /* ('iadd', ('iadd(is_used_once)', 'a(is_not_const)', '#b'), 'c(is_not_const)') => ('iadd3', 'a', 'b', 'c') */
   { .variable = {
      { nir_search_value_variable, -3 },
      0, /* a */
      false,
      nir_type_invalid,
      0,
      {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
   } },
   { .variable = {
      { nir_search_value_variable, -3 },
      1, /* b */
      true,
      nir_type_invalid,
      -1,
      {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
   } },
   { .expression = {
      { nir_search_value_expression, -3 },
      false,
      false,
      false,
      false,
      false,
      false,
      false,
      false,
      -1,
      nir_op_iadd,
      1, 1,
      { 17, 18 },
      0,
   } },
   { .variable = {
      { nir_search_value_variable, -3 },
      2, /* c */
      false,
      nir_type_invalid,
      0,
      {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
   } },
   { .expression = {
      { nir_search_value_expression, -3 },
      false,
      false,
      false,
      false,
      false,
      false,
      false,
      false,
      -1,
      nir_op_iadd,
      0, 2,
      { 19, 20 },
      -1,
   } },

   { .variable = {
      { nir_search_value_variable, -3 },
      0, /* a */
      false,
      nir_type_invalid,
      -1,
      {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
   } },
   { .variable = {
      { nir_search_value_variable, -3 },
      1, /* b */
      false,
      nir_type_invalid,
      -1,
      {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
   } },
   { .variable = {
      { nir_search_value_variable, -3 },
      2, /* c */
      false,
      nir_type_invalid,
      -1,
      {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
   } },
   { .expression = {
      { nir_search_value_expression, -3 },
      false,
      false,
      false,
      false,
      false,
      false,
      false,
      false,
      -1,
      nir_op_iadd3,
      0, 1,
      { 22, 23, 24 },
      -1,
   } },

   /* ('iadd', ('iadd(is_used_once)', 'a(is_not_const)', 'b(is_not_const)'), '#c') => ('iadd3', 'a', 'b', 'c') */
   /* search6_0_0 -> 17 in the cache */
   { .variable = {
      { nir_search_value_variable, -3 },
      1, /* b */
      false,
      nir_type_invalid,
      0,
      {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
   } },
   { .expression = {
      { nir_search_value_expression, -3 },
      false,
      false,
      false,
      false,
      false,
      false,
      false,
      false,
      -1,
      nir_op_iadd,
      1, 1,
      { 17, 26 },
      0,
   } },
   { .variable = {
      { nir_search_value_variable, -3 },
      2, /* c */
      true,
      nir_type_invalid,
      -1,
      {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
   } },
   { .expression = {
      { nir_search_value_expression, -3 },
      false,
      false,
      false,
      false,
      false,
      false,
      false,
      false,
      -1,
      nir_op_iadd,
      0, 2,
      { 27, 28 },
      -1,
   } },

   /* replace6_0 -> 22 in the cache */
   /* replace6_1 -> 23 in the cache */
   /* replace6_2 -> 24 in the cache */
   /* replace6 -> 25 in the cache */

   /* ('iadd(is_used_by_non_ldc_nv)', 'a@32', ('ishl', 'b@32', '#s@32')) => ('lea_nv', 'a', 'b', 's') */
   { .variable = {
      { nir_search_value_variable, 32 },
      0, /* a */
      false,
      nir_type_invalid,
      -1,
      {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
   } },
   { .variable = {
      { nir_search_value_variable, 32 },
      1, /* b */
      false,
      nir_type_invalid,
      -1,
      {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
   } },
   { .variable = {
      { nir_search_value_variable, 32 },
      2, /* s */
      true,
      nir_type_invalid,
      -1,
      {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
   } },
   { .expression = {
      { nir_search_value_expression, 32 },
      false,
      false,
      false,
      false,
      false,
      false,
      false,
      false,
      -1,
      nir_op_ishl,
      -1, 0,
      { 31, 32 },
      -1,
   } },
   { .expression = {
      { nir_search_value_expression, 32 },
      false,
      false,
      false,
      false,
      false,
      false,
      false,
      false,
      -1,
      nir_op_iadd,
      0, 1,
      { 30, 33 },
      1,
   } },

   /* replace7_0 -> 30 in the cache */
   /* replace7_1 -> 31 in the cache */
   { .variable = {
      { nir_search_value_variable, 32 },
      2, /* s */
      false,
      nir_type_invalid,
      -1,
      {0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15},
   } },
   { .expression = {
      { nir_search_value_expression, 32 },
      false,
      false,
      false,
      false,
      false,
      false,
      false,
      false,
      -1,
      nir_op_lea_nv,
      -1, 0,
      { 30, 31, 35 },
      -1,
   } },

   /* ('iadd', 'a@64', ('ishl', 'b@64', '#s@32')) => ('lea_nv', 'a', 'b', 's') */
   /* search8_0 -> 12 in the cache */
   /* search8_1_0 -> 13 in the cache */
   /* search8_1_1 -> 32 in the cache */
   { .expression = {
      { nir_search_value_expression, 64 },
      false,
      false,
      false,
      false,
      false,
      false,
      false,
      false,
      -1,
      nir_op_ishl,
      -1, 0,
      { 13, 32 },
      -1,
   } },
   { .expression = {
      { nir_search_value_expression, 64 },
      false,
      false,
      false,
      false,
      false,
      false,
      false,
      false,
      -1,
      nir_op_iadd,
      0, 1,
      { 12, 37 },
      -1,
   } },

   /* replace8_0 -> 12 in the cache */
   /* replace8_1 -> 13 in the cache */
   /* replace8_2 -> 35 in the cache */
   { .expression = {
      { nir_search_value_expression, 64 },
      false,
      false,
      false,
      false,
      false,
      false,
      false,
      false,
      -1,
      nir_op_lea_nv,
      -1, 0,
      { 12, 13, 35 },
      -1,
   } },

};

UNUSED static const nir_search_expression_cond nak_nir_lower_algebraic_late_expression_cond[] = {
   is_used_once,
   is_used_by_non_ldc_nv,
};

static const nir_search_variable_cond nak_nir_lower_algebraic_late_variable_cond[] = {
   (is_not_const),
};

static const struct transform nak_nir_lower_algebraic_late_transforms[] = {
   { ~0, ~0, ~0 }, /* Sentinel */

   { 2, 4, 1 },
   { ~0, ~0, ~0 }, /* Sentinel */

   { 5, 6, 1 },
   { ~0, ~0, ~0 }, /* Sentinel */

   { 7, 9, 1 },
   { ~0, ~0, ~0 }, /* Sentinel */

   { 10, 11, 1 },
   { ~0, ~0, ~0 }, /* Sentinel */

   { 21, 25, 2 },
   { ~0, ~0, ~0 }, /* Sentinel */

   { 15, 16, 0 },
   { ~0, ~0, ~0 }, /* Sentinel */

   { 34, 36, 3 },
   { 38, 39, 3 },
   { ~0, ~0, ~0 }, /* Sentinel */

   { 29, 25, 2 },
   { ~0, ~0, ~0 }, /* Sentinel */

   { 21, 25, 2 },
   { 29, 25, 2 },
   { ~0, ~0, ~0 }, /* Sentinel */

   { 15, 16, 0 },
   { ~0, ~0, ~0 }, /* Sentinel */

   { 34, 36, 3 },
   { 38, 39, 3 },
   { ~0, ~0, ~0 }, /* Sentinel */

   { 15, 16, 0 },
   { 21, 25, 2 },
   { ~0, ~0, ~0 }, /* Sentinel */

   { 21, 25, 2 },
   { 34, 36, 3 },
   { 38, 39, 3 },
   { ~0, ~0, ~0 }, /* Sentinel */

   { 15, 16, 0 },
   { 34, 36, 3 },
   { 38, 39, 3 },
   { ~0, ~0, ~0 }, /* Sentinel */

};

static const struct per_op_table nak_nir_lower_algebraic_late_pass_op_table[nir_num_search_ops] = {
   [nir_op_imin] = {
      .filter = NULL,
      
      .num_filtered_states = 1,
      .table = (const uint16_t []) {
      
         2,
      },
   },
   [nir_op_imax] = {
      .filter = NULL,
      
      .num_filtered_states = 1,
      .table = (const uint16_t []) {
      
         3,
      },
   },
   [nir_op_umin] = {
      .filter = NULL,
      
      .num_filtered_states = 1,
      .table = (const uint16_t []) {
      
         4,
      },
   },
   [nir_op_umax] = {
      .filter = NULL,
      
      .num_filtered_states = 1,
      .table = (const uint16_t []) {
      
         5,
      },
   },
   [nir_op_iadd] = {
      .filter = (const uint16_t []) {
         0,
         1,
         0,
         0,
         0,
         0,
         2,
         3,
         4,
         5,
         2,
         2,
         2,
         3,
         3,
         3,
         3,
         2,
         2,
         2,
      },
      
      .num_filtered_states = 6,
      .table = (const uint16_t []) {
      
         6,
         7,
         6,
         10,
         11,
         12,
         7,
         7,
         13,
         14,
         15,
         16,
         6,
         13,
         6,
         10,
         11,
         12,
         10,
         14,
         10,
         10,
         17,
         18,
         11,
         15,
         11,
         17,
         11,
         19,
         12,
         16,
         12,
         18,
         19,
         12,
      },
   },
   [nir_op_ineg] = {
      .filter = NULL,
      
      .num_filtered_states = 1,
      .table = (const uint16_t []) {
      
         8,
      },
   },
   [nir_op_ishl] = {
      .filter = (const uint16_t []) {
         0,
         1,
         0,
         0,
         0,
         0,
         0,
         0,
         0,
         0,
         0,
         0,
         0,
         0,
         0,
         0,
         0,
         0,
         0,
         0,
      },
      
      .num_filtered_states = 2,
      .table = (const uint16_t []) {
      
         0,
         9,
         0,
         9,
      },
   },
};

/* Mapping from state index to offset in transforms (0 being no transforms) */
static const uint16_t nak_nir_lower_algebraic_late_transform_offsets[] = {
   0,
   0,
   1,
   3,
   5,
   7,
   0,
   0,
   0,
   0,
   9,
   11,
   13,
   16,
   18,
   21,
   23,
   26,
   29,
   33,
};

static const nir_algebraic_table nak_nir_lower_algebraic_late_table = {
   .transforms = nak_nir_lower_algebraic_late_transforms,
   .transform_offsets = nak_nir_lower_algebraic_late_transform_offsets,
   .pass_op_table = nak_nir_lower_algebraic_late_pass_op_table,
   .values = nak_nir_lower_algebraic_late_values,
   .expression_cond = nak_nir_lower_algebraic_late_expression_cond,
   .variable_cond = nak_nir_lower_algebraic_late_variable_cond,
};

bool
nak_nir_lower_algebraic_late(
   nir_shader *shader
   , const struct nak_compiler * nak
) {
   bool progress = false;
   bool condition_flags[4];
   const nir_shader_compiler_options *options = shader->options;
   const shader_info *info = &shader->info;
   (void) options;
   (void) info;

   STATIC_ASSERT(40 == ARRAY_SIZE(nak_nir_lower_algebraic_late_values));
   condition_flags[0] = true;
   condition_flags[1] = nak->sm >= 70 && nak->sm < 73;
   condition_flags[2] = options->has_iadd3;
   condition_flags[3] = nak->sm >= 70;

   nir_foreach_function_impl(impl, shader) {
     progress |= nir_algebraic_impl(impl, condition_flags, &nak_nir_lower_algebraic_late_table);
   }

   return progress;
}
