#include <iostream>
#include <thread>
#include <sys/socket.h>
#include <linux/if_tun.h>
#include <unistd.h>
#include <cstring>
#include <nlohmann/json.hpp>
#include <linux/ipv6.h>
#include <vector>
#include <linux/ip.h>

#include "checksum.h"
#include "Config.h"
#include "Router.h"

using json = nlohmann::json;

struct Meta {
    unsigned char src_id;
    unsigned char dst_id;
    unsigned short reserved;
};

Config config;

// internet -> tun
void inbound(int raw) {
    unsigned char buffer[ETH_DATA_LEN];
    sockaddr_storage address;
    socklen_t address_length = sizeof(address);
    size_t packet_length;
    while ((packet_length = recvfrom(raw, buffer, sizeof(buffer), 0, (sockaddr *) &address, &address_length)) >= 0) {
        auto *packet = (iphdr *) buffer;
        auto overhead = packet->ihl * 4;
        auto payload = buffer + overhead;
        auto meta = (Meta *) payload;
        if (!(Router::all.contains(meta->src_id) && meta->dst_id == config.local_id && meta->reserved == 0)) continue;
        auto router = Router::all[meta->src_id];
        auto inner = (payload + sizeof(Meta));
        auto payload_length = packet_length - overhead - sizeof(Meta);
        router->decrypt(inner, payload_length);
        switch (((ipv6hdr *) inner)->version) {
            case 4:
                if (csum((uint16_t *) inner, ((iphdr *) inner)->ihl * 4)) continue;
                break;
            case 6:
                // ipv6 don't have checksum, do nothing
                break;
            default:
                continue;
        }
        router->remote_addr = address;

        if (write(router->tun, inner, payload_length) < 0) {
            perror("inbound write");
        }
    }
    perror("inbound read");
}

// tun -> internet
void outbound(int tun) {
    auto router = Router::tuns[tun];
    unsigned char buffer[ETH_DATA_LEN];
    auto meta = (Meta *) buffer;
    meta->src_id = config.local_id;
    meta->dst_id = router->config.remote_id;
    meta->reserved = 0;
    auto inner = buffer + sizeof(Meta);
    size_t packet_length;
    while ((packet_length = read(tun, inner, sizeof(buffer) - sizeof(Meta))) >= 0) {
        if (!router->remote_addr.ss_family) continue;
        router->encrypt(inner, packet_length);
        if (setsockopt(router->raw, SOL_SOCKET, SO_MARK, &router->config.mark, sizeof(router->config.mark)) < 0) {
            perror("setsockopt error");
        }
        if (sendto(router->raw, buffer, packet_length + sizeof(Meta), 0, (sockaddr *) &router->remote_addr,
                   sizeof(router->remote_addr)) < 0) {
            perror("outbound write");
        }
    }
    perror("outbound read");
}

int main(int argc, char *argv[]) {
    json data = json::parse(argv[1]);
    config = data.get<Config>();
    Router::create_secret(config.local_secret, Router::local_secret);
    for (const auto &item: config.routers) new Router(item);

    std::vector<std::thread> threads;
    for (auto &[_, router]: Router::all)threads.emplace_back(outbound, router->tun);
    for (auto &[_, raw]: Router::raws)threads.emplace_back(inbound, raw);
    for (auto &thread: threads) thread.join();

    return 0;
}
