mod router;

use crate::router::{Router, RouterReader, RouterWriter, SECRET_LENGTH};
use std::collections::HashMap;
use std::env;
use std::error::Error;
use std::intrinsics::transmute;
use std::io::{Read, Write};
use std::mem::MaybeUninit;
use std::sync::Arc;

#[repr(C)]
pub struct Meta {
    pub src_id: u8,
    pub dst_id: u8,
    pub reversed: u16,
}

use serde::Deserialize;

#[derive(Deserialize)]
pub struct ConfigRouter {
    pub remote_id: u8,
    pub proto: i32,
    pub family: u8,
    pub mark: u32,
    pub endpoint: String,
    pub remote_secret: String,
    pub dev: String,
    pub up: String,
}

#[derive(Deserialize)]
pub struct Config {
    pub local_id: u8,
    pub local_secret: String,
    pub routers: Vec<ConfigRouter>,
}
use crossbeam_utils::thread;
use grouping_by::GroupingBy;
use socket2::Socket;

fn main() -> Result<(), Box<dyn Error>> {
    let config: Config = serde_json::from_str(env::args().nth(1).ok_or("need param")?.as_str())?;
    let local_secret: [u8; SECRET_LENGTH] = Router::create_secret(config.local_secret.as_str())?;
    let mut sockets: HashMap<u16, Arc<Socket>> = HashMap::new();
    let routers: HashMap<u8, Router> = config
        .routers
        .iter()
        .map(|c| Router::new(c, &mut sockets).map(|router| (c.remote_id, router)))
        .collect::<Result<_, _>>()?;
    let (mut router_readers, router_writers): (
        HashMap<u8, RouterReader>,
        HashMap<u8, RouterWriter>,
    ) = routers
        .into_iter()
        .map(|(id, router)| {
            let (reader, writer) = router.split();
            ((id, reader), (id, writer))
        })
        .unzip();
    let router_writers3: Vec<(Arc<Socket>, HashMap<u8, RouterWriter>)> = router_writers
        .into_iter()
        .grouping_by(|(_, v)| v.key())
        .into_iter()
        .map(|(k, v)| {
            (
                Arc::clone(sockets.get_mut(&k).unwrap()),
                v.into_iter().collect(),
            )
        })
        .collect();
    println!("created tuns");

    thread::scope(|s| {
        for router in router_readers.values_mut() {
            #[cfg(target_os = "linux")]
            let mark_set = std::sync::atomic::AtomicBool::new(false);
            
            s.spawn(move |_| {
                // 使用更大的缓冲区以支持巨帧
                let mut buffer = vec![0u8; 9000];
                let meta_size = size_of::<Meta>();

                // 预初始化 Meta 头部（local -> remote）
                let meta = Meta {
                    src_id: config.local_id,
                    dst_id: router.config.remote_id,
                    reversed: 0,
                };
                // 直接写入缓冲区，避免额外的切片操作
                unsafe {
                    let meta_ptr = buffer.as_mut_ptr() as *mut Meta;
                    *meta_ptr = meta;
                }

                loop {
                    match router.tun_reader.read(&mut buffer[meta_size..]) {
                        Ok(n) => {
                            // 使用 try_read 减少锁争用
                            if let Ok(endpoint_guard) = router.endpoint.try_read() {
                                if let Some(ref addr) = *endpoint_guard {
                                    router.encrypt(&mut buffer[meta_size..meta_size + n]);
                                    
                                    #[cfg(target_os = "linux")]
                                    {
                                        // 只在第一次设置 mark
                                        use std::sync::atomic::Ordering;
                                        if !mark_set.load(Ordering::Relaxed) {
                                            let _ = router.socket.set_mark(router.config.mark);
                                            mark_set.store(true, Ordering::Relaxed);
                                        }
                                    }
                                    
                                    let _ = router.socket.send_to(&buffer[..meta_size + n], addr);
                                }
                            }
                        }
                        Err(_) => {
                            // TUN 读取失败时短暂休眠，避免 CPU 空转
                            std::thread::sleep(std::time::Duration::from_millis(1));
                        }
                    }
                }
            });
        }

        for (socket, mut router_writers) in router_writers3 {
            s.spawn(move |_| {
                // 使用更大的缓冲区和重用内存
                let mut recv_buf = vec![MaybeUninit::uninit(); 9000];
                let meta_size = size_of::<Meta>();
                
                loop {
                    match socket.recv_from(&mut recv_buf) {
                        Ok((len, addr)) => {
                            // 快速边界检查
                            if len < 20 + meta_size {
                                continue;
                            }
                            
                            let data: &mut [u8] = unsafe { transmute(&mut recv_buf[..len]) };
                            
                            // 优化：直接计算 IP 头部长度，避免创建 Ipv4Packet
                            let header_len = ((data[0] & 0x0f) as usize) * 4;
                            
                            if len < header_len + meta_size {
                                continue;
                            }
                            
                            // 直接从内存读取 Meta，避免额外的切片操作
                            let meta: &Meta = unsafe { 
                                &*(data.as_ptr().add(header_len) as *const Meta)
                            };
                            
                            if meta.dst_id == config.local_id && meta.reversed == 0 {
                                if let Some(router) = router_writers.get_mut(&meta.src_id) {
                                    // 使用 try_write 减少锁争用
                                    if let Ok(mut endpoint) = router.endpoint.try_write() {
                                        *endpoint = Some(addr);
                                    }
                                    
                                    let payload_start = header_len + meta_size;
                                    let payload = &mut data[payload_start..];
                                    
                                    router.decrypt(payload, &local_secret);
                                    
                                    // 忽略写入错误，继续处理下一个数据包
                                    let _ = router.tun_writer.write_all(payload);
                                }
                            }
                        }
                        Err(_) => {
                            // Socket 接收失败时短暂休眠
                            std::thread::sleep(std::time::Duration::from_millis(1));
                        }
                    }
                }
            });
        }
    })
    .unwrap();
    Ok(())
}
