mod router;

use crate::router::{Router, RouterReader, RouterWriter, SECRET_LENGTH};
use std::collections::HashMap;
use std::env;
use std::error::Error;
use std::intrinsics::transmute;
use std::io::{Read, Write};
use std::mem::MaybeUninit;
use std::sync::Arc;
use std::sync::mpsc::{sync_channel, Receiver, SyncSender};
use std::time::Duration;

#[repr(C)]
pub struct Meta {
    pub src_id: u8,
    pub dst_id: u8,
    pub reversed: u16,
}

use serde::Deserialize;

#[derive(Deserialize)]
pub struct ConfigRouter {
    pub remote_id: u8,
    pub proto: i32,
    pub family: u8,
    pub mark: u32,
    pub endpoint: String,
    pub remote_secret: String,
    pub dev: String,
    pub up: String,
}

#[derive(Deserialize)]
pub struct Config {
    pub local_id: u8,
    pub local_secret: String,
    pub routers: Vec<ConfigRouter>,
}

use crossbeam_utils::thread;
use grouping_by::GroupingBy;
use pnet::packet::ipv4::Ipv4Packet;
use socket2::Socket;

// 批处理配置
const BATCH_SIZE: usize = 16;
const CHANNEL_SIZE: usize = 256;
const SOCKET_BUFFER_SIZE: usize = 2 * 1024 * 1024; // 2MB

struct Packet {
    data: Vec<u8>,
    len: usize,
}

fn main() -> Result<(), Box<dyn Error>> {
    let config: Config = serde_json::from_str(env::args().nth(1).ok_or("need param")?.as_str())?;
    let local_secret: [u8; SECRET_LENGTH] = Router::create_secret(config.local_secret.as_str())?;
    let mut sockets: HashMap<u16, Arc<Socket>> = HashMap::new();
    
    let routers: HashMap<u8, Router> = config
        .routers
        .iter()
        .map(|c| {
            Router::new(c, &mut sockets).map(|router| {
                // 为 socket 设置更大的缓冲区
                if let Some(socket) = sockets.get(&Router::key(c)) {
                    let _ = socket.set_send_buffer_size(SOCKET_BUFFER_SIZE);
                    let _ = socket.set_recv_buffer_size(SOCKET_BUFFER_SIZE);
                }
                (c.remote_id, router)
            })
        })
        .collect::<Result<_, _>>()?;
    
    let (mut router_readers, router_writers): (
        HashMap<u8, RouterReader>,
        HashMap<u8, RouterWriter>,
    ) = routers
        .into_iter()
        .map(|(id, router)| {
            let (reader, writer) = router.split();
            ((id, reader), (id, writer))
        })
        .unzip();
    
    let router_writers3: Vec<(Arc<Socket>, HashMap<u8, RouterWriter>)> = router_writers
        .into_iter()
        .grouping_by(|(_, v)| v.key())
        .into_iter()
        .map(|(k, v)| {
            (
                Arc::clone(sockets.get_mut(&k).unwrap()),
                v.into_iter().collect(),
            )
        })
        .collect();
    
    println!("created tuns");

    thread::scope(|s| {
        // 发送端优化：添加批处理
        for router in router_readers.values_mut() {
            let (tx, rx): (SyncSender<Packet>, Receiver<Packet>) = sync_channel(CHANNEL_SIZE);
            
            // 读取线程
            s.spawn(|_| {
                let mut buffer = [0u8; 1500 - 20];
                let meta_size = size_of::<Meta>();

                let meta = Meta {
                    src_id: config.local_id,
                    dst_id: router.config.remote_id,
                    reversed: 0,
                };
                let meta_bytes = unsafe {
                    std::slice::from_raw_parts(&meta as *const Meta as *const u8, meta_size)
                };

                loop {
                    let n = router.tun_reader.read(&mut buffer[meta_size..]).unwrap();
                    if n > 0 {
                        let mut packet_data = vec![0u8; meta_size + n];
                        packet_data[..meta_size].copy_from_slice(meta_bytes);
                        packet_data[meta_size..].copy_from_slice(&buffer[meta_size..meta_size + n]);
                        
                        let packet = Packet {
                            data: packet_data,
                            len: meta_size + n,
                        };
                        let _ = tx.try_send(packet);
                    }
                }
            });
            
            // 批处理发送线程
            s.spawn(|_| {
                let mut batch = Vec::with_capacity(BATCH_SIZE);
                let meta_size = size_of::<Meta>();
                
                loop {
                    batch.clear();
                    
                    // 收集批量数据，使用超时避免延迟
                    if let Ok(packet) = rx.recv_timeout(Duration::from_micros(100)) {
                        batch.push(packet);
                        
                        // 尝试收集更多包
                        while batch.len() < BATCH_SIZE {
                            match rx.try_recv() {
                                Ok(packet) => batch.push(packet),
                                Err(_) => break,
                            }
                        }
                    }
                    
                    // 批量发送
                    if !batch.is_empty() {
                        if let Some(ref addr) = *router.endpoint.read().unwrap() {
                            // 批量加密和发送
                            for packet in &mut batch {
                                router.encrypt(&mut packet.data[meta_size..packet.len]);
                                #[cfg(target_os = "linux")]
                                let _ = router.socket.set_mark(router.config.mark);
                                let _ = router.socket.send_to(&packet.data[..packet.len], addr);
                            }
                        }
                    }
                }
            });
        }

        // 接收端保持原样，避免引入问题
        for (socket, mut router_writers) in router_writers3 {
            s.spawn(move |_| {
                let mut recv_buf = [MaybeUninit::uninit(); 1500];
                loop {
                    let _ = (|| {
                        let (len, addr) = socket.recv_from(&mut recv_buf).unwrap();
                        let data: &mut [u8] = unsafe { transmute(&mut recv_buf[..len]) };

                        let packet = Ipv4Packet::new(data).ok_or("malformed packet")?;
                        let header_len = packet.get_header_length() as usize * 4;
                        let (_ip_header, rest) = data
                            .split_at_mut_checked(header_len)
                            .ok_or("malformed packet")?;
                        let (meta_bytes, payload) = rest
                            .split_at_mut_checked(size_of::<Meta>())
                            .ok_or("malformed packet")?;
                        let meta: &Meta = unsafe { transmute(meta_bytes.as_ptr()) };
                        if meta.dst_id == config.local_id && meta.reversed == 0 {
                            let router = router_writers
                                .get_mut(&meta.src_id)
                                .ok_or("missing router")?;
                            *router.endpoint.write().unwrap() = Some(addr);
                            router.decrypt(payload, &local_secret);
                            router.tun_writer.write_all(payload)?;
                        }

                        Ok::<(), Box<dyn Error>>(())
                    })();
                }
            });
        }
    })
    .unwrap();
    Ok(())
}
