#include <iostream>
#include <thread>
#include <sys/socket.h>
#include <netinet/in.h>
#include <fcntl.h>
#include <linux/if.h>
#include <linux/ip.h>
#include <linux/if_tun.h>
#include <sys/ioctl.h>
#include <unistd.h>
#include <arpa/inet.h>
#include <boost/program_options.hpp>
#include <netdb.h>
//#include <net/checksum.h>

struct Meta {
    unsigned char src_id;
    unsigned char dst_id;
    unsigned short reversed;
};

#define decrypt_package encrypt_package

unsigned char local_id;
unsigned char remote_id;
sockaddr_in remote_addr{.sin_family = AF_INET};

/* Checksum a block of data */
uint16_t csum(uint16_t *packet, int packlen) {

    unsigned long sum = 0;

    while (packlen > 1) {

        sum += *(packet++);
        packlen -= 2;
    }

    if (packlen > 0)
        sum += *(unsigned char *) packet;

    /* TODO: this depends on byte order */

    while (sum >> 16)
        sum = (sum & 0xffff) + (sum >> 16);

    return (uint16_t) ~sum;
}

char *secret;
size_t secret_length;

void encrypt_package(unsigned char *buffer, size_t length) {
    for (auto i = 0; i < length; i++) {
        buffer[i] ^= secret[i % secret_length];
    }
}

// internet -> tun
void inbound(int raw, int tun) {
    unsigned char buffer[ETH_DATA_LEN];
    sockaddr_in address{.sin_family = AF_INET};
    socklen_t address_length;
    size_t packet_length;
    while ((packet_length = recvfrom(raw, buffer, sizeof(buffer), 0, (sockaddr *) &address, &address_length)) >= 0) {
        auto *packet = (iphdr *) buffer;
        auto overhead = packet->ihl * 4;
        auto payload = buffer + overhead;
        auto meta = (Meta *) payload;
        if (!(meta->src_id == remote_id && meta->dst_id == local_id && meta->reversed == 0)) continue;
        auto inner = (iphdr *) (payload + sizeof(Meta));
        auto payload_length = packet_length - overhead - sizeof(Meta);
        decrypt_package((unsigned char *) inner, payload_length);
//        if (ip_fast_csum(inner, inner->ihl)) continue;
        if (csum((uint16_t *) inner, inner->ihl * 4)) continue;
//        std::cout << "packet_length " << packet_length
//                  << " tot_len " << ntohs(packet->tot_len)
//                  << " inner->tot_len " << ntohs(inner->tot_len)
//                  << " from " << inet_ntoa(address.sin_addr) << std::endl;
        remote_addr = address;

        if (write(tun, inner, payload_length) < 0) {
            perror("inbound write");
        }
    }
    perror("inbound read");
}

// tun -> internet
void outbound(int raw, int tun) {
    unsigned char buffer[ETH_DATA_LEN];
    auto meta = (Meta *) buffer;
    meta->src_id = local_id;
    meta->dst_id = remote_id;
    meta->reversed = 0;
    auto inner = buffer + sizeof(Meta);
    size_t packet_length;
    while ((packet_length = read(tun, inner, sizeof(buffer) - sizeof(Meta))) >= 0) {
        if (!remote_addr.sin_addr.s_addr) continue;
        encrypt_package(inner, packet_length);
        if (sendto(raw, buffer, packet_length + sizeof(Meta), 0, (sockaddr *) &remote_addr, sizeof(remote_addr)) < 0) {
            perror("outbound write");
        }
    }
    perror("outbound read");
}

int main(int argc, char *argv[]) {

    local_id = atoi(getenv("local_id"));
    remote_id = atoi(getenv("remote_id"));
    auto endpoint = getenv("endpoint");
    unsigned char proto = atoi(getenv("proto"));
    secret = getenv("secret");
    auto dev = getenv("dev");
    auto up = getenv("up");

    if (endpoint != nullptr) {
        remote_addr.sin_addr.s_addr = ((in_addr *) gethostbyname(endpoint)->h_name)->s_addr;
    }

    secret_length = strlen(secret);

    ifreq ifr{};
    ifr.ifr_flags = IFF_TUN | IFF_NO_PI;
    strncpy(ifr.ifr_name, dev, IFNAMSIZ);

    auto raw = socket(AF_INET, SOCK_RAW, proto);
    if (raw < 0) {
        perror("socket init error");
        return -1;
    }
    auto tun = open("/dev/net/tun", O_RDWR);
    if (tun < 0) {
        perror("tun init error");
        return -1;
    }

    if (ioctl(tun, TUNSETIFF, &ifr) < 0) {
        perror("ioctl error");
        return -1;
    }

    system(up);

    std::thread t1(inbound, raw, tun);
    std::thread t2(outbound, raw, tun);
    t1.join();
    t2.join();

    return 0;
}

