bits 64

global load_gdt_and_idt

section .rodata

;0x28 picked to align with limine choice

;0x18 - tss
;0x28 - kernel code
;0x30 - kernel data
;0x38 - user data
;0x40 - user code

tss:
  times 9 dd 0
  dq 0xffffffffffeff000
  times 15 dd 0

gdtr:
  dw 0x47
  dq gdt

idtr:
  dw 4095
  dq idt

section .bss

idt:
  resq 512

global exception_info
exception_info:
.rax:
  resq 1
.rbx:
  resq 1
.rcx:
  resq 1
.rdx:
  resq 1
.rdi:
  resq 1
.rsi:
  resq 1
.rbp:
  resq 1
.rsp:
  resq 1
.r8:
  resq 1
.r9:
  resq 1
.r10:
  resq 1
.r11:
  resq 1
.r12:
  resq 1
.r13:
  resq 1
.r14:
  resq 1
.r15:
  resq 1
.cr2:
  resq 1
.cr3:
  resq 1
.rip:
  resq 1
.rflags:
  resq 1
.error:
  resq 1
.has_error:
  resb 1
.exception_number:
  resb 1

section .rodata

has_error_code:
  db 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 1, 1, 1, 1, 1, 0

exception_isrs:
  dq exception_00, exception_01, exception_02, exception_03
  dq exception_04, exception_05, exception_06, exception_07
  dq exception_08, exception_09, exception_0a, exception_0b
  dq exception_0c, exception_0d, exception_0e, exception_0f

section .text

extern print_exception

exception_00:
  mov byte [exception_info.exception_number], 0x00
  jmp exception_common
exception_01:
  mov byte [exception_info.exception_number], 0x01
  jmp exception_common
exception_02:
  mov byte [exception_info.exception_number], 0x02
  jmp exception_common
exception_03:
  mov byte [exception_info.exception_number], 0x03
  jmp exception_common
exception_04:
  mov byte [exception_info.exception_number], 0x04
  jmp exception_common
exception_05:
  mov byte [exception_info.exception_number], 0x05
  jmp exception_common
exception_06:
  mov byte [exception_info.exception_number], 0x06
  jmp exception_common
exception_07:
  mov byte [exception_info.exception_number], 0x07
  jmp exception_common
exception_08:
  mov byte [exception_info.exception_number], 0x08
  jmp exception_common
exception_09:
  mov byte [exception_info.exception_number], 0x09
  jmp exception_common
exception_0a:
  mov byte [exception_info.exception_number], 0x0a
  jmp exception_common
exception_0b:
  mov byte [exception_info.exception_number], 0x0b
  jmp exception_common
exception_0c:
  mov byte [exception_info.exception_number], 0x0c
  jmp exception_common
exception_0d:
  mov byte [exception_info.exception_number], 0x0d
  jmp exception_common
exception_0e:
  mov byte [exception_info.exception_number], 0x0e
  jmp exception_common
exception_0f:
  mov byte [exception_info.exception_number], 0x0f
  jmp exception_common

exception_common:
  mov qword [exception_info.rax], rax

  movzx rax, byte [exception_info.exception_number]
  mov al, byte [has_error_code + rax]
  test al, al
  jz .no_error_code

  mov byte [exception_info.has_error], 1
  pop rax
  mov qword [exception_info.error], rax
  jmp .post_error_code

.no_error_code:
  mov byte [exception_info.has_error], 0

.post_error_code:
  mov qword [exception_info.rbx], rbx
  mov qword [exception_info.rcx], rcx
  mov qword [exception_info.rdx], rdx
  mov qword [exception_info.rdi], rdi
  mov qword [exception_info.rsi], rsi
  mov qword [exception_info.rbp], rbp
  mov qword [exception_info.r8], r8
  mov qword [exception_info.r9], r9
  mov qword [exception_info.r10], r10
  mov qword [exception_info.r11], r11
  mov qword [exception_info.r12], r12
  mov qword [exception_info.r13], r13
  mov qword [exception_info.r14], r14
  mov qword [exception_info.r15], r15

  pop rax
  mov qword [exception_info.rip], rax
  pop rax
  pop rax
  mov qword [exception_info.rflags], rax
  pop rax
  mov qword [exception_info.rsp], rax

  mov rax, cr2
  mov qword [exception_info.cr2], rax
  mov rax, cr3
  mov qword [exception_info.cr3], rax

  jmp print_exception

set_isr:
;rdi - index
;rsi - isr pointer

  shl rdi, 4
  add rdi, idt

  mov word [rdi], si
  shr rsi, 16
  mov word [rdi + 6], si
  shr rsi, 16
  mov dword [rdi + 8], esi

  mov byte [rdi + 5], 0x8e
  mov word [rdi + 2], 0x28
  mov byte [rdi + 4], 1

  ret

section .data

gdt:
  dq 0
  dq 0
  dq 0
.tss:
  dq 0x0000e90000000067
  dq 0;tss is 2 qwords wide
  dq 0x002f98000000ffff
  dq 0x002f92000000ffff
  dq 0x002ff2000000ffff
  dq 0x002ff8000000ffff

section .bss

section .text

isr_start:

  push rcx
  mov rcx, qword [rsp + 8]
  mov qword [rsp + 8], rax
  push rdx
  push rdi
  push rsi
  push r8
  push r9
  push r10
  push r11
  push rcx

  ;this is a complete hack but it works
  mov rdi, qword [rsp + 10 * 8 + 4 * 8]
  cmp rdi, 0x38
  je .fix_ss
  ret

.fix_ss:
  mov qword [rsp + 10 * 8 + 4 * 8], 0x3b
  ret

isr_end:

  pop rcx
  pop r11
  pop r10
  pop r9
  pop r8
  pop rsi
  pop rdi
  pop rdx
  mov rax, qword [rsp + 8]
  mov qword [rsp + 8], rcx
  pop rcx

  ret

extern on_rtc_interrupt

rtc_isr:

  call isr_start

  call on_rtc_interrupt

  mov al, 0x20
  out 0x20, al
  out 0xa0, al

  call isr_end

  iretq

extern on_keyboard_interrupt

keyboard_isr:

  call isr_start

  call wait_read_ps2
  in al, 0x60
  mov dil, al

  call on_keyboard_interrupt

  mov al, 0x20
  out 0x20, al

  call isr_end

  iretq

extern on_mouse_interrupt

mouse_isr:

  call isr_start

  call wait_read_ps2
  in al, 0x60
  mov dil, al

  call on_mouse_interrupt

  mov al, 0x20
  out 0x20, al
  out 0xa0, al

  call isr_end

  iretq

wait_send_ps2:
  in al, 0x64
  test al, 0x02
  jnz wait_send_ps2
  ret

wait_read_ps2:
  in al, 0x64
  test al, 0x01
  jz wait_send_ps2
  ret

load_gdt_and_idt:

  ;fill exception entries in idt

  mov rcx, 16

.loop:

  mov rdi, rcx
  dec rdi
  mov rsi, qword [exception_isrs + rdi * 8]
  call set_isr

  loop .loop

  ;reset pic and map irqs to 0x20 - 0x2f

  mov al, 0x11
  out 0x20, al
  mov al, 0x20
  out 0x21, al
  mov al, 0x04
  out 0x21, al
  mov al, 0x01
  out 0x21, al
  mov al, 0xf9 ;mask all but irqs 1 and 2
  out 0x21, al

  mov al, 0x11
  out 0xa0, al
  mov al, 0x28
  out 0xa1, al
  mov al, 0x02
  out 0xa1, al
  mov al, 0x01
  out 0xa1, al
  mov al, 0xee ;mask all but irqs 8 and 12
  out 0xa1, al

  ;register rtc interrupt

  mov rdi, 0x28
  mov rsi, rtc_isr
  call set_isr

  ;register keyboard and mouse interrupts

  mov rdi, 0x21
  mov rsi, keyboard_isr
  call set_isr

  mov rdi, 0x2c
  mov rsi, mouse_isr
  call set_isr

  ;set ps2 config

  call wait_send_ps2
  mov al, 0x60
  out 0x64, al

  call wait_send_ps2
  mov al, 0x03
  out 0x60, al

  ;set mouse defaults

  call wait_send_ps2
  mov al, 0xd4
  out 0x64, al

  call wait_send_ps2
  mov al, 0xf6
  out 0x60, al

  call wait_read_ps2
  in al, 0x60

  ;enable mouse reporting

  call wait_send_ps2
  mov al, 0xd4
  out 0x64, al

  call wait_send_ps2
  mov al, 0xf4
  out 0x60, al

  call wait_read_ps2
  in al, 0x60

  ;make tss entry in gdt

  mov rax, tss

  mov word [gdt.tss + 2], ax
  shr rax, 16
  mov byte [gdt.tss + 4], al
  mov byte [gdt.tss + 7], ah
  shr rax, 16
  mov dword [gdt.tss + 8], eax

  ;load gdt, idt, tss

  lgdt [gdtr]
  lidt [idtr]
  mov ax, 0x18
  ltr ax

  ret