mod router;

use crate::router::{Router, RouterReader, RouterWriter, SECRET_LENGTH};
use std::collections::HashMap;
use std::env;
use std::error::Error;
use std::intrinsics::transmute;
use std::io::{Read, Write};
use std::mem::MaybeUninit;
use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::net::SocketAddr;
use parking_lot::RwLock;

#[repr(C)]
pub struct Meta {
    pub src_id: u8,
    pub dst_id: u8,
    pub reversed: u16,
}

use serde::Deserialize;

#[derive(Deserialize)]
pub struct ConfigRouter {
    pub remote_id: u8,
    pub proto: i32,
    pub family: u8,
    pub mark: u32,
    pub endpoint: String,
    pub remote_secret: String,
    pub dev: String,
    pub up: String,
}

#[derive(Deserialize)]
pub struct Config {
    pub local_id: u8,
    pub local_secret: String,
    pub routers: Vec<ConfigRouter>,
}
use crossbeam_utils::thread;
use grouping_by::GroupingBy;
use pnet::packet::ipv4::Ipv4Packet;
use socket2::Socket;

// 优化的 RouterWriter 包装器，用于缓存 mark 状态
struct OptimizedRouterWriter {
    writer: RouterWriter,
    #[cfg(target_os = "linux")]
    mark_set: AtomicBool,
}

fn main() -> Result<(), Box<dyn Error>> {
    let config: Config = serde_json::from_str(env::args().nth(1).ok_or("need param")?.as_str())?;
    let local_secret: [u8; SECRET_LENGTH] = Router::create_secret(config.local_secret.as_str())?;
    let mut sockets: HashMap<u16, Arc<Socket>> = HashMap::new();
    let routers: HashMap<u8, Router> = config
        .routers
        .iter()
        .map(|c| Router::new(c, &mut sockets).map(|router| (c.remote_id, router)))
        .collect::<Result<_, _>>()?;
    let (mut router_readers, router_writers): (
        HashMap<u8, RouterReader>,
        HashMap<u8, RouterWriter>,
    ) = routers
        .into_iter()
        .map(|(id, router)| {
            let (reader, writer) = router.split();
            ((id, reader), (id, writer))
        })
        .unzip();
    
    // 使用 parking_lot 的 RwLock 替换标准库的 RwLock
    let router_writers = router_writers
        .into_iter()
        .map(|(id, writer)| {
            (id, Arc::new(RwLock::new(OptimizedRouterWriter {
                writer,
                #[cfg(target_os = "linux")]
                mark_set: AtomicBool::new(false),
            })))
        })
        .collect::<HashMap<_, _>>();
    
    let router_writers3: Vec<(Arc<Socket>, HashMap<u8, Arc<RwLock<OptimizedRouterWriter>>>)> = router_writers
        .iter()
        .map(|(id, writer)| (*id, writer.read().writer.key(), Arc::clone(writer)))
        .into_iter()
        .grouping_by(|(_, key, _)| *key)
        .into_iter()
        .map(|(k, v)| {
            (
                Arc::clone(sockets.get_mut(&k).unwrap()),
                v.into_iter().map(|(id, _, writer)| (id, writer)).collect(),
            )
        })
        .collect();
    
    println!("created tuns");

    thread::scope(|s| {
        // 发送线程池
        for router in router_readers.values_mut() {
            s.spawn(|_| {
                // 使用更大的缓冲区以支持巨帧
                let mut buffer = vec![0u8; 9000];
                let meta_size = size_of::<Meta>();

                // 预初始化 Meta 头部
                let meta = Meta {
                    src_id: config.local_id,
                    dst_id: router.config.remote_id,
                    reversed: 0,
                };
                
                // 直接写入缓冲区，避免额外的内存分配
                unsafe {
                    let meta_ptr = buffer.as_mut_ptr() as *mut Meta;
                    *meta_ptr = meta;
                }

                loop {
                    match router.tun_reader.read(&mut buffer[meta_size..]) {
                        Ok(n) => {
                            // 使用 try_read 减少锁争用
                            if let Ok(endpoint_guard) = router.endpoint.try_read() {
                                if let Some(ref addr) = *endpoint_guard {
                                    // 原地加密，避免额外的内存拷贝
                                    router.encrypt(&mut buffer[meta_size..meta_size + n]);
                                    
                                    #[cfg(target_os = "linux")]
                                    {
                                        // 只在第一次或 mark 改变时设置
                                        if !router.mark_set.load(Ordering::Relaxed) {
                                            let _ = router.socket.set_mark(router.config.mark);
                                            router.mark_set.store(true, Ordering::Relaxed);
                                        }
                                    }
                                    
                                    let _ = router.socket.send_to(&buffer[..meta_size + n], addr);
                                }
                            }
                        }
                        Err(_) => {
                            // 忽略读取错误，继续循环
                            continue;
                        }
                    }
                }
            });
        }

        // 接收线程池
        for (socket, router_writers_map) in router_writers3 {
            s.spawn(move |_| {
                // 使用更大的缓冲区
                let mut recv_buf = vec![MaybeUninit::uninit(); 9000];
                let meta_size = size_of::<Meta>();
                
                loop {
                    match socket.recv_from(&mut recv_buf) {
                        Ok((len, addr)) => {
                            // 快速路径：直接处理数据，不创建 Ipv4Packet
                            if len < 20 + meta_size {
                                continue; // 数据包太小，跳过
                            }
                            
                            let data: &mut [u8] = unsafe { 
                                std::slice::from_raw_parts_mut(
                                    recv_buf.as_mut_ptr() as *mut u8, 
                                    len
                                )
                            };
                            
                            // 快速获取 IP 头部长度
                            let header_len = ((data[0] & 0x0f) as usize) * 4;
                            
                            if len < header_len + meta_size {
                                continue; // 数据包格式错误，跳过
                            }
                            
                            // 直接解析 Meta 结构
                            let meta: &Meta = unsafe { 
                                &*(data[header_len..].as_ptr() as *const Meta)
                            };
                            
                            if meta.dst_id == config.local_id && meta.reversed == 0 {
                                // 使用 try_read 减少锁争用
                                if let Some(router_lock) = router_writers_map.get(&meta.src_id) {
                                    if let Ok(mut router) = router_lock.try_write() {
                                        // 更新端点地址
                                        if let Ok(mut endpoint) = router.writer.endpoint.try_write() {
                                            *endpoint = Some(addr);
                                        }
                                        
                                        // 原地解密
                                        let payload_start = header_len + meta_size;
                                        let payload_len = len - payload_start;
                                        router.writer.decrypt(
                                            &mut data[payload_start..len], 
                                            &local_secret
                                        );
                                        
                                        // 写入 TUN 设备
                                        let _ = router.writer.tun_writer.write_all(
                                            &data[payload_start..len]
                                        );
                                    }
                                }
                            }
                        }
                        Err(_) => {
                            // 忽略接收错误，继续循环
                            continue;
                        }
                    }
                }
            });
        }
    })
    .unwrap();
    Ok(())
}
