mod router;

use crate::router::{Router, RouterReader, RouterWriter, SECRET_LENGTH};
use std::collections::HashMap;
use std::env;
use std::error::Error;
use std::io::{Read, Write};
use std::mem::{self, MaybeUninit};
use std::net::SocketAddr;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use std::sync::mpsc::{sync_channel, Receiver, SyncSender, TryRecvError};
use std::time::{Duration, Instant};
use std::thread;

#[repr(C)]
#[derive(Copy, Clone)]
pub struct Meta {
    pub src_id: u8,
    pub dst_id: u8,
    pub reversed: u16,
}

use serde::Deserialize;

#[derive(Deserialize)]
pub struct ConfigRouter {
    pub remote_id: u8,
    pub proto: i32,
    pub family: u8,
    pub mark: u32,
    pub endpoint: String,
    pub remote_secret: String,
    pub dev: String,
    pub up: String,
}

#[derive(Deserialize)]
pub struct Config {
    pub local_id: u8,
    pub local_secret: String,
    pub routers: Vec<ConfigRouter>,
}

use crossbeam_utils::thread as scoped_thread;
use grouping_by::GroupingBy;
use pnet::packet::ipv4::Ipv4Packet;
use socket2::Socket;

// 性能优化配置
const BATCH_SIZE: usize = 32;
const BATCH_TIMEOUT_MILLIS: u64 = 1;
const CHANNEL_SIZE: usize = 1024;
const SOCKET_BUFFER_SIZE: usize = 4 * 1024 * 1024; // 4MB
const MAX_PACKET_SIZE: usize = 1500;

// 数据包结构
#[derive(Clone)]
struct Packet {
    data: Vec<u8>,
    len: usize,
}

impl Packet {
    fn new() -> Self {
        Self {
            data: vec![0u8; MAX_PACKET_SIZE],
            len: 0,
        }
    }
}

// 统计信息
struct Stats {
    packets_sent: AtomicU64,
    packets_recv: AtomicU64,
    bytes_sent: AtomicU64,
    bytes_recv: AtomicU64,
}

fn optimize_socket(socket: &Socket) -> Result<(), Box<dyn Error>> {
    socket.set_send_buffer_size(SOCKET_BUFFER_SIZE)?;
    socket.set_recv_buffer_size(SOCKET_BUFFER_SIZE)?;
    socket.set_nonblocking(false)?;
    
    #[cfg(target_os = "linux")]
    {
        // Linux 特定优化
        socket.set_nodelay(true)?;
    }
    
    Ok(())
}

fn main() -> Result<(), Box<dyn Error>> {
    let config: Config = serde_json::from_str(env::args().nth(1).ok_or("need param")?.as_str())?;
    let local_secret: [u8; SECRET_LENGTH] = Router::create_secret(config.local_secret.as_str())?;
    let mut sockets: HashMap<u16, Arc<Socket>> = HashMap::new();
    
    // 创建并优化 routers
    let routers: HashMap<u8, Router> = config
        .routers
        .iter()
        .map(|c| {
            Router::new(c, &mut sockets).map(|router| {
                // 优化每个 socket
                let router_key = Router::key(c); // 使用关联函数
                if let Some(socket) = sockets.get(&router_key) {
                    let _ = optimize_socket(socket);
                }
                (c.remote_id, router)
            })
        })
        .collect::<Result<_, _>>()?;
    
    let (mut router_readers, router_writers): (
        HashMap<u8, RouterReader>,
        HashMap<u8, RouterWriter>,
    ) = routers
        .into_iter()
        .map(|(id, router)| {
            let (reader, writer) = router.split();
            ((id, reader), (id, writer))
        })
        .unzip();
    
    let router_writers3: Vec<(Arc<Socket>, HashMap<u8, RouterWriter>)> = router_writers
        .into_iter()
        .grouping_by(|(_, v)| v.key())
        .into_iter()
        .map(|(k, v)| {
            (
                Arc::clone(sockets.get_mut(&k).unwrap()),
                v.into_iter().collect(),
            )
        })
        .collect();
    
    // 全局统计
    let stats = Arc::new(Stats {
        packets_sent: AtomicU64::new(0),
        packets_recv: AtomicU64::new(0),
        bytes_sent: AtomicU64::new(0),
        bytes_recv: AtomicU64::new(0),
    });
    
    println!("Created optimized TUN devices");

    scoped_thread::scope(|s| {
        // 为每个 router 创建读写线程
        for router in router_readers.values_mut() {
            let router_id = router.config.remote_id;
            let local_id = config.local_id;
            let stats_clone = Arc::clone(&stats);
            
            // 创建批处理通道
            let (tx, rx): (SyncSender<Packet>, Receiver<Packet>) = sync_channel(CHANNEL_SIZE);
            
            // TUN 读取线程
            s.spawn(move |_| {
                let meta_size = mem::size_of::<Meta>();
                let mut buffer = vec![0u8; MAX_PACKET_SIZE];
                
                // 预构建 Meta
                let meta = Meta {
                    src_id: local_id,
                    dst_id: router_id,
                    reversed: 0,
                };
                let meta_bytes = unsafe {
                    std::slice::from_raw_parts(&meta as *const _ as *const u8, meta_size)
                };
                
                loop {
                    // 写入 Meta 头
                    buffer[..meta_size].copy_from_slice(meta_bytes);
                    
                    // 读取数据
                    match router.tun_reader.read(&mut buffer[meta_size..]) {
                        Ok(n) if n > 0 => {
                            let packet = Packet {
                                data: buffer[..meta_size + n].to_vec(),
                                len: meta_size + n,
                            };
                            let _ = tx.try_send(packet);
                        }
                        _ => continue,
                    }
                }
            });
            
            // 批处理发送线程
            s.spawn(move |_| {
                let mut batch = Vec::with_capacity(BATCH_SIZE);
                let meta_size = mem::size_of::<Meta>();
                let mut last_batch_time = Instant::now();
                
                loop {
                    batch.clear();
                    
                    // 收集数据包
                    match rx.recv_timeout(Duration::from_millis(BATCH_TIMEOUT_MILLIS)) {
                        Ok(packet) => batch.push(packet),
                        Err(_) => {
                            if batch.is_empty() { continue; }
                        }
                    }
                    
                    // 继续收集直到批量大小或超时
                    while batch.len() < BATCH_SIZE {
                        match rx.try_recv() {
                            Ok(packet) => batch.push(packet),
                            Err(TryRecvError::Empty) => {
                                if last_batch_time.elapsed() > Duration::from_millis(BATCH_TIMEOUT_MILLIS) {
                                    break;
                                }
                            }
                            Err(TryRecvError::Disconnected) => return,
                        }
                    }
                    
                    if !batch.is_empty() {
                        // 获取endpoint
                        if let Some(ref addr) = *router.endpoint.read().unwrap() {
                            // 加密所有包
                            for packet in &mut batch {
                                router.encrypt(&mut packet.data[meta_size..packet.len]);
                            }
                            
                            // 设置 mark（每批只设置一次）
                            #[cfg(target_os = "linux")]
                            let _ = router.socket.set_mark(router.config.mark);
                            
                            // 批量发送
                            let mut total_bytes = 0;
                            for packet in &batch {
                                if let Ok(_) = router.socket.send_to(&packet.data[..packet.len], addr) {
                                    total_bytes += packet.len;
                                }
                            }
                            
                            // 更新统计
                            stats_clone.packets_sent.fetch_add(batch.len() as u64, Ordering::Relaxed);
                            stats_clone.bytes_sent.fetch_add(total_bytes as u64, Ordering::Relaxed);
                        }
                        
                        last_batch_time = Instant::now();
                    }
                }
            });
        }

        // 优化的接收线程
        for (socket, mut router_writers) in router_writers3 {
            let stats_clone = Arc::clone(&stats);
            let local_secret_clone = local_secret.clone();
            let config_local_id = config.local_id;
            
            s.spawn(move |_| {
                // 预分配多个缓冲区
                let mut recv_bufs: Vec<[MaybeUninit<u8>; MAX_PACKET_SIZE]> = 
                    (0..4).map(|_| unsafe { MaybeUninit::uninit().assume_init() }).collect();
                let mut buf_idx = 0;
                
                loop {
                    let recv_buf = &mut recv_bufs[buf_idx];
                    buf_idx = (buf_idx + 1) % recv_bufs.len();
                    
                    match socket.recv_from(recv_buf) {
                        Ok((len, sock_addr)) => {
                            // 转换为 SocketAddr
                            let addr = match sock_addr.as_socket() {
                                Some(addr) => addr,
                                None => continue,
                            };
                            
                            // 将 MaybeUninit 转换为初始化的数据
                            let data = unsafe {
                                std::slice::from_raw_parts_mut(
                                    recv_buf.as_mut_ptr() as *mut u8,
                                    len
                                )
                            };
                            
                            // 快速路径检查
                            if len < 20 + mem::size_of::<Meta>() {
                                continue;
                            }
                            
                            if let Some(packet) = Ipv4Packet::new(data) {
                                let header_len = packet.get_header_length() as usize * 4;
                                if header_len < data.len() {
                                    let rest = &mut data[header_len..];
                                    if rest.len() >= mem::size_of::<Meta>() {
                                        let (meta_bytes, payload) = rest.split_at_mut(mem::size_of::<Meta>());
                                        let meta: &Meta = unsafe { &*(meta_bytes.as_ptr() as *const Meta) };
                                        
                                        if meta.dst_id == config_local_id && meta.reversed == 0 {
                                            if let Some(router) = router_writers.get_mut(&meta.src_id) {
                                                // 更新endpoint
                                                *router.endpoint.write().unwrap() = Some(addr);
                                                
                                                // 解密
                                                router.decrypt(payload, &local_secret_clone);
                                                
                                                // 写入TUN
                                                let _ = router.tun_writer.write_all(payload);
                                                
                                                // 更新统计
                                                stats_clone.packets_recv.fetch_add(1, Ordering::Relaxed);
                                                stats_clone.bytes_recv.fetch_add(payload.len() as u64, Ordering::Relaxed);
                                            }
                                        }
                                    }
                                }
                            }
                        }
                        Err(_) => continue,
                    }
                }
            });
        }
        
        // 统计线程
        let stats_clone = Arc::clone(&stats);
        s.spawn(move |_| {
            let mut last_stats = (0u64, 0u64, 0u64, 0u64);
            loop {
                thread::sleep(Duration::from_secs(10));
                let sent_pkts = stats_clone.packets_sent.load(Ordering::Relaxed);
                let recv_pkts = stats_clone.packets_recv.load(Ordering::Relaxed);
                let sent_bytes = stats_clone.bytes_sent.load(Ordering::Relaxed);
                let recv_bytes = stats_clone.bytes_recv.load(Ordering::Relaxed);
                
                let sent_pps = (sent_pkts - last_stats.0) / 10;
                let recv_pps = (recv_pkts - last_stats.1) / 10;
                let sent_mbps = ((sent_bytes - last_stats.2) * 8) / (10 * 1024 * 1024);
                let recv_mbps = ((recv_bytes - last_stats.3) * 8) / (10 * 1024 * 1024);
                
                println!(
                    "Stats - TX: {} pps, {} Mbps | RX: {} pps, {} Mbps",
                    sent_pps, sent_mbps, recv_pps, recv_mbps
                );
                
                last_stats = (sent_pkts, recv_pkts, sent_bytes, recv_bytes);
            }
        });
    })
    .unwrap();
    
    Ok(())
}
