mod router;

use crate::router::{Router, RouterReader, RouterWriter, SECRET_LENGTH};
use std::collections::HashMap;
use std::env;
use std::error::Error;
use std::intrinsics::transmute;
use std::io::{Read, Write};
use std::mem::MaybeUninit;
use std::sync::Arc;

#[repr(C)]
pub struct Meta {
    pub src_id: u8,
    pub dst_id: u8,
    pub reversed: u16,
}

use serde::Deserialize;

#[derive(Deserialize)]
pub struct ConfigRouter {
    pub remote_id: u8,
    pub proto: i32,
    pub family: u8,
    pub mark: u32,
    pub endpoint: String,
    pub remote_secret: String,
    pub dev: String,
    pub up: String,
}

#[derive(Deserialize)]
pub struct Config {
    pub local_id: u8,
    pub local_secret: String,
    pub routers: Vec<ConfigRouter>,
}
use crossbeam_utils::thread;
use grouping_by::GroupingBy;
use pnet::packet::ipv4::Ipv4Packet;
use socket2::Socket;

// 优化参数 - 针对高延迟网络
const MTU: usize = 1500;
const MAX_PACKET_SIZE: usize = MTU - 20; // 减去 IP 头部
const BATCH_SIZE: usize = 64; // 批量处理数量
const SOCKET_BUFFER_SIZE: usize = 16 * 1024 * 1024; // 16MB socket 缓冲区

fn main() -> Result<(), Box<dyn Error>> {
    let config: Config = serde_json::from_str(env::args().nth(1).ok_or("need param")?.as_str())?;
    let local_secret: [u8; SECRET_LENGTH] = Router::create_secret(config.local_secret.as_str())?;
    let mut sockets: HashMap<u16, Arc<Socket>> = HashMap::new();
    
    let routers: HashMap<u8, Router> = config
        .routers
        .iter()
        .map(|c| Router::new(c, &mut sockets).map(|router| (c.remote_id, router)))
        .collect::<Result<_, _>>()?;
    
    // 优化 raw socket 缓冲区
    for socket in sockets.values() {
        let _ = socket.set_send_buffer_size(SOCKET_BUFFER_SIZE);
        let _ = socket.set_recv_buffer_size(SOCKET_BUFFER_SIZE);
        
        // Linux 特定优化
        #[cfg(target_os = "linux")]
        {
            use std::os::unix::io::AsRawFd;
            unsafe {
                // 设置 IP_RECVERR 以快速检测错误
                let enable = 1i32;
                libc::setsockopt(
                    socket.as_raw_fd(),
                    libc::IPPROTO_IP,
                    libc::IP_RECVERR,
                    &enable as *const _ as *const libc::c_void,
                    std::mem::size_of_val(&enable) as libc::socklen_t,
                );
                
                // 设置 SO_RCVBUFFORCE 和 SO_SNDBUFFORCE 绕过系统限制（需要 CAP_NET_ADMIN）
                let force_size = SOCKET_BUFFER_SIZE as i32;
                libc::setsockopt(
                    socket.as_raw_fd(),
                    libc::SOL_SOCKET,
                    libc::SO_RCVBUFFORCE,
                    &force_size as *const _ as *const libc::c_void,
                    std::mem::size_of_val(&force_size) as libc::socklen_t,
                );
                libc::setsockopt(
                    socket.as_raw_fd(),
                    libc::SOL_SOCKET,
                    libc::SO_SNDBUFFORCE,
                    &force_size as *const _ as *const libc::c_void,
                    std::mem::size_of_val(&force_size) as libc::socklen_t,
                );
            }
        }
    }
    
    let (mut router_readers, router_writers): (
        HashMap<u8, RouterReader>,
        HashMap<u8, RouterWriter>,
    ) = routers
        .into_iter()
        .map(|(id, router)| {
            let (reader, writer) = router.split();
            ((id, reader), (id, writer))
        })
        .unzip();
    let router_writers3: Vec<(Arc<Socket>, HashMap<u8, RouterWriter>)> = router_writers
        .into_iter()
        .grouping_by(|(_, v)| v.key())
        .into_iter()
        .map(|(k, v)| {
            (
                Arc::clone(sockets.get_mut(&k).unwrap()),
                v.into_iter().collect(),
            )
        })
        .collect();
    println!("created tuns");

    thread::scope(|s| {
        // 发送线程 - 批量处理以提高吞吐量
        for router in router_readers.values_mut() {
            s.spawn(|_| {
                // 为批量发送准备多个缓冲区
                let mut buffers: Vec<Vec<u8>> = (0..BATCH_SIZE)
                    .map(|_| vec![0u8; MAX_PACKET_SIZE])
                    .collect();
                let meta_size = size_of::<Meta>();

                // 预初始化所有缓冲区的 Meta 头
                let meta = Meta {
                    src_id: config.local_id,
                    dst_id: router.config.remote_id,
                    reversed: 0,
                };
                let meta_bytes = unsafe {
                    std::slice::from_raw_parts(&meta as *const Meta as *const u8, meta_size)
                };
                for buffer in &mut buffers {
                    buffer[..meta_size].copy_from_slice(meta_bytes);
                }

                let mut batch_count = 0;
                let mut batch_data: Vec<(usize, usize)> = Vec::with_capacity(BATCH_SIZE); // (buffer_idx, data_len)
                
                loop {
                    // 批量读取
                    batch_data.clear();
                    for i in 0..BATCH_SIZE {
                        match router.tun_reader.read(&mut buffers[i][meta_size..]) {
                            Ok(n) if n > 0 => {
                                batch_data.push((i, n));
                                if batch_data.len() >= 32 { // 达到一定数量就发送
                                    break;
                                }
                            }
                            Ok(_) => break,
                            Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => break,
                            Err(_) => break,
                        }
                    }
                    
                    // 批量加密和发送
                    if !batch_data.is_empty() {
                        if let Some(ref addr) = *router.endpoint.read().unwrap() {
                            #[cfg(target_os = "linux")]
                            let _ = router.socket.set_mark(router.config.mark);
                            
                            // 批量处理所有包
                            for &(idx, len) in &batch_data {
                                let buffer = &mut buffers[idx];
                                router.encrypt(&mut buffer[meta_size..meta_size + len]);
                                
                                // 快速发送，不等待
                                let _ = router.socket.send_to(&buffer[..meta_size + len], addr);
                            }
                            
                            batch_count += batch_data.len();
                            
                            // 定期 yield 以避免饥饿其他线程
                            if batch_count > 1000 {
                                batch_count = 0;
                                std::thread::yield_now();
                            }
                        }
                    } else {
                        // 没有数据时短暂休眠
                        std::thread::sleep(std::time::Duration::from_micros(100));
                    }
                }
            });
        }

        // 接收线程 - 批量处理和缓存写入
        for (socket, mut router_writers) in router_writers3 {
            s.spawn(move |_| {
                // 多个接收缓冲区用于批量处理
                let mut recv_bufs: Vec<[MaybeUninit<u8>; MAX_PACKET_SIZE]> = 
                    (0..BATCH_SIZE).map(|_| [MaybeUninit::uninit(); MAX_PACKET_SIZE]).collect();
                
                // 为每个 router 维护写入缓冲区
                let mut write_buffers: HashMap<u8, Vec<u8>> = HashMap::new();
                let mut recv_count = 0;
                
                loop {
                    // 批量接收
                    let mut received_packets = Vec::new();
                    
                    // 尝试接收多个包
                    for i in 0..32 {
                        match socket.recv_from(&mut recv_bufs[i]) {
                            Ok((len, addr)) => {
                                received_packets.push((i, len, addr));
                            }
                            Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
                                break;
                            }
                            Err(_) => break,
                        }
                    }
                    
                    // 批量处理接收到的包
                    for (buf_idx, len, addr) in received_packets {
                        let _ = (|| {
                            let data: &mut [u8] = unsafe { 
                                transmute(&mut recv_bufs[buf_idx][..len]) 
                            };

                            let packet = Ipv4Packet::new(data).ok_or("malformed packet")?;
                            let header_len = packet.get_header_length() as usize * 4;
                            let (_ip_header, rest) = data
                                .split_at_mut_checked(header_len)
                                .ok_or("malformed packet")?;
                            let (meta_bytes, payload) = rest
                                .split_at_mut_checked(size_of::<Meta>())
                                .ok_or("malformed packet")?;
                            let meta: &Meta = unsafe { transmute(meta_bytes.as_ptr()) };
                            
                            if meta.dst_id == config.local_id && meta.reversed == 0 {
                                let router = router_writers
                                    .get_mut(&meta.src_id)
                                    .ok_or("missing router")?;
                                *router.endpoint.write().unwrap() = Some(addr);
                                router.decrypt(payload, &local_secret);
                                
                                // 缓存数据以批量写入
                                let write_buf = write_buffers
                                    .entry(meta.src_id)
                                    .or_insert_with(|| Vec::with_capacity(65536));
                                write_buf.extend_from_slice(payload);
                                
                                // 当缓冲区达到一定大小时写入
                                if write_buf.len() >= 32768 {
                                    let data = std::mem::take(write_buf);
                                    let _ = router.tun_writer.write_all(&data);
                                }
                            }

                            Ok::<(), Box<dyn Error>>(())
                        })();
                    }
                    
                    // 定期刷新所有缓冲区
                    recv_count += 1;
                    if recv_count > 100 {
                        recv_count = 0;
                        for (router_id, data) in write_buffers.drain() {
                            if !data.is_empty() {
                                if let Some(router) = router_writers.get_mut(&router_id) {
                                    let _ = router.tun_writer.write_all(&data);
                                }
                            }
                        }
                    }
                    
                    // 如果没有接收到数据，短暂休眠
                    if received_packets.is_empty() {
                        std::thread::sleep(std::time::Duration::from_micros(100));
                    }
                }
            });
        }
    })
    .unwrap();
    Ok(())
}
