mod config;
mod router;

use crate::config::{Config, Schema};
use crate::router::{Meta, Router, META_SIZE};
use anyhow::anyhow;
use anyhow::{Context, Result};
use crossbeam::epoch::{pin, Owned};
use crossbeam_utils::thread;
use itertools::Itertools;
use std::collections::BTreeMap;
use std::net::Shutdown;
use std::sync::atomic::Ordering;
use std::time::Duration;
use std::{env, mem::MaybeUninit};
use std::fs;

fn main() -> Result<()> {
    println!("Starting");
    let args: Vec<String> = env::args().collect();
    if args.len() < 2 {
        return Err(anyhow!("need JSON string or -c <config.json>"));
    }

    let config: Config;

    if args[1] == "-c" || args[1] == "--config" {
        // 从文件读
        if args.len() < 3 {
            return Err(anyhow!("missing value for -c/--config"));
        }
        let data = fs::read_to_string(&args[2])?;
        config = serde_json::from_str(&data)?;
    } else {
        // 当作 JSON 字符串解析
        config = serde_json::from_str(&args[1])?;
    }
    println!("Read config");

    let routers = &config
        .routers
        .into_iter()
        .sorted_by_key(|r| r.remote_id)
        .map(|c| Router::new(c).map(|router| (router.config.remote_id, router)))
        .collect::<Result<BTreeMap<u8, Router>, _>>()?;

    for (_, group) in &routers
        .values()
        .filter(|r| r.config.schema == Schema::UDP && r.config.src_port != 0)
        .chunk_by(|r| r.config.src_port)
    {
        Router::attach_filter_udp(group.collect())?;
    }

    println!("created tuns");
    const TCP_RECONNECT: u64 = 10;

    thread::scope(|s| {
        // IP, UDP
        for router in routers.values().filter(|&r| r.config.schema != Schema::TCP) {
            s.spawn(|_| {
                router.handle_outbound_ip_udp();
            });

            s.spawn(|_| {
                router.handle_inbound_ip_udp();
            });
        }

        for router in routers.values().filter(|&r| r.config.schema == Schema::TCP && r.config.dst_port != 0) {
            s.spawn(|_| {
                loop {
                    if let Ok(connection) = router.connect_tcp() {
                        let _ = thread::scope(|s| {
                            s.spawn(|_| router.handle_outbound_tcp(&connection));
                            s.spawn(|_| router.handle_inbound_tcp(&connection));
                        });
                    }
                    std::thread::sleep(Duration::from_secs(TCP_RECONNECT));
                }
            });
        }

        // tcp listeners
        for router in routers
            .values()
            .filter(|&r| r.config.schema == Schema::TCP && r.config.dst_port == 0)
            .unique_by(|r| r.config.src_port)
        {
            println!("listen on port {}", router.config.src_port);
            let socket = router.listen_tcp();
            s.spawn(move |s| {
                // listen 或 accept 出错直接 panic
                loop {
                    let (connection, _) = socket.accept().unwrap();
                    s.spawn(move |_| {
                        connection.set_tcp_nodelay(true).unwrap();

                        let mut meta_bytes = [MaybeUninit::uninit(); META_SIZE];
                        Router::recv_exact_tcp(&connection, &mut meta_bytes).unwrap();
                        let meta: &Meta = Meta::from_bytes(&meta_bytes);
                        if meta.reversed == 0
                            && let Some(router) = routers.get(&meta.src_id)
                            && meta.dst_id == router.config.local_id
                        {
                            // let connection = Arc::new(connection);

                            // tcp listener 只许一个连接，过来新连接就把前一个关掉。
                            {
                                let guard = pin();
                                let new_shared = Owned::new(connection).into_shared(&guard);
                                let old_shared = router.tcp_listener_connection.swap(new_shared, Ordering::AcqRel, &guard);
                                //
                                // SAFETY: this is guaranteed to still point to valid connection because
                                // the guard is swapped with AcqRel so we are for sure tracked by the pin
                                // list
                                //
                                if let Some(old) = unsafe { old_shared.as_ref() } {
                                    let _ = old.shutdown(Shutdown::Both);

                                    // SAFETY: At this point old_shared is guaranteed
                                    // to be non-null (above if let checks that)
                                    // And since it is already swapped out of the
                                    // `tcp_listener_connection` no other thread
                                    // should have access to it.
                                    unsafe {
                                        guard.defer_destroy(old_shared);
                                    }
                                }
                            }

                            let _ = thread::scope(|s| {
                                s.spawn(|_| {
                                    let guard = pin();
                                    let shared = router.tcp_listener_connection.load(Ordering::Acquire, &guard);
                                    // SAFETY: tcp_listener_connection shoud always either point to null or some valid thing
                                    if let Some(connection) = unsafe { shared.as_ref() } {
                                        router.handle_outbound_tcp(connection);
                                    }
                                });

                                s.spawn(|_| {
                                    let guard = pin();
                                    let shared = router.tcp_listener_connection.load(Ordering::Acquire, &guard);
                                    // SAFETY: tcp_listener_connection shoud always either point to null or some valid thing
                                    if let Some(connection) = unsafe { shared.as_ref() } {
                                        router.handle_inbound_tcp(&connection);
                                    }
                                });
                            });
                        }
                    });
                }
            });
        }
    })
    .unwrap();
    Ok(())
}
