#include <curl/curl.h>
#include <algorithm>
#include <sqlite3.h>
#include <iostream>
#include <optional>
#include <unistd.h>
#include <fstream>
#include <random>
#include <vector>

#include "json.hpp"

std::default_random_engine prng;
sqlite3 *dbc;
CURL *curl_handle;

std::string bot_token;
std::string channel_id;
int polls_per_day;
long post_time_utc;

void init_db() {

  sqlite3_exec(dbc,
    "begin transaction; "
    "create table misc ("
      "on_round INTEGER NOT NULL); "
    "insert into misc (on_round) "
      "values (0); "
    "create table entries ("
      "name TEXT UNIQUE NOT NULL, "
      "round INTEGER NOT NULL, "
      "in_active_poll INTEGER NOT NULL); "
    "create table active_polls ("
      "msg_id TEXT UNIQUE NOT NULL); "
    "create table past_polls ("
      "round INTEGER NOT NULL, "
      "entry_1 TEXT NOT NULL, "
      "entry_2 TEXT NOT NULL, "
      "users_1 TEXT NOT NULL, "
      "users_2 TEXT NOT NULL)",
    0, 0, 0);

  sqlite3_stmt *insert_entry;
  sqlite3_prepare_v2(dbc,
    "insert into entries (name, round, in_active_poll)"
      "values (?, 1, 0)", -1, &insert_entry, 0);

  std::ifstream entries("entries.txt");
  if (!entries) {
    std::cerr << "Please put the list of entries into entries.txt, separated by newlines." << std::endl;
    exit(1);
  }

  std::string entry;
  while (std::getline(entries, entry)) {
    sqlite3_bind_text(insert_entry, 1, entry.c_str(), -1, SQLITE_TRANSIENT);
    while (sqlite3_step(insert_entry) != SQLITE_DONE)
      ;
    sqlite3_reset(insert_entry);
  }

  sqlite3_exec(dbc, "end transaction", 0, 0, 0);
  sqlite3_finalize(insert_entry);

}

int current_round_number;
std::vector<std::string> left_for_this_round;

void load_on_round() {
  current_round_number = -1;
  sqlite3_exec(dbc,
    "select on_round from misc limit 1",
    [](void *, int, char **row, char **) {
      current_round_number = atoi(row[0]);
      return 0;
    }, 0, 0);
}

void save_on_round() {
  sqlite3_stmt *update;
  sqlite3_prepare_v2(dbc, "update misc set on_round = ?", -1, &update, 0);
  sqlite3_bind_int(update, 1, current_round_number);
  while (sqlite3_step(update) != SQLITE_DONE)
    ;
  sqlite3_finalize(update);
}

void load_left_for_round() {
  sqlite3_stmt *select_left;
  sqlite3_prepare_v2(dbc,
    "select name from entries where round = ? and in_active_poll = 0",
    -1, &select_left, 0);
  sqlite3_bind_int(select_left, 1, current_round_number);

  left_for_this_round.clear();
  int result;
  while ((result = sqlite3_step(select_left)) != SQLITE_DONE)
    if (result == SQLITE_ROW)
      left_for_this_round.push_back(
        std::string((const char *)sqlite3_column_text(select_left, 0)));

  sqlite3_finalize(select_left);
  std::shuffle(left_for_this_round.begin(), left_for_this_round.end(), prng);
}

void set_round(const std::string &str, int round) {
  sqlite3_stmt *update;
  sqlite3_prepare_v2(dbc, "update entries set round = ? where name = ?", -1, &update, 0);
  sqlite3_bind_int(update, 1, round);
  sqlite3_bind_text(update, 2, str.c_str(), -1, SQLITE_TRANSIENT);
  while (sqlite3_step(update) != SQLITE_DONE)
    ;
  sqlite3_finalize(update);
}

void set_in_poll(const std::string &str, bool value) {
  sqlite3_stmt *update;
  sqlite3_prepare_v2(dbc, "update entries set in_active_poll = ? where name = ?", -1, &update, 0);
  sqlite3_bind_int(update, 1, value ? 1 : 0);
  sqlite3_bind_text(update, 2, str.c_str(), -1, SQLITE_TRANSIENT);
  while (sqlite3_step(update) != SQLITE_DONE)
    ;
  sqlite3_finalize(update);
}

std::string curl_reciept;

size_t write_callback(void *ptr, size_t, size_t count, void *) {
  size_t offset = curl_reciept.size();
  curl_reciept.resize(offset + count);
  memcpy(curl_reciept.data() + offset, ptr, count);
  return count;
}

nlohmann::json api(const std::string &endpoint, bool is_post, const std::string &read_from = "") {
  curl_reciept.clear();
  curl_easy_setopt(curl_handle, CURLOPT_URL, ("https://discord.com/api/v10" + endpoint).c_str());
  curl_easy_setopt(curl_handle, is_post ? CURLOPT_POST : CURLOPT_HTTPGET, 1);
  curl_slist *sl = 0;
  sl = curl_slist_append(sl, ("Authorization: Bot " + bot_token).c_str());
  if (is_post)
    sl = curl_slist_append(sl, "Content-Type: application/json");
  curl_easy_setopt(curl_handle, CURLOPT_HTTPHEADER, sl);
  curl_easy_setopt(curl_handle, CURLOPT_WRITEFUNCTION, &write_callback);
  if (is_post)
    curl_easy_setopt(curl_handle, CURLOPT_POSTFIELDS, read_from.c_str());
  curl_easy_perform(curl_handle);
  curl_slist_free_all(sl);
  std::cout << "\n" << curl_reciept << std::endl;
  auto as_json = nlohmann::json::parse(curl_reciept);
  if (as_json.contains("message") && as_json["message"] == "You are being rate limited.") {
    sleep((int)(as_json["retry_after"].template get<double>() + 1) * 2);
    return api(endpoint, is_post, read_from);
  }
  return as_json;
}

std::vector<std::string> get_users(const std::string &msg_id, int answer_id) {
  std::string base = "/channels/" + channel_id + "/polls/" + msg_id + "/answers/" + std::to_string(answer_id) + "?limit=100";
  std::vector<std::string> all_users;
  nlohmann::json returned_list = api(base, false)["users"];
  while (returned_list.size() > 0) {
    for (auto user : returned_list)
      all_users.push_back(user["id"].template get<std::string>());
    returned_list = api(base + "&after=" + all_users.back(), false)["users"];
  }
  return all_users;
}

void process_all_pending() {
  std::vector<std::string> pending_msgs;
  sqlite3_exec(dbc,
    "select msg_id from active_polls",
    [](void *ptr, int, char **row, char **) {
      ((std::vector<std::string> *)ptr)->push_back(std::string(row[0]));
      return 0;
    }, &pending_msgs, 0);

  for (const std::string &msg_id : pending_msgs) {

    nlohmann::json result;
    int sleep_time = 1;
    while (true) {
      result = api("/channels/" + channel_id + "/messages/" + msg_id, false);
      if (result["poll"]["results"]["is_finalized"].template get<bool>())
        break;
      sleep(sleep_time);
      sleep_time *= 2;
    }

    auto e1 = result["poll"]["answers"][0];
    auto e2 = result["poll"]["answers"][1];

    std::string e1_text = e1["poll_media"]["text"];
    std::string e2_text = e2["poll_media"]["text"];

    std::vector<std::string> e1_users = get_users(msg_id, e1["answer_id"]);
    std::vector<std::string> e2_users = get_users(msg_id, e2["answer_id"]);

    std::string e1_users_text = e1_users.size() > 0 ? e1_users[0] : "";
    for (unsigned i = 1; i < e1_users.size(); ++i)
      e1_users_text += "," + e1_users[i];

    std::string e2_users_text = e2_users.size() > 0 ? e2_users[0] : "";
    for (unsigned i = 1; i < e2_users.size(); ++i)
      e2_users_text += "," + e2_users[i];

    bool e1_advances = e1_users.size() >= e2_users.size();
    bool e2_advances = e2_users.size() >= e1_users.size();

    sqlite3_exec(dbc, "begin transaction", 0, 0, 0);

    set_round(e1_text, e1_advances ? current_round_number + 1 : 0);
    set_round(e2_text, e2_advances ? current_round_number + 1 : 0);
    set_in_poll(e1_text, false);
    set_in_poll(e2_text, false);

    sqlite3_stmt *delete_active;
    sqlite3_prepare_v2(dbc,
      "delete from active_polls where msg_id = ?",
      -1, &delete_active, 0);
    sqlite3_bind_text(delete_active, 1, msg_id.c_str(), -1, SQLITE_TRANSIENT);
    while (sqlite3_step(delete_active) != SQLITE_DONE)
      ;
    sqlite3_finalize(delete_active);

    sqlite3_stmt *insert_complete;
    sqlite3_prepare_v2(dbc,
      "insert into past_polls (round, entry_1, entry_2, users_1, users_2) "
        "values (?, ?, ?, ?, ?)",
      -1, &insert_complete, 0);
    sqlite3_bind_int(insert_complete, 1, current_round_number);
    sqlite3_bind_text(insert_complete, 2, e1_text.c_str(), -1, SQLITE_TRANSIENT);
    sqlite3_bind_text(insert_complete, 3, e2_text.c_str(), -1, SQLITE_TRANSIENT);
    sqlite3_bind_text(insert_complete, 4, e1_users_text.c_str(), -1, SQLITE_TRANSIENT);
    sqlite3_bind_text(insert_complete, 5, e2_users_text.c_str(), -1, SQLITE_TRANSIENT);
    while (sqlite3_step(insert_complete) != SQLITE_DONE)
      ;
    sqlite3_finalize(insert_complete);

    sqlite3_exec(dbc, "end transaction", 0, 0, 0);

  }
}

int seconds_to_midnight() {
  int seconds = (post_time_utc - time(0)) % 86400;
  return seconds < 0 ? seconds + 86400 : seconds;
}

void post_poll(std::string e1, std::string e2, int poll_no) {
  nlohmann::json body = {
    { "poll", {
      { "question", { { "text", "Today's Poll #" + std::to_string(poll_no) } } },
      { "answers", {
        { { "poll_media", { { "text", e1 } } } },
        { { "poll_media", { { "text", e2 } } } } } },
      { "duration", seconds_to_midnight() / 3600 } } } };

  std::string msg_id = api("/channels/" + channel_id + "/messages", true, body.dump())["id"];

  sqlite3_exec(dbc, "begin transaction", 0, 0, 0);

  set_in_poll(e1, true);
  set_in_poll(e2, true);

  sqlite3_stmt *insert_active;
  sqlite3_prepare_v2(dbc,
    "insert into active_polls (msg_id) values (?)",
    -1, &insert_active, 0);
  sqlite3_bind_text(insert_active, 1, msg_id.c_str(), -1, SQLITE_TRANSIENT);
  while (sqlite3_step(insert_active) != SQLITE_DONE)
    ;
  sqlite3_finalize(insert_active);

  sqlite3_exec(dbc, "end transaction", 0, 0, 0);
}

void send_message(std::string str) {
  nlohmann::json body = { { "content", str } };
  api("/channels/" + channel_id + "/messages", true, body.dump());
}

int main() {

  prng.seed(time(0));

  const char *env = getenv("BOT_TOKEN");
  if (!env) {
    std::cerr << "Please set the BOT_TOKEN environment variable." << std::endl;
    exit(1);
  }
  bot_token = std::string(env);

  env = getenv("CHANNEL_ID");
  if (!env) {
    std::cerr << "Please set the CHANNEL_ID environment variable." << std::endl;
    exit(1);
  }
  channel_id = std::string(env);

  env = getenv("POLLS_PER_DAY");
  if (!env) {
    std::cerr << "Please set the POLLS_PER_DAY environment variable." << std::endl;
    exit(1);
  }
  polls_per_day = atoi(env);
  if (polls_per_day <= 0) {
    std::cerr << "POLLS_PER_DAY must be a positive integer." << std::endl;
    exit(1);
  }

  env = getenv("POST_TIME_UTC");
  if (!env) {
    std::cerr << "Please set the POST_TIME_UTC environment variable (seconds past midnight)." << std::endl;
    exit(1);
  }
  post_time_utc = atoi(env);

  curl_global_init(CURL_GLOBAL_DEFAULT);
  curl_handle = curl_easy_init();

  sqlite3_open("database.db", &dbc);

  bool have_entries = false;
  sqlite3_exec(dbc,
    "select 1 from sqlite_master where type = 'table' and name = 'entries'",
    [](void *ptr, int, char **, char **) { *(bool *)ptr = true; return 0; },
    &have_entries, 0);

  if (!have_entries)
    init_db();

  load_on_round();
  load_left_for_round();

  while (true) {
    sleep(seconds_to_midnight());
    sleep(2);

    process_all_pending();

    std::optional<std::string> advanced = {};
    if (left_for_this_round.size() == 1) {
      set_round(left_for_this_round[0], current_round_number + 1);
      advanced = std::move(left_for_this_round[0]);
      left_for_this_round.clear();
    }

    if (left_for_this_round.size() == 0) {
      ++current_round_number;
      save_on_round();
      load_left_for_round();

      if (left_for_this_round.size() == 1) {
        for (char &c : left_for_this_round[0])
          c = toupper(c);
        send_message("# WINNER WINNER\n## " + left_for_this_round[0]);
        exit(0);
      }

      std::string msg = "# ROUND " + std::to_string(current_round_number) + '\n';
      if (advanced)
        msg += "The only entry left in the previous round was \"" + *advanced + "\", so it advanced. ";
      msg += "There are " + std::to_string(left_for_this_round.size()) + " entries remaining:";
      std::string msg_backup = msg;
      for (int i = left_for_this_round.size() - 1; i >= 0; --i)
        msg += "\n* \"" + left_for_this_round[i] + "\"";
      if (msg.size() > 2000) {
        msg = std::move(msg_backup);
        msg.back() = '.';
      }
      send_message(msg);
    }

    for (int poll_no = 1; poll_no <= polls_per_day && left_for_this_round.size() >= 2; ++poll_no) {
      std::string e1 = std::move(left_for_this_round.back());
      left_for_this_round.pop_back();
      std::string e2 = std::move(left_for_this_round.back());
      left_for_this_round.pop_back();
      post_poll(e1, e2, poll_no);
    }
  }

  return 0;

}