#include "analyzer.h"
#include "pvGenerator.h"
#include "eval/evalFactory.h"
#include "eval/eval.h"

#include <boost/ptr_container/ptr_vector.hpp>
#include <boost/program_options.hpp>
#include <boost/thread/thread.hpp>
#include <boost/scoped_ptr.hpp>

#include <iostream>

namespace po = boost::program_options;
using namespace osl;
using namespace gpsshogi;

const int MaxThreads = 64;
static int cross_validation_randomness = 1;

static void validate(double search_window_for_validation,
		     int max_progress, int min_rating,
		     size_t num_records,
		     size_t split, int cross_start,
		     const std::vector<std::string> &kisen_filenames,
		     gpsshogi::Eval *eval, std::ostream &os)
{
  const KisenAnalyzer::OtherConfig config =
    {
      search_window_for_validation, max_progress, cross_validation_randomness,
      min_rating, eval
    };
  boost::ptr_vector<boost::thread> threads;
      
  KisenAnalyzer::RecordConfig cross_configs[MaxThreads];
  PVGenerator::Result results[MaxThreads];
  KisenAnalyzer::distributeJob(split, &cross_configs[0], cross_start,
			       num_records,
			       kisen_filenames, min_rating);
  for (size_t j=0; j<split; ++j) {
    threads.push_back(new boost::thread(Validator(cross_configs[j], config, results+j)));
  }
  stat::Average werror, order_lb, order_ub, toprated, toprated_strict;
  int record_processed = 0;
  for (size_t j=0; j<split; ++j) {
    threads[j].join();
    werror.merge(results[j].werrors);
    order_lb.merge(results[j].order_lb);
    order_ub.merge(results[j].order_ub);
    toprated.merge(results[j].toprated);
    toprated_strict.merge(results[j].toprated_strict);
    record_processed += results[j].record_processed;
  }
  os << "\n  mean errors in search " << werror.getAverage() << "\n";
  os << "  average order (lb) " << order_lb.getAverage() << "\n";
  os << "  average order (ub) " << order_ub.getAverage() << "\n";
  os << "  ratio that best move is recorded move " << toprated.getAverage() << "\n";
  os << "  ratio that best move is recorded move (strict) " << toprated_strict.getAverage() << "\n";
  os << "  (#records) " << record_processed << "\n";
}

int main(int argc, char **argv)
{
  double search_window_for_validation;
  size_t num_records, num_cpus;
  std::string eval_type, initial_value;
  int max_progress;
  bool high_rating_only;
  int cross_validation_start;

  po::options_description options("all_options");

  options.add_options()
    ("num-records,n",
     po::value<size_t>(&num_records)->default_value(0),
     "number of records to be analyzed (all if 0)")
    ("num-cpus,N",
     po::value<size_t>(&num_cpus)->default_value(1),
     "number cpus to be used")
    ("kisen-file,k", po::value<std::vector<std::string> >(),
     "filename for records to be analyzed")
    ("eval,e",
     po::value<std::string>(&eval_type)->default_value(std::string("piece")),
     "evaluation function (piece, rich0, rich1)")
    ("vwindow",
     po::value<double>(&search_window_for_validation)->default_value(8),
     "search window for validation, relative to pawn value")
    ("initial-value-file,f",
     po::value<std::string>(&initial_value)->default_value(""),
     "File with initial eval values")
    ("max-progress",
     po::value<int>(&max_progress)->default_value(6),
     "When non-negative, only use states where progress is less than this "
      "value.")
    ("high-rating-only",
     po::value<bool>(&high_rating_only)->default_value(true),
      "When true only consider plays where both player have at least "
      "1500 rating value")
    ("cross-validation-start",
     po::value<int>(&cross_validation_start)->default_value(200000),
     "Start ID of record in kisen file to do cross validation")
    ("help", "produce help message")
    ;

  std::vector<std::string> kisen_filenames;
  po::variables_map vm;
  try
  {
    po::store(po::parse_command_line(argc, argv, options), vm);
    po::notify(vm);
    if (vm.count("kisen-file"))
      kisen_filenames = vm["kisen-file"].as<std::vector<std::string> >();
    else
      kisen_filenames.push_back("../../../data/kisen/01.kif");
  }
  catch (std::exception& e)
  {
    std::cerr << "error in parsing options" << std::endl
	      << e.what() << std::endl;
    std::cerr << options << std::endl;
    return 1;
  }
  if (vm.count("help")) {
    std::cerr << options << std::endl;
    return 0;
  }

  boost::scoped_ptr<gpsshogi::Eval>
    eval(gpsshogi::EvalFactory::newEval(eval_type));
  if (eval == NULL)
  {
    std::cerr << "unknown eval type " << eval_type << "\n";
    throw std::runtime_error("unknown eval type");
  }
  eval->load(initial_value.c_str());
  size_t min_rating = high_rating_only ? 1500 : 0;

  validate(search_window_for_validation, max_progress,
	   min_rating, num_records, num_cpus,
	   cross_validation_start, kisen_filenames, eval.get(), std::cout);

  return 0;
}

// ;;; Local Variables:
// ;;; mode:c++
// ;;; c-basic-offset:2
// ;;; End:
