mod router;

use crate::router::{Router, SECRET_LENGTH};
use crossbeam_utils::thread;
use std::collections::HashMap;
use std::env;
use std::error::Error;
use std::io::{Read, Write};
use std::mem::MaybeUninit;
use std::mem::{size_of, transmute};
use std::sync::Arc;
use std::sync::atomic::Ordering;

#[repr(C)]
pub struct Meta {
    pub src_id: u8,
    pub dst_id: u8,
    pub reversed: u16,
}

use serde::{Deserialize, Deserializer};
use socket2::{Domain, Socket};

#[derive(Deserialize)]
pub struct Config {
    pub local_id: u8,
    pub local_secret: String,
    pub routers: Vec<ConfigRouter>,
}
#[derive(Deserialize)]
pub struct ConfigRouter {
    pub remote_id: u8,
    #[serde(default)]
    pub schema: Schema,
    #[serde(default)]
    pub proto: u8,
    #[serde(default)]
    pub src_port: u16,
    #[serde(default)]
    pub dst_port: u16,
    #[serde(deserialize_with = "deserialize_domain")]
    pub family: Domain,
    pub mark: u32,
    pub endpoint: String,
    pub remote_secret: String,
    pub dev: String,
    pub up: String,
}

#[derive(Deserialize, Default)]
pub enum Schema {
    #[default]
    IP,
    UDP,
    TCP,
    FakeTCP,
}

fn deserialize_domain<'de, D>(d: D) -> Result<Domain, D::Error>
where
    D: Deserializer<'de>,
{
    match u8::deserialize(d)? {
        4 => Ok(Domain::IPV4),
        6 => Ok(Domain::IPV6),
        _ => Err(serde::de::Error::custom("Invalid domain")),
    }
}

fn main() -> Result<(), Box<dyn Error>> {
    println!("Starting");
    let config: Config = serde_json::from_str(env::args().nth(1).ok_or("need param")?.as_str())?;
    let local_secret: [u8; SECRET_LENGTH] = Router::create_secret(config.local_secret.as_str())?;

    let mut udp_groups: HashMap<u16, (Arc<Socket>, Vec<u8>)> = HashMap::new();

    let routers: Vec<Router> = config
        .routers
        .into_iter()
        .map(|c| Router::new(c, config.local_id, &mut udp_groups))
        .collect::<Result<Vec<_>, _>>()?;

    for (socket, group) in udp_groups.values() {
        Router::attach_filter_udp(socket, group, config.local_id)?;
    }

    println!("created tuns");
    const META_SIZE: usize = size_of::<Meta>();

    thread::scope(|s| {
        for router in routers {
            let (mut reader, mut writer) = router.split();

            s.spawn(move |_| {
                let mut buffer = [0u8; 1500];

                // Pre-initialize with our Meta header (local -> remote)
                let meta = Meta {
                    src_id: config.local_id,
                    dst_id: reader.config.remote_id,
                    reversed: 0,
                };
                // Turn the Meta struct into bytes
                let meta_bytes: &[u8; META_SIZE] =
                    unsafe { &*(&meta as *const Meta as *const [u8; META_SIZE]) };
                buffer[..META_SIZE].copy_from_slice(meta_bytes);

                loop {
                    let n = reader.tun_reader.read(&mut buffer[META_SIZE..]).unwrap();
                    let guard = crossbeam::epoch::pin();
                    let shared = reader.endpoint.load(Ordering::Acquire, &guard);
                    if let Some(addr) = unsafe { shared.as_ref() } {
                        reader.encrypt(&mut buffer[META_SIZE..META_SIZE + n]);
                        let _ = reader.socket.send_to(&buffer[..META_SIZE + n], addr);
                    }
                }
            });

            s.spawn(move |_| {
                let mut recv_buf = [MaybeUninit::uninit(); 1500];
                loop {
                    let _ = (|| {
                        let (len, addr) = writer.socket.recv_from(&mut recv_buf).unwrap();
                        let packet: &mut [u8] = unsafe { transmute(&mut recv_buf[..len]) };
                        // if addr.is_ipv6() { println!("{:X?}", packet) }
                        // 只有 ipv4 raw 会给 IP报头
                        let offset =
                            if addr.is_ipv4() && addr.as_socket_ipv4().ok_or("?")?.port() == 0 {
                                (packet[0] & 0x0f) as usize * 4
                            } else {
                                0
                            } + META_SIZE;

                        let guard = crossbeam::epoch::pin();
                        let current_shared = writer.endpoint.load(Ordering::SeqCst, &guard);
                        let is_same = unsafe { current_shared.as_ref() }
                            .map(|c| *c == addr)
                            .unwrap_or(false);
                        if !is_same {
                            let new_shared = crossbeam::epoch::Owned::new(addr).into_shared(&guard);
                            let old_shared =
                                writer.endpoint.swap(new_shared, Ordering::SeqCst, &guard);
                            unsafe {
                                guard.defer_destroy(old_shared);
                            }
                        }

                        let payload = &mut packet[offset..];
                        writer.decrypt(payload, &local_secret);
                        writer.tun_writer.write_all(payload)?;
                        writer.last_activity = std::time::Instant::now();
                        Ok::<(), Box<dyn Error>>(())
                    })();
                }
            });
        }
    })
    .map_err(|_| "Thread panicked")?;
    Ok(())
}
