mod router;

use crate::router::{Router, RouterReader, RouterWriter, SECRET_LENGTH};
use std::collections::HashMap;
use std::env;
use std::error::Error;
use std::intrinsics::transmute;
use std::io::{Read, Write};
use std::mem::MaybeUninit;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};

#[repr(C)]
pub struct Meta {
    pub src_id: u8,
    pub dst_id: u8,
    pub reversed: u16,
}

use serde::Deserialize;

#[derive(Deserialize)]
pub struct ConfigRouter {
    pub remote_id: u8,
    pub proto: i32,
    pub family: u8,
    pub mark: u32,
    pub endpoint: String,
    pub remote_secret: String,
    pub dev: String,
    pub up: String,
}

#[derive(Deserialize)]
pub struct Config {
    pub local_id: u8,
    pub local_secret: String,
    pub routers: Vec<ConfigRouter>,
}
use crossbeam_utils::thread;
use grouping_by::GroupingBy;
use pnet::packet::ipv4::Ipv4Packet;
use socket2::{Socket, SockAddr};

// 优化参数
const BUFFER_SIZE: usize = 65536; // 64KB 缓冲区
const BATCH_SIZE: usize = 32; // 批量处理大小
const SOCKET_BUFFER_SIZE: usize = 8 * 1024 * 1024; // 8MB socket 缓冲区

fn main() -> Result<(), Box<dyn Error>> {
    let config: Config = serde_json::from_str(env::args().nth(1).ok_or("need param")?.as_str())?;
    let local_secret: [u8; SECRET_LENGTH] = Router::create_secret(config.local_secret.as_str())?;
    let mut sockets: HashMap<u16, Arc<Socket>> = HashMap::new();
    
    let routers: HashMap<u8, Router> = config
        .routers
        .iter()
        .map(|c| Router::new(c, &mut sockets).map(|router| (c.remote_id, router)))
        .collect::<Result<_, _>>()?;
    
    // 优化 socket 缓冲区大小
    for socket in sockets.values() {
        let _ = socket.set_send_buffer_size(SOCKET_BUFFER_SIZE);
        let _ = socket.set_recv_buffer_size(SOCKET_BUFFER_SIZE);
        #[cfg(target_os = "linux")]
        {
            // 启用 GSO/GRO
            unsafe {
                let enable = 1i32;
                libc::setsockopt(
                    socket.as_raw_fd(),
                    libc::SOL_UDP,
                    libc::UDP_GRO,
                    &enable as *const _ as *const libc::c_void,
                    std::mem::size_of_val(&enable) as libc::socklen_t,
                );
            }
        }
    }
    
    let (mut router_readers, router_writers): (
        HashMap<u8, RouterReader>,
        HashMap<u8, RouterWriter>,
    ) = routers
        .into_iter()
        .map(|(id, router)| {
            let (reader, writer) = router.split();
            ((id, reader), (id, writer))
        })
        .unzip();
    let router_writers3: Vec<(Arc<Socket>, HashMap<u8, RouterWriter>)> = router_writers
        .into_iter()
        .grouping_by(|(_, v)| v.key())
        .into_iter()
        .map(|(k, v)| {
            (
                Arc::clone(sockets.get_mut(&k).unwrap()),
                v.into_iter().collect(),
            )
        })
        .collect();
    println!("created tuns");

    thread::scope(|s| {
        // 为每个路由创建多个发送线程
        for router in router_readers.values_mut() {
            let router_id = router.config.remote_id;
            let local_id = config.local_id;
            let mark = router.config.mark;
            
            // 创建 4 个并发发送线程
            for _ in 0..4 {
                let socket = Arc::clone(&router.socket);
                let endpoint = Arc::clone(&router.endpoint);
                let tun_reader = router.tun_reader.try_clone().unwrap();
                let encrypt_fn = router.encrypt.clone();
                
                s.spawn(move |_| {
                    let mut buffers: Vec<Vec<u8>> = (0..BATCH_SIZE)
                        .map(|_| vec![0u8; BUFFER_SIZE])
                        .collect();
                    let meta_size = size_of::<Meta>();
                    
                    // 预初始化 Meta 头
                    let meta = Meta {
                        src_id: local_id,
                        dst_id: router_id,
                        reversed: 0,
                    };
                    let meta_bytes = unsafe {
                        std::slice::from_raw_parts(&meta as *const Meta as *const u8, meta_size)
                    };
                    
                    for buffer in &mut buffers {
                        buffer[..meta_size].copy_from_slice(meta_bytes);
                    }
                    
                    let mut current_buffer = 0;
                    
                    loop {
                        let buffer = &mut buffers[current_buffer];
                        match tun_reader.read(&mut buffer[meta_size..]) {
                            Ok(n) if n > 0 => {
                                if let Some(ref addr) = *endpoint.read().unwrap() {
                                    encrypt_fn(&mut buffer[meta_size..meta_size + n]);
                                    #[cfg(target_os = "linux")]
                                    let _ = socket.set_mark(mark);
                                    
                                    // 使用 MSG_DONTWAIT 避免阻塞
                                    match socket.send_to(&buffer[..meta_size + n], addr) {
                                        Ok(_) => {},
                                        Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
                                            // 缓冲区满，稍后重试
                                            std::thread::yield_now();
                                        },
                                        Err(_) => {},
                                    }
                                }
                                current_buffer = (current_buffer + 1) % BATCH_SIZE;
                            },
                            _ => std::thread::yield_now(),
                        }
                    }
                });
            }
        }

        // 为每个 socket 创建多个接收线程
        for (socket, mut router_writers) in router_writers3 {
            // 创建 4 个并发接收线程
            for _ in 0..4 {
                let socket = Arc::clone(&socket);
                let mut router_writers = router_writers.clone();
                let local_id = config.local_id;
                let local_secret = local_secret.clone();
                
                s.spawn(move |_| {
                    let mut recv_bufs: Vec<[MaybeUninit<u8>; BUFFER_SIZE]> = (0..BATCH_SIZE)
                        .map(|_| [MaybeUninit::uninit(); BUFFER_SIZE])
                        .collect();
                    let mut current_buffer = 0;
                    
                    loop {
                        let recv_buf = &mut recv_bufs[current_buffer];
                        let _ = (|| {
                            let (len, addr) = match socket.recv_from(recv_buf) {
                                Ok(result) => result,
                                Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
                                    std::thread::yield_now();
                                    return Ok(());
                                },
                                Err(_) => return Ok(()),
                            };
                            
                            let data: &mut [u8] = unsafe { transmute(&mut recv_buf[..len]) };

                            let packet = Ipv4Packet::new(data).ok_or("malformed packet")?;
                            let header_len = packet.get_header_length() as usize * 4;
                            let (_ip_header, rest) = data
                                .split_at_mut_checked(header_len)
                                .ok_or("malformed packet")?;
                            let (meta_bytes, payload) = rest
                                .split_at_mut_checked(size_of::<Meta>())
                                .ok_or("malformed packet")?;
                            let meta: &Meta = unsafe { transmute(meta_bytes.as_ptr()) };
                            if meta.dst_id == local_id && meta.reversed == 0 {
                                let router = router_writers
                                    .get_mut(&meta.src_id)
                                    .ok_or("missing router")?;
                                *router.endpoint.write().unwrap() = Some(addr);
                                router.decrypt(payload, &local_secret);
                                
                                // 批量写入以减少系统调用
                                let mut offset = 0;
                                while offset < payload.len() {
                                    match router.tun_writer.write(&payload[offset..]) {
                                        Ok(n) => offset += n,
                                        Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
                                            std::thread::yield_now();
                                        },
                                        Err(_) => break,
                                    }
                                }
                            }
                            
                            current_buffer = (current_buffer + 1) % BATCH_SIZE;
                            Ok::<(), Box<dyn Error>>(())
                        })();
                    }
                });
            }
        }
    })
    .unwrap();
    Ok(())
}

// 辅助函数：设置 socket 为非阻塞模式
#[cfg(target_os = "linux")]
use std::os::unix::io::AsRawFd;
