#include <lib94/lib94.hpp>
#include <iostream>
#include <fstream>
#include <sstream>
#include <cstdio>
#include <vector>
#include <mpi.h>

const int default_rounds_per_chunk = 250;
const int steps_to_tie = 1000000;

int error(std::string msg, int rank) {
  if (rank == 0) {
    std::cerr << msg << std::endl;
    return 1;
  }
  return 0;
}

int main(int argc, char **argv) {

  MPI_Init(&argc, &argv);

  int rank;
  int size;

  MPI_Comm_rank(MPI_COMM_WORLD, &rank);
  MPI_Comm_size(MPI_COMM_WORLD, &size);

  if (size < 2)
    error("this must be run under mpi with at least two processes.", rank);

  std::vector<std::string> filenames = {};
  int rounds_per_chunk = default_rounds_per_chunk;

  for (int i = 1; i < argc; ++i) {
    if (atoi(argv[i]) > 0)
      rounds_per_chunk = atoi(argv[i]);
    else
      filenames.push_back(argv[i]);
  }

  if (filenames.size() == 0)
    return error("no files specified.", rank);
  if (filenames.size() == 1)
    return error("only one file specified.", rank);

  int count = filenames.size();
  lib94::warrior **warriors = new lib94::warrior *[count];
  for (int i = 0; i < count; ++i) {
    std::ifstream file(filenames[i]);
    if (!file)
      return error("could not open " + filenames[i] + ".", rank);
    std::stringstream stream;
    stream << file.rdbuf();
    try {
      warriors[i] = lib94::compile_warrior(stream.str());
    }
    catch (const lib94::compiler_exception &ex) {
      return error("could not compile " + filenames[i] + ": " + ex.message + " on line " + std::to_string(ex.source_line_number) + ".", rank);
    }
  }

  //w1 * count + w2
  int *placements = new int[count * count];
  for (int i = 0; i < count; ++i)
    for (int j = 0; j < count; ++j) {
      int p = LIB94_CORE_SIZE - warriors[i]->instructions.size() - warriors[j]->instructions.size() + 1;
      if (p <= 0)
        return error(filenames[i] + " and " + filenames[j] + " do not fit in core together.", rank);
      placements[i * count + j] = p;
    }

  if (rank == 0) {

    std::cerr << "\x1b""7\x1b[?47h\x1b[2J" << std::flush;

    //w1 * count + w2
    int *wins_array = new int[count * count];
    for (int i = 0; i < count * count; ++i)
      wins_array[i] = 0;

    for (int i = 0; i < count; ++i)
      for (int j = 0; j < count; ++j) {
        if (i == j)
          continue;
        int rounds = placements[i * count + j];
        for (int r = 0; r < rounds; r += rounds_per_chunk) {
          int e = std::min(r + rounds_per_chunk, rounds);

          MPI_Status status;
          int buffer[4];
          MPI_Recv(buffer, 4, MPI_INT, MPI_ANY_SOURCE, 1, MPI_COMM_WORLD, &status);
          int source = status.MPI_SOURCE;

          wins_array[buffer[0] * count + buffer[1]] += buffer[2];
          wins_array[buffer[1] * count + buffer[0]] += buffer[3];

          buffer[0] = i;
          buffer[1] = j;
          buffer[2] = r;
          buffer[3] = e;
          MPI_Send(buffer, 4, MPI_INT, source, 0, MPI_COMM_WORLD);

          std::cerr << "\x1b[" << source << ";1Hworker " << source << ": " << warriors[i]->name << " vs " << warriors[j]->name << " rounds " << r + 1 << " to " << e << " of " << rounds << ".\x1b[0K" << std::flush;
        }
      }

    for (int i = 1; i < size; ++i) {
      MPI_Status status;
      int buffer[4];
      MPI_Recv(buffer, 4, MPI_INT, MPI_ANY_SOURCE, 1, MPI_COMM_WORLD, &status);
      int source = status.MPI_SOURCE;

      wins_array[buffer[0] * count + buffer[1]] += buffer[2];
      wins_array[buffer[1] * count + buffer[0]] += buffer[3];

      buffer[3] = 0;
      MPI_Send(buffer, 4, MPI_INT, source, 0, MPI_COMM_WORLD);
      std::cerr << "\x1b[" << source << ";1Hworker " << source << ": complete.\x1b[0K" << std::flush;
    }

    std::cerr << "\x1b[?47l\x1b""8";

    int col_width = 13;
    for (int i = 0; i < count; ++i)
      if ((int)warriors[i]->name.size() > col_width)
        col_width = (int)warriors[i]->name.size();

    printf(" %*s", col_width, "");
    for (int i = 0; i < count; ++i)
      printf(" | %*s", col_width, warriors[i]->name.c_str());
    putchar('\n');
    for (int k = 0; k < col_width + 2; ++k)
      putchar('-');
    for (int i = 0; i < count; ++i) {
      putchar('+');
      for (int k = 0; k < col_width + 2; ++k)
        putchar('-');
    }
    putchar('\n');
    for (int j = 0; j < count; ++j) {
      printf(" %*s", col_width, warriors[j]->name.c_str());
      for (int i = 0; i < count; ++i)
        if (i == j)
          printf(" | %*s", col_width, "");
        else
          printf(" | %5d / %5d%*s", wins_array[i * count + j], placements[i * count + j] * 2, col_width - 13, "");
      putchar('\n');
    }

  }

  else {

    lib94::instruction core_background = {
      .op = lib94::DAT,
      .mod = lib94::F,
      .amode = lib94::DIRECT,
      .bmode = lib94::DIRECT,
      .anumber = 0,
      .bnumber = 0
    };

    int buffer[4] = { 0, 0, 0, 0 };
    MPI_Send(buffer, 4, MPI_INT, 0, 1, MPI_COMM_WORLD);

    while (true) {

      MPI_Recv(buffer, 4, MPI_INT, 0, 0, MPI_COMM_WORLD, 0);
      if (buffer[3] == 0)
        break;

      auto w1 = warriors[buffer[0]];
      auto w2 = warriors[buffer[1]];
      int start_offset = w1->instructions.size();
      int w1_wins = 0;
      int w2_wins = 0;

      const lib94::warrior *const wlist[2] = { w1, w2 };
      int offsets[2] = {};

      for (int round = buffer[2]; round < buffer[3]; ++round) {

        offsets[1] = start_offset + round;
        lib94::clear_core(core_background);
        lib94::init_round(wlist, 2, offsets, false);

        for (int step = 0; step < steps_to_tie; ++step) {
          auto lost = lib94::single_step<false>();
          if (lost == w1) {
            ++w2_wins;
            break;
          }
          if (lost == w2) {
            ++w1_wins;
            break;
          }
        }

      }

      buffer[2] = w1_wins;
      buffer[3] = w2_wins;
      MPI_Send(buffer, 4, MPI_INT, 0, 1, MPI_COMM_WORLD);

    }

  }

  MPI_Finalize();
  return 0;

}