mod router;

use crate::router::{Router, SECRET_LENGTH};
use std::collections::HashMap;
use std::env;
use std::error::Error;
use std::intrinsics::transmute;
use std::io::{Read, Write};
use std::mem::MaybeUninit;
use std::net::ToSocketAddrs;
use std::sync::Arc;

#[repr(C)]
pub struct Meta {
    pub src_id: u8,
    pub dst_id: u8,
    pub reversed: u16,
}

use serde::{Deserialize, Serialize};

#[derive(Debug, Serialize, Deserialize)]
pub struct ConfigRouter {
    pub remote_id: u8,
    pub proto: i32,
    pub family: u8,
    pub mark: u32,
    pub endpoint: String,
    pub remote_secret: String,
    pub dev: String,
    pub up: String,
}

#[derive(Debug, Serialize, Deserialize)]
pub struct Config {
    pub local_id: u8,
    pub local_secret: String,
    pub routers: Vec<ConfigRouter>,
}
use crossbeam_utils::thread;
use lazy_static::lazy_static;
use pnet::packet::ipv4::Ipv4Packet;
use socket2::Socket;

lazy_static! {
    static ref config: Config = serde_json::from_str(env::args().nth(0).unwrap().as_str()).unwrap();
    static ref local_secret: [u8; SECRET_LENGTH] =
        Router::create_secret(config.local_secret.as_str()).unwrap();
}

fn main() -> Result<(), Box<dyn Error>> {
    let mut sockets: HashMap<u16, Arc<Socket>> = HashMap::new();
    let mut routers: HashMap<u8, Router> = config
        .routers
        .iter()
        .map(|c| Router::new(c, &mut sockets).map(|router| (c.remote_id, router)))
        .collect::<Result<_, _>>()?;

    thread::scope(|s| {
        for router in routers.values_mut() {
            s.spawn({
                |_| {
                    let mut buffer = [0u8; 1500 - 20]; // minus typical IP header space
                    let meta_size = size_of::<Meta>();

                    // Pre-initialize with our Meta header (local -> remote)
                    let meta = Meta {
                        src_id: config.local_id,
                        dst_id: router.config.remote_id,
                        reversed: 0,
                    };
                    // Turn the Meta struct into bytes
                    let meta_bytes = unsafe {
                        std::slice::from_raw_parts(&meta as *const Meta as *const u8, meta_size)
                    };
                    buffer[..meta_size].copy_from_slice(meta_bytes);

                    loop {
                        match router.tun_reader.read(&mut buffer[meta_size..]) {
                            Ok(n) if n > 0 => {
                                // If we have a known remote address, encrypt and send
                                if let Some(ref addr) = router.endpoint {
                                    router.encrypt(&mut buffer[meta_size..meta_size + n]);
                                    #[cfg(target_os = "linux")]
                                    router.socket.set_mark(router.config.mark)?;
                                    if let Err(e) =
                                        router.socket.send_to(&buffer[..meta_size + n], addr)
                                    {
                                        eprintln!("Error sending packet: {}", e);
                                    }
                                } else {
                                    eprintln!("No remote address set; packet dropped.");
                                }
                            }
                            Err(e) => eprintln!("Error reading from TUN: {}", e),
                            _ => {}
                        }
                    }
                }
            });
        }

        for socket in sockets.values() {
            let mut recv_buf = [MaybeUninit::uninit(); 1500];
            loop {
                match socket.recv_from(&mut recv_buf) {
                    Ok((len, addr)) => {
                        let data: &mut [u8] = unsafe { transmute(&mut recv_buf[..len]) };

                        if let Some(packet) = Ipv4Packet::new(data) {
                            let header_len = packet.get_header_length() as usize * 4;
                            let (_ip_header, rest) = data.split_at_mut(header_len);
                            if rest.len() < size_of::<Meta>() {
                                continue;
                            }
                            let (meta_bytes, payload) = rest.split_at_mut(size_of::<Meta>());

                            let meta = Meta {
                                src_id: meta_bytes[0],
                                dst_id: meta_bytes[1],
                                reversed: u16::from_le_bytes([meta_bytes[2], meta_bytes[3]]),
                            };

                            if let Some(router) = routers.get_mut(&meta.src_id) {
                                if meta.dst_id == config.local_id && meta.reversed == 0 {
                                    router.endpoint = Some(addr);
                                    router.decrypt(payload);
                                    router.tun_writer.write_all(payload).unwrap();
                                }
                            } else {
                                // eprintln!("Dropping unexpected packet");
                            }
                        }
                    }
                    Err(e) => eprintln!("Error receiving: {}", e),
                }
            }
        }
    })
    .unwrap();

    Ok(())
}
