#include <hilbert/kernel/app-memory.hpp>
#include <hilbert/kernel/paging.hpp>
#include <hilbert/kernel/panic.hpp>

namespace hilbert::kernel {

  app_memory::app_memory() {

    uint64_t p3_paddr;
    paging::map_new_kernel_page(p3, p3_paddr);
    paging::map_new_kernel_page(p4, p4_paddr);

    for (int i = 0; i < 512; ++i) {
      p4[i] = 0;
      p3[i] = 0;
      p2s[i] = 0;
      p1s[i] = 0;
      pram_pages_to_free_on_exit[i] = 0;
    }

    p4[0] = paging::encode_pte(p3_paddr, true, true, true);
    p4[511] = paging::kernel_p4e;

  }

  app_memory::~app_memory() {

    //first we see if the p2s exist
    for (int p3i = 0; p3i < 512; ++p3i)
      if (p3[p3i]) {

        //now we see if the p1s under this p2 exist
        for (int p2i = 0; p2i < 512; ++p2i)
          if (p2s[p3i][p2i]) {

            //we see if the pages under this p1 need to be freed
            for (int p1i = 0; p1i < 512; ++p1i)
              if (pram_pages_to_free_on_exit[p3i][p2i][p1i])
                paging::free_pram_page(
                  paging::pte_to_paddr(p1s[p3i][p2i][p1i]));

            //we free the p1 and the pram list
            paging::free_pram_page(paging::pte_to_paddr(p2s[p3i][p2i]));
            paging::unmap_kernel_page((uint64_t)p1s[p3i][p2i]);
            delete[] pram_pages_to_free_on_exit[p3i][p2i];

          }

        //free the p2, the p1 list, and the pram list list
        paging::free_pram_page(paging::pte_to_paddr(p3[p3i]));
        paging::unmap_kernel_page((uint64_t)p2s[p3i]);
        delete[] p1s[p3i];
        delete[] pram_pages_to_free_on_exit[p3i];

      }

    //finally, we free the p3 and the p4
    paging::free_pram_page(paging::pte_to_paddr(p4[0]));
    paging::unmap_kernel_page((uint64_t)p3);
    paging::free_pram_page(p4_paddr);
    paging::unmap_kernel_page((uint64_t)p4);

  }

  void app_memory::map_page(uint64_t vaddr, uint64_t paddr,
    bool write, bool execute, bool free_pram_on_exit) {

    int p1i = (vaddr >> 12) & 511;
    int p2i = (vaddr >> 21) & 511;
    int p3i = (vaddr >> 30) & 511;

    if (p2s[p3i] == 0) {
      uint64_t new_p2_paddr;
      paging::map_new_kernel_page(p2s[p3i], new_p2_paddr);
      p1s[p3i] = new v_page_table[512];
      pram_pages_to_free_on_exit[p3i] = new bool *[512];
      for (int i = 0; i < 512; ++i) {
        p2s[p3i][i] = 0;
        p1s[p3i][i] = 0;
        pram_pages_to_free_on_exit[p3i][i] = 0;
      }
      p3[p3i] = paging::encode_pte(new_p2_paddr, true, true, true);
    }

    if (p1s[p3i][p2i] == 0) {
      uint64_t new_p1_paddr;
      paging::map_new_kernel_page(p1s[p3i][p2i], new_p1_paddr);
      pram_pages_to_free_on_exit[p3i][p2i] = new bool[512];
      for (int i = 0; i < 512; ++i) {
        p1s[p3i][p2i][i] = 0;
        pram_pages_to_free_on_exit[p3i][p2i][i] = false;
      }
      p2s[p3i][p2i] = paging::encode_pte(new_p1_paddr, true, true, true);
    }

    p1s[p3i][p2i][p1i] = paging::encode_pte(paddr, true, write, execute);
    pram_pages_to_free_on_exit[p3i][p2i][p1i] = free_pram_on_exit;

  }

  void app_memory::unmap_page(uint64_t vaddr) {
    int p1i = (vaddr >> 12) & 511;
    int p2i = (vaddr >> 21) & 511;
    int p3i = (vaddr >> 30) & 511;
    if (pram_pages_to_free_on_exit[p3i][p2i][p1i]) {
      pram_pages_to_free_on_exit[p3i][p2i][p1i] = false;
      paging::free_pram_page(paging::pte_to_paddr(p1s[p3i][p2i][p1i]));
    }
    p1s[p3i][p2i][p1i] = 0;
  }

  bool app_memory::valid_to_read(
    uint64_t vaddr_start, uint64_t vaddr_end, bool and_write) const {
    if (vaddr_start > vaddr_end || vaddr_end > 0x8000000000)
      return false;
    vaddr_start = (vaddr_start / 4096) * 4096;
    vaddr_end = (((vaddr_end - 1) / 4096) + 1) * 4096;
    for (uint64_t vaddr = vaddr_start; vaddr < vaddr_end; ++vaddr) {
      int p1i = (vaddr >> 12) & 511;
      int p2i = (vaddr >> 21) & 511;
      int p3i = (vaddr >> 30) & 511;
      if (!p1s[p3i] || !p1s[p3i][p2i] || !(and_write
            ? (p1s[p3i][p2i][p1i] & 0x1) : p1s[p3i][p2i][p1i]))
        return false;
    }
    return true;
  }

  uint64_t app_memory::get_free_vaddr_pages(uint64_t count) {
    uint64_t vaddr = 0x1000;
    uint64_t run = 0;
    while (true) {
      if (run == count)
        return vaddr;
      if (vaddr + (run + 1) * 4096 > 0x4000000000)
        //TODO: handle out of virtual memory
        panic(0x9af5e6);
      if (valid_to_read(vaddr + run * 4096, vaddr + (run + 1) * 4096, false)) {
        vaddr += (run + 1) * 4096;
        run = 0;
      }
      else
        ++run;
    }
  }

  uint64_t app_memory::map_new_stack() {
    for (uint64_t base_vaddr = 0x4000000000;
         base_vaddr < 0x8000000000; base_vaddr += 0x1000000)
      if (!valid_to_read(base_vaddr + 4096, base_vaddr + 8192, false)) {

        for (uint64_t vaddr = base_vaddr + 4096;
             vaddr < base_vaddr + 0x1000000; vaddr += 4096) {

          uint8_t *kvaddr;
          uint64_t paddr;
          paging::map_new_kernel_page(kvaddr, paddr);
          for (int i = 0; i < 4096; ++i)
            kvaddr[i] = 0;
          paging::unmap_kernel_page(kvaddr);
          map_page(vaddr, paddr, true, false, true);

        }
        return base_vaddr + 0x1000000;

      }
    //TODO: handle out of stacks
    panic(0x9af5e6);

  }

  void app_memory::unmap_stack(uint64_t top) {
    for (uint64_t vaddr = top - 0xfff000; vaddr < top; vaddr += 4096)
      unmap_page(vaddr);
  }

  uint64_t app_memory::count_mapped_vram_pages() const {
    uint64_t count = 0;
    for (int p3i = 0; p3i < 512; ++p3i)
      if (p3[p3i])
        for (int p2i = 0; p2i < 512; ++p2i)
          if (p2s[p3i][p2i])
            for (int p1i = 0; p1i < 512; ++p1i)
              if (p1s[p3i][p2i][p1i])
                ++count;
    return count;
  }

}