mod router;

use crate::router::{Meta, Router, META_SIZE, SECRET_LENGTH};
use crate::Schema::{TCP, UDP};
use anyhow::{Context, Result};
use crossbeam::epoch::{pin, Owned};
use crossbeam_utils::thread;
use itertools::Itertools;
use serde::{Deserialize, Deserializer};
use socket2::Domain;
use std::net::Shutdown;
use std::sync::atomic::Ordering;
use std::time::Duration;
use std::{collections::HashMap, env, mem::MaybeUninit, sync::Arc};

#[derive(Deserialize)]
pub struct Config {
    pub local_id: u8,
    pub local_secret: String,
    pub routers: Vec<ConfigRouter>,
}
#[derive(Deserialize, Clone)]
pub struct ConfigRouter {
    pub remote_id: u8,
    #[serde(default)]
    pub schema: Schema,
    #[serde(default)]
    pub proto: u8,
    #[serde(default)]
    pub src_port: u16,
    #[serde(default)]
    pub dst_port: u16,
    #[serde(deserialize_with = "deserialize_domain")]
    pub family: Domain,
    pub mark: u32,
    pub endpoint: String,
    pub remote_secret: String,
    pub dev: String,
    pub up: String,
}

#[derive(Deserialize, Default, PartialEq, Clone, Copy)]
pub enum Schema {
    #[default]
    IP,
    UDP,
    TCP,
}

fn deserialize_domain<'de, D>(d: D) -> Result<Domain, D::Error>
where
    D: Deserializer<'de>,
{
    match u8::deserialize(d)? {
        4 => Ok(Domain::IPV4),
        6 => Ok(Domain::IPV6),
        _ => Err(serde::de::Error::custom("Invalid domain")),
    }
}

fn main() -> Result<()> {
    println!("Starting");
    let config = serde_json::from_str::<Config>(env::args().nth(1).context("need param")?.as_str())?;
    let local_secret: [u8; SECRET_LENGTH] = Router::create_secret(config.local_secret.as_str())?;

    let routers = Arc::new(
        config
            .routers
            .iter()
            .cloned()
            .sorted_by_key(|r| r.remote_id)
            .map(|c| {
                let remote_id = c.remote_id;
                Router::new(c, config.local_id).map(|r| (remote_id, r))
            })
            .collect::<Result<HashMap<u8, Router>, _>>()?,
    );

    for (_, group) in &routers
        .values()
        .filter(|r| r.config.schema == UDP && r.config.src_port != 0)
        .chunk_by(|r| r.config.src_port)
    {
        Router::attach_filter_udp(group.sorted_by_key(|r| r.config.remote_id).collect(), config.local_id)?;
    }

    println!("created tuns");
    const TCP_RECONNECT: u64 = 10;

    thread::scope(|s| {
        // IP, UDP
        for router in routers.values().filter(|&r| !(r.config.schema != TCP)) {
            s.spawn(|_| {
                router.handle_outbound_ip_udp(config.local_id);
            });

            s.spawn(|_| {
                router.handle_inbound_ip_udp(&local_secret);
            });
        }

        for router in routers.values().filter(|&r| r.config.schema == TCP && r.config.dst_port != 0) {
            s.spawn(|_| {
                loop {
                    if let Ok(connection) = router.connect_tcp(config.local_id) {
                        let _ = thread::scope(|s| {
                            s.spawn(|_| router.handle_outbound_tcp(&connection));
                            s.spawn(|_| router.handle_inbound_tcp(&connection, &local_secret));
                        });
                    }
                    std::thread::sleep(Duration::from_secs(TCP_RECONNECT));
                }
            });
        }

        // tcp listeners
        for router in routers
            .values()
            .filter(|&r| r.config.schema == TCP && r.config.dst_port == 0)
            .unique_by(|r| r.config.src_port)
        {
            s.spawn(|_| {
                // accept 出错直接 panic
                loop {
                    let (connection, _) = router.socket.accept().unwrap();
                    thread::scope(|s| {
                        s.spawn(|_| {
                            connection.set_tcp_nodelay(true).unwrap();

                            let mut meta_bytes = [MaybeUninit::uninit(); META_SIZE];
                            Router::recv_exact_tcp(&connection, &mut meta_bytes).unwrap();
                            let meta: &Meta = Meta::from_bytes(&meta_bytes);
                            if meta.reversed == 0
                                && meta.dst_id == config.local_id
                                && let Some(router) = routers.get(&meta.src_id)
                            {
                                let connection = Arc::new(connection);

                                // tcp listener 只许一个连接，过来新连接就把前一个关掉。
                                {
                                    let guard = pin();
                                    let new_shared = Owned::new(connection.clone()).into_shared(&guard);
                                    let old_shared = router.tcp_listener_connection.swap(new_shared, Ordering::Release, &guard);
                                    unsafe {
                                        if let Some(old) = old_shared.as_ref() {
                                            let _ = old.shutdown(Shutdown::Both);
                                        }
                                        guard.defer_destroy(old_shared)
                                    }
                                }

                                let _ = thread::scope(|s| {
                                    s.spawn(|_| router.handle_outbound_tcp(&connection));
                                    s.spawn(|_| router.handle_inbound_tcp(&connection, &local_secret));
                                });
                            }
                        });
                    })
                    .unwrap();
                }
            });
        }
    })
    .unwrap();
    Ok(())
}
