diff options
author | Benji Dial <benji@benjidial.net> | 2024-05-06 03:05:16 -0400 |
---|---|---|
committer | Benji Dial <benji@benjidial.net> | 2024-05-06 03:05:16 -0400 |
commit | 379b35093097f0091e08e128c161000ead3d1e19 (patch) | |
tree | a20a390aec708c2176bee77e7e93367d63e3655c /source.cpp | |
parent | f05e422f697c94edc2983b857dc814aedb7aaa9b (diff) | |
download | bracket-bot-379b35093097f0091e08e128c161000ead3d1e19.tar.gz |
new version
Diffstat (limited to 'source.cpp')
-rw-r--r-- | source.cpp | 384 |
1 files changed, 384 insertions, 0 deletions
diff --git a/source.cpp b/source.cpp new file mode 100644 index 0000000..9b5a122 --- /dev/null +++ b/source.cpp @@ -0,0 +1,384 @@ +#include <curl/curl.h> +#include <algorithm> +#include <sqlite3.h> +#include <iostream> +#include <optional> +#include <unistd.h> +#include <fstream> +#include <random> +#include <vector> + +#include "json.hpp" + +std::default_random_engine prng; +sqlite3 *dbc; +CURL *curl_handle; + +std::string bot_token; +std::string channel_id; +int polls_per_day; +long post_time_utc; + +void init_db() { + + sqlite3_exec(dbc, + "begin transaction; " + "create table misc (" + "on_round INTEGER NOT NULL); " + "insert into misc (on_round) " + "values (0); " + "create table entries (" + "name TEXT UNIQUE NOT NULL, " + "round INTEGER NOT NULL, " + "in_active_poll INTEGER NOT NULL); " + "create table active_polls (" + "msg_id TEXT UNIQUE NOT NULL); " + "create table past_polls (" + "round INTEGER NOT NULL, " + "entry_1 TEXT NOT NULL, " + "entry_2 TEXT NOT NULL, " + "users_1 TEXT NOT NULL, " + "users_2 TEXT NOT NULL)", + 0, 0, 0); + + sqlite3_stmt *insert_entry; + sqlite3_prepare_v2(dbc, + "insert into entries (name, round, in_active_poll)" + "values (?, 1, 0)", -1, &insert_entry, 0); + + std::ifstream entries("entries.txt"); + if (!entries) { + std::cerr << "Please put the list of entries into entries.txt, separated by newlines." << std::endl; + exit(1); + } + + std::string entry; + while (std::getline(entries, entry)) { + sqlite3_bind_text(insert_entry, 1, entry.c_str(), -1, SQLITE_TRANSIENT); + while (sqlite3_step(insert_entry) != SQLITE_DONE) + ; + sqlite3_reset(insert_entry); + } + + sqlite3_exec(dbc, "end transaction", 0, 0, 0); + sqlite3_finalize(insert_entry); + +} + +int current_round_number; +std::vector<std::string> left_for_this_round; + +void load_on_round() { + current_round_number = -1; + sqlite3_exec(dbc, + "select on_round from misc limit 1", + [](void *, int, char **row, char **) { + current_round_number = atoi(row[0]); + return 0; + }, 0, 0); +} + +void save_on_round() { + sqlite3_stmt *update; + sqlite3_prepare_v2(dbc, "update misc set on_round = ?", -1, &update, 0); + sqlite3_bind_int(update, 1, current_round_number); + while (sqlite3_step(update) != SQLITE_DONE) + ; + sqlite3_finalize(update); +} + +void load_left_for_round() { + sqlite3_stmt *select_left; + sqlite3_prepare_v2(dbc, + "select name from entries where round = ? and in_active_poll = 0", + -1, &select_left, 0); + sqlite3_bind_int(select_left, 1, current_round_number); + + left_for_this_round.clear(); + int result; + while ((result = sqlite3_step(select_left)) != SQLITE_DONE) + if (result == SQLITE_ROW) + left_for_this_round.push_back( + std::string((const char *)sqlite3_column_text(select_left, 0))); + + sqlite3_finalize(select_left); + std::shuffle(left_for_this_round.begin(), left_for_this_round.end(), prng); +} + +void set_round(const std::string &str, int round) { + sqlite3_stmt *update; + sqlite3_prepare_v2(dbc, "update entries set round = ? where name = ?", -1, &update, 0); + sqlite3_bind_int(update, 1, round); + sqlite3_bind_text(update, 2, str.c_str(), -1, SQLITE_TRANSIENT); + while (sqlite3_step(update) != SQLITE_DONE) + ; + sqlite3_finalize(update); +} + +void set_in_poll(const std::string &str, bool value) { + sqlite3_stmt *update; + sqlite3_prepare_v2(dbc, "update entries set in_active_poll = ? where name = ?", -1, &update, 0); + sqlite3_bind_int(update, 1, value ? 1 : 0); + sqlite3_bind_text(update, 2, str.c_str(), -1, SQLITE_TRANSIENT); + while (sqlite3_step(update) != SQLITE_DONE) + ; + sqlite3_finalize(update); +} + +std::string curl_reciept; + +size_t write_callback(void *ptr, size_t, size_t count, void *) { + size_t offset = curl_reciept.size(); + curl_reciept.resize(offset + count); + memcpy(curl_reciept.data() + offset, ptr, count); + return count; +} + +nlohmann::json api(const std::string &endpoint, bool is_post, const std::string &read_from = "") { + curl_reciept.clear(); + curl_easy_setopt(curl_handle, CURLOPT_URL, ("https://discord.com/api/v10" + endpoint).c_str()); + curl_easy_setopt(curl_handle, is_post ? CURLOPT_POST : CURLOPT_HTTPGET, 1); + curl_slist *sl = 0; + sl = curl_slist_append(sl, ("Authorization: Bot " + bot_token).c_str()); + if (is_post) + sl = curl_slist_append(sl, "Content-Type: application/json"); + curl_easy_setopt(curl_handle, CURLOPT_HTTPHEADER, sl); + curl_easy_setopt(curl_handle, CURLOPT_WRITEFUNCTION, &write_callback); + if (is_post) + curl_easy_setopt(curl_handle, CURLOPT_POSTFIELDS, read_from.c_str()); + curl_easy_perform(curl_handle); + curl_slist_free_all(sl); + std::cout << "\n" << curl_reciept << std::endl; + auto as_json = nlohmann::json::parse(curl_reciept); + if (as_json.contains("message") && as_json["message"] == "You are being rate limited.") { + sleep((int)(as_json["retry_after"].template get<double>() + 1) * 2); + return api(endpoint, is_post, read_from); + } + return as_json; +} + +std::vector<std::string> get_users(const std::string &msg_id, int answer_id) { + std::string base = "/channels/" + channel_id + "/polls/" + msg_id + "/answers/" + std::to_string(answer_id) + "?limit=100"; + std::vector<std::string> all_users; + nlohmann::json returned_list = api(base, false)["users"]; + while (returned_list.size() > 0) { + for (auto user : returned_list) + all_users.push_back(user["id"].template get<std::string>()); + returned_list = api(base + "&after=" + all_users.back(), false)["users"]; + } + return all_users; +} + +void process_all_pending() { + std::vector<std::string> pending_msgs; + sqlite3_exec(dbc, + "select msg_id from active_polls", + [](void *ptr, int, char **row, char **) { + ((std::vector<std::string> *)ptr)->push_back(std::string(row[0])); + return 0; + }, &pending_msgs, 0); + + for (const std::string &msg_id : pending_msgs) { + + nlohmann::json result; + int sleep_time = 1; + while (true) { + result = api("/channels/" + channel_id + "/messages/" + msg_id, false); + if (result["poll"]["results"]["is_finalized"].template get<bool>()) + break; + sleep(sleep_time); + sleep_time *= 2; + } + + auto e1 = result["poll"]["answers"][0]; + auto e2 = result["poll"]["answers"][1]; + + std::string e1_text = e1["poll_media"]["text"]; + std::string e2_text = e2["poll_media"]["text"]; + + std::vector<std::string> e1_users = get_users(msg_id, e1["answer_id"]); + std::vector<std::string> e2_users = get_users(msg_id, e2["answer_id"]); + + std::string e1_users_text = e1_users.size() > 0 ? e1_users[0] : ""; + for (unsigned i = 1; i < e1_users.size(); ++i) + e1_users_text += "," + e1_users[i]; + + std::string e2_users_text = e2_users.size() > 0 ? e2_users[0] : ""; + for (unsigned i = 1; i < e2_users.size(); ++i) + e2_users_text += "," + e2_users[i]; + + bool e1_advances = e1_users.size() >= e2_users.size(); + bool e2_advances = e2_users.size() >= e1_users.size(); + + sqlite3_exec(dbc, "begin transaction", 0, 0, 0); + + set_round(e1_text, e1_advances ? current_round_number + 1 : 0); + set_round(e2_text, e2_advances ? current_round_number + 1 : 0); + set_in_poll(e1_text, false); + set_in_poll(e2_text, false); + + sqlite3_stmt *delete_active; + sqlite3_prepare_v2(dbc, + "delete from active_polls where msg_id = ?", + -1, &delete_active, 0); + sqlite3_bind_text(delete_active, 1, msg_id.c_str(), -1, SQLITE_TRANSIENT); + while (sqlite3_step(delete_active) != SQLITE_DONE) + ; + sqlite3_finalize(delete_active); + + sqlite3_stmt *insert_complete; + sqlite3_prepare_v2(dbc, + "insert into past_polls (round, entry_1, entry_2, users_1, users_2) " + "values (?, ?, ?, ?, ?)", + -1, &insert_complete, 0); + sqlite3_bind_int(insert_complete, 1, current_round_number); + sqlite3_bind_text(insert_complete, 2, e1_text.c_str(), -1, SQLITE_TRANSIENT); + sqlite3_bind_text(insert_complete, 3, e2_text.c_str(), -1, SQLITE_TRANSIENT); + sqlite3_bind_text(insert_complete, 4, e1_users_text.c_str(), -1, SQLITE_TRANSIENT); + sqlite3_bind_text(insert_complete, 5, e2_users_text.c_str(), -1, SQLITE_TRANSIENT); + while (sqlite3_step(insert_complete) != SQLITE_DONE) + ; + sqlite3_finalize(insert_complete); + + sqlite3_exec(dbc, "end transaction", 0, 0, 0); + + } +} + +void post_poll(std::string e1, std::string e2, int poll_no) { + nlohmann::json body = { + { "poll", { + { "question", { { "text", "Today's Poll #" + std::to_string(poll_no) } } }, + { "answers", { + { { "poll_media", { { "text", e1 } } } }, + { { "poll_media", { { "text", e2 } } } } } }, + { "duration", 23 } } } }; + + std::string msg_id = api("/channels/" + channel_id + "/messages", true, body.dump())["id"]; + + sqlite3_exec(dbc, "begin transaction", 0, 0, 0); + + set_in_poll(e1, true); + set_in_poll(e2, true); + + sqlite3_stmt *insert_active; + sqlite3_prepare_v2(dbc, + "insert into active_polls (msg_id) values (?)", + -1, &insert_active, 0); + sqlite3_bind_text(insert_active, 1, msg_id.c_str(), -1, SQLITE_TRANSIENT); + while (sqlite3_step(insert_active) != SQLITE_DONE) + ; + sqlite3_finalize(insert_active); + + sqlite3_exec(dbc, "end transaction", 0, 0, 0); +} + +void send_message(std::string str) { + nlohmann::json body = { { "content", str } }; + api("/channels/" + channel_id + "/messages", true, body.dump()); +} + +void sleep_until_midnight() { + sleep(2); + long timestamp = time(0); + int time = (post_time_utc - timestamp) % 86400; + sleep(time < 0 ? time + 86400 : time); +} + +int main() { + + prng.seed(time(0)); + + const char *env = getenv("BOT_TOKEN"); + if (!env) { + std::cerr << "Please set the BOT_TOKEN environment variable." << std::endl; + exit(1); + } + bot_token = std::string(env); + + env = getenv("CHANNEL_ID"); + if (!env) { + std::cerr << "Please set the CHANNEL_ID environment variable." << std::endl; + exit(1); + } + channel_id = std::string(env); + + env = getenv("POLLS_PER_DAY"); + if (!env) { + std::cerr << "Please set the POLLS_PER_DAY environment variable." << std::endl; + exit(1); + } + polls_per_day = atoi(env); + if (polls_per_day <= 0) { + std::cerr << "POLLS_PER_DAY must be a positive integer." << std::endl; + exit(1); + } + + env = getenv("POST_TIME_UTC"); + if (!env) { + std::cerr << "Please set the POST_TIME_UTC environment variable (seconds past midnight)." << std::endl; + exit(1); + } + post_time_utc = atoi(env); + + curl_global_init(CURL_GLOBAL_DEFAULT); + curl_handle = curl_easy_init(); + + sqlite3_open("database.db", &dbc); + + bool have_entries = false; + sqlite3_exec(dbc, + "select 1 from sqlite_master where type = 'table' and name = 'entries'", + [](void *ptr, int, char **, char **) { *(bool *)ptr = true; return 0; }, + &have_entries, 0); + + if (!have_entries) + init_db(); + + load_on_round(); + load_left_for_round(); + + while (true) { + sleep_until_midnight(); + process_all_pending(); + + std::optional<std::string> advanced = {}; + if (left_for_this_round.size() == 1) { + set_round(left_for_this_round[0], current_round_number + 1); + advanced = std::move(left_for_this_round[0]); + left_for_this_round.clear(); + } + + if (left_for_this_round.size() == 0) { + ++current_round_number; + save_on_round(); + load_left_for_round(); + + if (left_for_this_round.size() == 1) { + for (char &c : left_for_this_round[0]) + c = toupper(c); + send_message("# WINNER WINNER\n## " + left_for_this_round[0]); + exit(0); + } + + std::string msg = "# ROUND " + std::to_string(current_round_number) + '\n'; + if (advanced) + msg += "The only entry left in the previous round was \"" + *advanced + "\", so it advanced. "; + msg += "There are " + std::to_string(left_for_this_round.size()) + " entries remaining:"; + for (int i = left_for_this_round.size() - 1; i >= 0; --i) + msg += "\n* \"" + left_for_this_round[i] + "\""; + send_message(msg); + } + + for (int poll_no = 1; poll_no <= polls_per_day && left_for_this_round.size() >= 2; ++poll_no) { + std::string e1 = std::move(left_for_this_round.back()); + left_for_this_round.pop_back(); + std::string e2 = std::move(left_for_this_round.back()); + left_for_this_round.pop_back(); + post_poll(e1, e2, poll_no); + } + } + + return 0; + +} |