mod router;

use crate::router::{Router, RouterReader, RouterWriter, SECRET_LENGTH};
use std::collections::HashMap;
use std::env;
use std::error::Error;
use std::intrinsics::transmute;
use std::io::{Read, Write};
use std::mem::MaybeUninit;
use std::sync::Arc;
use std::sync::atomic::{AtomicPtr, Ordering};
use std::ptr;
use std::thread::available_parallelism;

#[repr(C)]
pub struct Meta {
    pub src_id: u8,
    pub dst_id: u8,
    pub reversed: u16,
}

use serde::Deserialize;

#[derive(Deserialize)]
pub struct ConfigRouter {
    pub remote_id: u8,
    pub proto: i32,
    pub family: u8,
    pub mark: u32,
    pub endpoint: String,
    pub remote_secret: String,
    pub dev: String,
    pub up: String,
}

#[derive(Deserialize)]
pub struct Config {
    pub local_id: u8,
    pub local_secret: String,
    pub routers: Vec<ConfigRouter>,
}

use crossbeam_utils::thread;
use grouping_by::GroupingBy;
use pnet::packet::ipv4::Ipv4Packet;
use socket2::Socket;
use libc::{iovec, msghdr, mmsghdr, recvmmsg, sendmmsg, MSG_WAITFORONE};
use std::net::SocketAddr;

// 批量处理的包数量
const BATCH_SIZE: usize = 32;
// 更大的缓冲区以支持 GSO
const BUFFER_SIZE: usize = 65536;
// 预分配的缓冲池大小（会根据 CPU 数量调整）
const BUFFER_POOL_SIZE_PER_THREAD: usize = 64;

// 无锁的地址更新结构
struct AtomicEndpoint {
    ptr: AtomicPtr<SocketAddr>,
}

impl AtomicEndpoint {
    fn new() -> Self {
        AtomicEndpoint {
            ptr: AtomicPtr::new(ptr::null_mut()),
        }
    }
    
    fn load(&self) -> Option<SocketAddr> {
        let ptr = self.ptr.load(Ordering::Acquire);
        if ptr.is_null() {
            None
        } else {
            Some(unsafe { *ptr })
        }
    }
    
    fn store(&self, addr: SocketAddr) {
        let boxed = Box::new(addr);
        let old = self.ptr.swap(Box::into_raw(boxed), Ordering::Release);
        if !old.is_null() {
            unsafe { Box::from_raw(old); }
        }
    }
}

// 缓冲池结构
struct BufferPool {
    buffers: Vec<Vec<u8>>,
}

impl BufferPool {
    fn new(size: usize) -> Self {
        let mut buffers = Vec::with_capacity(size);
        for _ in 0..size {
            buffers.push(vec![0u8; BUFFER_SIZE]);
        }
        BufferPool { buffers }
    }
    
    fn get_buffers(&mut self, count: usize) -> Vec<Vec<u8>> {
        let mut result = Vec::with_capacity(count);
        for _ in 0..count.min(self.buffers.len()) {
            if let Some(buf) = self.buffers.pop() {
                result.push(buf);
            } else {
                result.push(vec![0u8; BUFFER_SIZE]);
            }
        }
        while result.len() < count {
            result.push(vec![0u8; BUFFER_SIZE]);
        }
        result
    }
    
    fn return_buffers(&mut self, buffers: Vec<Vec<u8>>) {
        self.buffers.extend(buffers);
    }
}

fn main() -> Result<(), Box<dyn Error>> {
    // 获取系统的 CPU 核心数
    let num_threads = available_parallelism()
        .map(|n| n.get())
        .unwrap_or(4); // 默认使用 4 个线程
    
    println!("Using {} threads based on available CPU cores", num_threads);
    
    let config: Config = serde_json::from_str(env::args().nth(1).ok_or("need param")?.as_str())?;
    let local_secret: [u8; SECRET_LENGTH] = Router::create_secret(config.local_secret.as_str())?;
    let mut sockets: HashMap<u16, Arc<Socket>> = HashMap::new();
    
    // 为所有 socket 设置更大的缓冲区
    for socket in sockets.values_mut() {
        // 设置 socket 缓冲区，根据线程数动态调整
        let buffer_size = 2 * 1024 * 1024 * num_threads; // 每个线程 2MB
        let _ = socket.set_recv_buffer_size(buffer_size);
        let _ = socket.set_send_buffer_size(buffer_size);
    }
    
    let routers: HashMap<u8, Router> = config
        .routers
        .iter()
        .map(|c| Router::new(c, &mut sockets).map(|router| (c.remote_id, router)))
        .collect::<Result<_, _>>()?;
    
    // 使用原子指针替代 RwLock 来存储 endpoint
    let atomic_endpoints: HashMap<u8, Arc<AtomicEndpoint>> = routers
        .keys()
        .map(|&id| (id, Arc::new(AtomicEndpoint::new())))
        .collect();
    
    let (mut router_readers, router_writers): (
        HashMap<u8, RouterReader>,
        HashMap<u8, RouterWriter>,
    ) = routers
        .into_iter()
        .map(|(id, router)| {
            let (reader, writer) = router.split();
            ((id, reader), (id, writer))
        })
        .unzip();
    
    let router_writers3: Vec<(Arc<Socket>, HashMap<u8, RouterWriter>)> = router_writers
        .into_iter()
        .grouping_by(|(_, v)| v.key())
        .into_iter()
        .map(|(k, v)| {
            (
                Arc::clone(sockets.get_mut(&k).unwrap()),
                v.into_iter().collect(),
            )
        })
        .collect();
    
    println!("created tuns");
    
    thread::scope(|s| {
        // 为每个路由器启动多个发送线程
        for (id, router) in router_readers.drain() {
            let endpoint = Arc::clone(&atomic_endpoints[&id]);
            
            // 根据 CPU 核心数创建发送线程
            // 为避免过多线程竞争同一个 TUN 设备，限制最大线程数
            let sender_threads = num_threads.min(4); 
            
            for thread_id in 0..sender_threads {
                let mut router = router.clone(); // 假设 RouterReader 实现了 Clone
                let endpoint = Arc::clone(&endpoint);
                let local_id = config.local_id;
                
                s.spawn(move |_| {
                    // 设置线程 CPU 亲和性（可选）
                    #[cfg(target_os = "linux")]
                    {
                        use libc::{cpu_set_t, CPU_SET, CPU_ZERO, sched_setaffinity};
                        unsafe {
                            let mut cpu_set: cpu_set_t = std::mem::zeroed();
                            CPU_ZERO(&mut cpu_set);
                            CPU_SET(thread_id % num_threads, &mut cpu_set);
                            sched_setaffinity(0, std::mem::size_of_val(&cpu_set), &cpu_set);
                        }
                    }
                    
                    let mut buffer_pool = BufferPool::new(BUFFER_POOL_SIZE_PER_THREAD);
                    let meta_size = size_of::<Meta>();
                    
                    // 预初始化 Meta 头
                    let meta = Meta {
                        src_id: local_id,
                        dst_id: router.config.remote_id,
                        reversed: 0,
                    };
                    
                    loop {
                        // 批量读取多个包
                        let mut buffers = buffer_pool.get_buffers(BATCH_SIZE);
                        let mut valid_count = 0;
                        
                        for buffer in &mut buffers {
                            // 写入 Meta 头
                            let meta_bytes = unsafe {
                                std::slice::from_raw_parts(&meta as *const Meta as *const u8, meta_size)
                            };
                            buffer[..meta_size].copy_from_slice(meta_bytes);
                            
                            // 尝试非阻塞读取
                            match router.tun_reader.read(&mut buffer[meta_size..]) {
                                Ok(n) if n > 0 => {
                                    if let Some(addr) = endpoint.load() {
                                        router.encrypt(&mut buffer[meta_size..meta_size + n]);
                                        buffer.truncate(meta_size + n);
                                        valid_count += 1;
                                    } else {
                                        break;
                                    }
                                }
                                _ => break,
                            }
                        }
                        
                        // 批量发送
                        if valid_count > 0 && endpoint.load().is_some() {
                            #[cfg(target_os = "linux")]
                            let _ = router.socket.set_mark(router.config.mark);
                            
                            // 使用 sendmmsg 批量发送（如果可用）
                            let addr = endpoint.load().unwrap();
                            #[cfg(target_os = "linux")]
                            {
                                use std::os::unix::io::AsRawFd;
                                let fd = router.socket.as_raw_fd();
                                
                                // 准备批量发送的消息
                                let mut messages: Vec<mmsghdr> = Vec::with_capacity(valid_count);
                                let mut iovecs: Vec<iovec> = Vec::with_capacity(valid_count);
                                let mut sockaddrs: Vec<libc::sockaddr_storage> = Vec::with_capacity(valid_count);
                                
                                // TODO: 实现 sendmmsg 批量发送
                                // 暂时使用普通发送
                                for buffer in &buffers[..valid_count] {
                                    let _ = router.socket.send_to(buffer, &addr);
                                }
                            }
                            
                            #[cfg(not(target_os = "linux"))]
                            {
                                for buffer in &buffers[..valid_count] {
                                    let _ = router.socket.send_to(buffer, &addr);
                                }
                            }
                        }
                        
                        // 归还缓冲区
                        for buffer in &mut buffers {
                            buffer.clear();
                            buffer.resize(BUFFER_SIZE, 0);
                        }
                        buffer_pool.return_buffers(buffers);
                        
                        // 避免 CPU 占用过高
                        if valid_count == 0 {
                            std::thread::yield_now();
                        }
                    }
                });
            }
        }
        
        // 为每个 socket 启动多个接收线程
        for (socket, mut router_writers) in router_writers3 {
            // 将 router_writers 转换为使用原子指针的版本
            let atomic_writers: Arc<HashMap<u8, (RouterWriter, Arc<AtomicEndpoint>)>> = Arc::new(
                router_writers
                    .drain()
                    .map(|(id, writer)| {
                        let endpoint = Arc::clone(&atomic_endpoints[&id]);
                        (id, (writer, endpoint))
                    })
                    .collect()
            );
            
            // 根据 CPU 核心数创建接收线程
            let receiver_threads = num_threads.min(8); // 接收线程可以多一些
            
            for thread_id in 0..receiver_threads {
                let socket = Arc::clone(&socket);
                let atomic_writers = Arc::clone(&atomic_writers);
                let local_id = config.local_id;
                let local_secret = local_secret.clone();
                
                s.spawn(move |_| {
                    // 设置线程 CPU 亲和性（可选）
                    #[cfg(target_os = "linux")]
                    {
                        use libc::{cpu_set_t, CPU_SET, CPU_ZERO, sched_setaffinity};
                        unsafe {
                            let mut cpu_set: cpu_set_t = std::mem::zeroed();
                            CPU_ZERO(&mut cpu_set);
                            CPU_SET(thread_id % num_threads, &mut cpu_set);
                            sched_setaffinity(0, std::mem::size_of_val(&cpu_set), &cpu_set);
                        }
                    }
                    
                    let mut buffer_pool = BufferPool::new(BUFFER_POOL_SIZE_PER_THREAD);
                    
                    loop {
                        let mut recv_buffers = buffer_pool.get_buffers(BATCH_SIZE);
                        let mut addrs = vec![MaybeUninit::uninit(); BATCH_SIZE];
                        let mut valid_count = 0;
                        
                        // 批量接收
                        #[cfg(target_os = "linux")]
                        {
                            // TODO: 实现 recvmmsg 批量接收
                            // 暂时使用普通接收
                            for i in 0..BATCH_SIZE {
                                match socket.recv_from(&mut recv_buffers[i]) {
                                    Ok((len, addr)) => {
                                        recv_buffers[i].truncate(len);
                                        addrs[i].write(addr);
                                        valid_count += 1;
                                    }
                                    Err(_) => break,
                                }
                            }
                        }
                        
                        #[cfg(not(target_os = "linux"))]
                        {
                            for i in 0..BATCH_SIZE {
                                match socket.recv_from(&mut recv_buffers[i]) {
                                    Ok((len, addr)) => {
                                        recv_buffers[i].truncate(len);
                                        addrs[i].write(addr);
                                        valid_count += 1;
                                    }
                                    Err(_) => break,
                                }
                            }
                        }
                        
                        // 批量处理接收到的包
                        for i in 0..valid_count {
                            let _ = (|| {
                                let data = &mut recv_buffers[i];
                                let addr = unsafe { addrs[i].assume_init() };
                                
                                let packet = Ipv4Packet::new(data).ok_or("malformed packet")?;
                                let header_len = packet.get_header_length() as usize * 4;
                                let (_ip_header, rest) = data
                                    .split_at_mut_checked(header_len)
                                    .ok_or("malformed packet")?;
                                let (meta_bytes, payload) = rest
                                    .split_at_mut_checked(size_of::<Meta>())
                                    .ok_or("malformed packet")?;
                                let meta: &Meta = unsafe { transmute(meta_bytes.as_ptr()) };
                                
                                if meta.dst_id == local_id && meta.reversed == 0 {
                                    if let Some((router, endpoint)) = atomic_writers.get(&meta.src_id) {
                                        endpoint.store(addr);
                                        router.decrypt(payload, &local_secret);
                                        router.tun_writer.write_all(payload)?;
                                    }
                                }
                                
                                Ok::<(), Box<dyn Error>>(())
                            })();
                        }
                        
                        // 归还缓冲区
                        for buffer in &mut recv_buffers {
                            buffer.clear();
                            buffer.resize(BUFFER_SIZE, 0);
                        }
                        buffer_pool.return_buffers(recv_buffers);
                        
                        if valid_count == 0 {
                            std::thread::yield_now();
                        }
                    }
                });
            }
        }
    })
    .unwrap();
    Ok(())
}
