use anyhow::{bail, ensure, Result};
use socket2::{Domain, Protocol, SockAddr, SockFilter, Socket, Type};
use std::net::Shutdown;
use std::{
    ffi::c_void,
    mem::MaybeUninit,
    net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6, ToSocketAddrs},
    ops::Range,
    os::fd::{AsRawFd, FromRawFd},
    process::{Command, ExitStatus},
    sync::atomic::Ordering,
};
use tun::Device;

use crate::config::{ConfigRouter, Schema};
use crossbeam::epoch::{pin, Atomic};
use libc::{
    setsockopt, sock_filter, sock_fprog, socklen_t, BPF_ABS, BPF_B, BPF_IND, BPF_JEQ, BPF_JMP, BPF_K, BPF_LD, BPF_LDX, BPF_MEM, BPF_MSH, BPF_RET, BPF_ST, BPF_W,
    MSG_FASTOPEN, SOL_SOCKET, SO_ATTACH_REUSEPORT_CBPF,
};

pub const SECRET_LENGTH: usize = 32;
pub const META_SIZE: usize = size_of::<Meta>();

#[repr(C)]
#[derive(Debug, Clone, Copy, Default)]
pub struct Meta {
    pub src_id: u32,
    pub dst_id: u32,
}
impl Meta {
    pub fn as_bytes(&self) -> &[u8; META_SIZE] {
        unsafe { &*(self as *const Meta as *const [u8; META_SIZE]) }
    }
    pub fn from_bytes(bytes: &[MaybeUninit<u8>; META_SIZE]) -> &Meta {
        unsafe { &*(bytes.as_ptr() as *const Meta) }
    }
}

pub struct Router {
    pub config: ConfigRouter,
    pub tun: Device,
    pub socket: Socket,
    pub endpoint: Atomic<SockAddr>,

    pub tcp_listener_connection: Atomic<Socket>,
}

#[inline]
fn xor_with_secret_offset<const L: usize>(data: &mut [u8], secret: &[u8; L], offset: usize) {
    let len = data.len();
    if len == 0 { return; }

    let mut i = 0;
    let mut key_pos = offset % L;

    // 1) 先把 key_pos 补到 0（也就是把相位对齐到块边界），这样后面能走“完整块”
    if key_pos != 0 {
        let head = (L - key_pos).min(len);
        for j in 0..head {
            data[j] ^= secret[key_pos + j]; // 这里不会越界，因为 j < L - key_pos
        }
        i += head;
        key_pos = 0;
    }

    // 2) 处理完整块（key_pos 已经对齐到 0）
    while i + L <= len {
        for j in 0..L {
            data[i + j] ^= secret[j];
        }
        i += L;
    }

    // 3) 处理尾部
    for j in 0..(len - i) {
        data[i + j] ^= secret[j];
    }
}

impl Router {
    pub(crate) fn decrypt(&self, data: &mut [u8], secret: &[u8; SECRET_LENGTH]) {
        xor_with_secret_offset::<SECRET_LENGTH>(data, secret, 0);
    }

    pub(crate) fn decrypt2(
        &self,
        data: &mut [u8],
        secret: &[u8; SECRET_LENGTH],
        range: Range<usize>,
    ) {
        xor_with_secret_offset::<SECRET_LENGTH>(&mut data[range.clone()], secret, range.start);
    }

    pub(crate) fn encrypt(&self, data: &mut [u8]) {
        xor_with_secret_offset::<SECRET_LENGTH>(data, &self.config.remote_secret, 0);
    }

    pub fn create_socket(config: &ConfigRouter) -> Result<Socket> {
        println!("create_socket {}", config.remote_id);
        match config.schema {
            Schema::IP => {
                let result = Socket::new(config.family, Type::RAW, Some(Protocol::from(config.proto as i32)))?;
                if config.mark != 0 {
                    result.set_mark(config.mark)?;
                }
                Self::attach_filter_ip(config, &result)?;
                Ok(result)
            }
            Schema::UDP => {
                let result = Socket::new(config.family, Type::DGRAM, Some(Protocol::UDP))?;
                if config.mark != 0 {
                    result.set_mark(config.mark)?;
                }
                if config.src_port != 0 {
                    result.set_reuse_port(true)?;
                    let addr = Self::bind_addr(config);
                    result.bind(&addr)?;
                }
                Ok(result)
            }
            Schema::TCP => Ok(unsafe { Socket::from_raw_fd(0) }),
        }
    }

    pub fn listen_tcp(&self) -> Socket {
        // listener
        let result = Socket::new(Domain::IPV6, Type::STREAM, Some(Protocol::TCP)).unwrap();
        result.set_reuse_address(true).unwrap();
        let addr = SockAddr::from(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, self.config.src_port, 0, 0));
        result.bind(&addr).unwrap();
        result.listen(100).unwrap();
        result
    }

    pub fn connect_tcp(&self) -> Result<Socket> {
        // tcp client 的 socket 不要在初始化时创建，在循环里创建
        // 创建 socket 和 获取 endpoint 失败会 panic，连接失败会 error
        let result = Socket::new(self.config.family, Type::STREAM, Some(Protocol::TCP)).unwrap();
        result.set_tcp_nodelay(true).unwrap();
        if self.config.mark != 0 {
            result.set_mark(self.config.mark).unwrap();
        }
        if self.config.src_port != 0 {
            result.set_reuse_address(true).unwrap();
            let addr = Self::bind_addr(&self.config);
            result.bind(&addr)?;
        }

        let meta = Meta {
            src_id: self.config.local_id,
            dst_id: self.config.remote_id,
        };
        let guard = pin();
        let endpoint_ref = self.endpoint.load(Ordering::Relaxed, &guard);
        let endpoint = unsafe { endpoint_ref.as_ref() }.unwrap();

        result.send_to_with_flags(meta.as_bytes(), endpoint, MSG_FASTOPEN)?;
        Ok(result)
    }

    fn attach_filter_ip(config: &ConfigRouter, socket: &Socket) -> Result<()> {
        // 由于多个对端可能会使用相同的 ipproto 号，这里确保每个 socket 上只会收到自己对应的对端发来的消息
        
        // 构造 Meta 来计算正确的字节序比较值
        let meta_bytes = [
            config.remote_id.to_le_bytes(),
            config.local_id.to_le_bytes(),
        ];
        
        // BPF 按网络字节序（大端序）比较，所以需要把小端序字节当作大端序来构造比较值
        let expected_src_id = u32::from_be_bytes(meta_bytes[0]);
        let expected_dst_id = u32::from_be_bytes(meta_bytes[1]);

        // IP filter 工作原理：
        // 每个对端起一个 raw socket
        // 根据报文内容判断是给谁的。拒绝掉不是给自己的报文
        // IPv4 raw socket 带 IP 头，IPv6 不带
        // Meta 结构：src_id(u32) + dst_id(u32) = 8 字节

        let filters: &[SockFilter] = match socket.domain()? {
            Domain::IPV4 => &[
                // [IPv4] 计算 IPv4 头长度: X = 4 * (IP[0] & 0xf)
                bpf_stmt(BPF_LDX | BPF_B | BPF_MSH, 0),
                // A = Packet[X + 0:4] = src_id
                bpf_stmt(BPF_LD | BPF_W | BPF_IND, 0),
                // if A != expected_src_id, goto reject
                bpf_jump(BPF_JMP | BPF_JEQ | BPF_K, expected_src_id, 0, 3),
                // A = Packet[X + 4:8] = dst_id
                bpf_stmt(BPF_LD | BPF_W | BPF_IND, 4),
                // if A != expected_dst_id, goto reject
                bpf_jump(BPF_JMP | BPF_JEQ | BPF_K, expected_dst_id, 0, 1),
                // 【接受】
                bpf_stmt(BPF_RET | BPF_K, u32::MAX),
                // 【拒绝】
                bpf_stmt(BPF_RET | BPF_K, 0),
            ],
            Domain::IPV6 => &[
                // raw socket IPv6 没有 header
                // A = Packet[0:4] = src_id
                bpf_stmt(BPF_LD | BPF_W | BPF_ABS, 0),
                // if A != expected_src_id, goto reject
                bpf_jump(BPF_JMP | BPF_JEQ | BPF_K, expected_src_id, 0, 3),
                // A = Packet[4:8] = dst_id
                bpf_stmt(BPF_LD | BPF_W | BPF_ABS, 4),
                // if A != expected_dst_id, goto reject
                bpf_jump(BPF_JMP | BPF_JEQ | BPF_K, expected_dst_id, 0, 1),
                // 【接受】
                bpf_stmt(BPF_RET | BPF_K, u32::MAX),
                // 【拒绝】
                bpf_stmt(BPF_RET | BPF_K, 0),
            ],
            _ => bail!("unsupported family"),
        };
        socket.attach_filter(filters)?;

        Ok(())
    }

    pub fn attach_filter_udp(group: Vec<&Router>) -> Result<()> {
        // 预留空间：4 条前置指令 + 每个 router 5 条 + 1 条默认返回
        let mut filters: Vec<SockFilter> = Vec::with_capacity(4 + group.len() * 5 + 1);

        // udp filter 工作原理：
        // 每个对端起一个 udp socket
        // 根据报文内容判断是给谁的，调度给对应的端口复用组序号
        // Meta 结构：src_id(u32) + dst_id(u32) = 8 字节

        // 加载 src_id 并存储到 M[0]
        filters.push(bpf_stmt(BPF_LD | BPF_W | BPF_ABS, 0));  // A = packet[0:4] = src_id
        filters.push(bpf_stmt(BPF_ST, 0));                    // M[0] = A

        // 加载 dst_id 并存储到 M[1]
        filters.push(bpf_stmt(BPF_LD | BPF_W | BPF_ABS, 4));  // A = packet[4:8] = dst_id
        filters.push(bpf_stmt(BPF_ST, 1));                    // M[1] = A

        for (i, router) in group.iter().enumerate() {
            // 字节序转换：将小端序ID转换为BPF期望的大端序比较值
            let src_bytes = router.config.remote_id.to_le_bytes();
            let dst_bytes = router.config.local_id.to_le_bytes();
            let expected_src_id = u32::from_be_bytes(src_bytes);
            let expected_dst_id = u32::from_be_bytes(dst_bytes);

            // 每个 router 5 条指令：
            // 0: LD M[0]                      ; A = src_id
            // 1: JEQ expected_src_id, +0, +3  ; 匹配继续，不匹配跳过当前 router
            // 2: LD M[1]                      ; A = dst_id
            // 3: JEQ expected_dst_id, +0, +1  ; 匹配继续，不匹配跳过当前 router
            // 4: RET i                        ; 返回索引

            filters.push(bpf_stmt(BPF_LD | BPF_MEM, 0));
            filters.push(bpf_jump(BPF_JMP | BPF_JEQ | BPF_K, expected_src_id, 0, 3));
            filters.push(bpf_stmt(BPF_LD | BPF_MEM, 1));
            filters.push(bpf_jump(BPF_JMP | BPF_JEQ | BPF_K, expected_dst_id, 0, 1));
            filters.push(bpf_stmt(BPF_RET | BPF_K, i as u32));
        }

        // 默认返回（不匹配任何 router）
        filters.push(bpf_stmt(BPF_RET | BPF_K, u32::MAX));

        let prog = sock_fprog {
            len: filters.len() as u16,
            filter: filters.as_mut_ptr() as *mut sock_filter,
        };
        let fd = group[0].socket.as_raw_fd();
        let ret = unsafe {
            setsockopt(
                fd,
                SOL_SOCKET,
                SO_ATTACH_REUSEPORT_CBPF,
                &prog as *const _ as *const c_void,
                size_of_val(&prog) as socklen_t,
            )
        };
        ensure!(ret != -1, std::io::Error::last_os_error());
        Ok(())
    }

    pub(crate) fn handle_outbound_ip_udp(&self) {
        let mut buffer = [0u8; 1500];

        // Pre-initialize with our Meta header (local -> remote)
        let meta = Meta {
            src_id: self.config.local_id,
            dst_id: self.config.remote_id,
        };
        buffer[..META_SIZE].copy_from_slice(meta.as_bytes());

        loop {
            let n = self.tun.recv(&mut buffer[META_SIZE..]).unwrap(); // recv 失败直接 panic
            let guard = pin();
            let endpoint_ref = self.endpoint.load(Ordering::Relaxed, &guard);
            if let Some(endpoint) = unsafe { endpoint_ref.as_ref() } {
                self.encrypt(&mut buffer[META_SIZE..META_SIZE + n]);
                let _ = self.socket.send_to(&buffer[..META_SIZE + n], endpoint);
            }
        }
    }

    pub(crate) fn handle_inbound_ip_udp(&self) {
        let mut recv_buf = [MaybeUninit::uninit(); 1500];
        loop {
            // 收到一个非法报文只丢弃一个报文
            let (len, addr) = { self.socket.recv_from(&mut recv_buf).unwrap() }; // recv 出错直接 panic
            let packet = unsafe { std::slice::from_raw_parts_mut(recv_buf.as_mut_ptr().cast(), len) };
            // if addr.is_ipv6() { println!("{:X?}", packet) }
            // 只有 ipv4 raw 会给 IP报头
            let offset = if self.config.family == Domain::IPV4 && self.config.schema == Schema::IP {
                (packet[0] & 0x0f) as usize * 4
            } else {
                0
            } + META_SIZE;

            {
                let guard = pin();
                let current_shared = self.endpoint.load(Ordering::Relaxed, &guard);
                let is_same = unsafe { current_shared.as_ref() }.map(|c| *c == addr).unwrap_or(false);
                if !is_same {
                    let new_shared = crossbeam::epoch::Owned::new(addr).into_shared(&guard);
                    let old_shared = self.endpoint.swap(new_shared, Ordering::Release, &guard);
                    unsafe { guard.defer_destroy(old_shared) }
                }
            }

            let payload = &mut packet[offset..];
            self.decrypt(payload, &self.config.local_secret);
            let _ = self.tun.send(payload);
        }
    }

    pub(crate) fn handle_outbound_tcp(&self, connection: &Socket) {
        let _ = (|| -> Result<()> {
            let mut buffer = [0u8; 1500];
            loop {
                let n = self.tun.recv(&mut buffer)?;
                self.encrypt(&mut buffer[..n]);
                Router::send_all_tcp(&connection, &buffer[..n])?;
            }
        })();
        let _ = connection.shutdown(Shutdown::Both);
    }
    pub(crate) fn handle_inbound_tcp(&self, connection: &Socket) {
        let _ = (|| -> Result<()> {
            let mut buf = [MaybeUninit::uninit(); 1500];
            let packet: &mut [u8] = unsafe { std::slice::from_raw_parts_mut(buf.as_mut_ptr().cast(), buf.len()) };
            loop {
                Router::recv_exact_tcp(&connection, &mut buf[0..6])?;
                self.decrypt2(packet, &self.config.local_secret, 0..6);
                let version = packet[0] >> 4;
                let total_len = match version {
                    4 => u16::from_be_bytes([packet[2], packet[3]]) as usize,
                    6 => u16::from_be_bytes([packet[4], packet[5]]) as usize + 40,
                    _ => bail!("Invalid IP version"),
                };
                ensure!(6 < total_len && total_len <= buf.len(), "Invalid total length");
                Router::recv_exact_tcp(&connection, &mut buf[6..total_len])?;
                self.decrypt2(packet, &self.config.local_secret, 6..total_len);
                self.tun.send(&packet[..total_len])?;
            }
        })();
        let _ = connection.shutdown(Shutdown::Both);
    }

    pub(crate) fn recv_exact_tcp(sock: &Socket, mut buf: &mut [MaybeUninit<u8>]) -> Result<()> {
        while !buf.is_empty() {
            let n = sock.recv(buf)?;
            ensure!(n != 0, std::io::ErrorKind::UnexpectedEof);
            buf = &mut buf[n..];
        }
        Ok(())
    }

    pub(crate) fn send_all_tcp(sock: &Socket, mut buf: &[u8]) -> Result<()> {
        while !buf.is_empty() {
            let n = sock.send(buf)?;
            buf = &buf[n..];
        }
        Ok(())
    }

    fn create_tun_device(config: &ConfigRouter) -> Result<Device> {
        println!("create_tun_device {}", config.remote_id);
        let mut tun_config = tun::Configuration::default();
        tun_config.tun_name(config.dev.as_str()).up();

        let dev = tun::create(&tun_config)?;
        Ok(dev)
    }
    fn run_up_script(config: &ConfigRouter) -> Result<ExitStatus> {
        Ok(Command::new("sh").args(["-c", config.up.as_str()]).status()?)
    }

    fn create_endpoint(config: &ConfigRouter) -> Atomic<SockAddr> {
        println!("create_endpoint {}", config.remote_id);
        match (config.endpoint.clone(), config.dst_port)
            .to_socket_addrs()
            .unwrap_or_default()
            .filter(|a| match config.family {
                Domain::IPV4 => a.is_ipv4(),
                Domain::IPV6 => a.is_ipv6(),
                _ => false,
            })
            .next()
        {
            None => Atomic::null(),
            Some(addr) => Atomic::new(addr.into()),
        }
    }

    pub fn new(config: ConfigRouter) -> Result<Router> {
        println!("creating {}", config.remote_id);
        let router = Router {
            tun: Self::create_tun_device(&config)?,
            endpoint: Self::create_endpoint(&config),
            socket: Self::create_socket(&config)?,
            tcp_listener_connection: Atomic::null(),
            config,
        };
        println!("run_up_script {}", &router.config.remote_id);
        Self::run_up_script(&router.config)?;
        Ok(router)
    }

    fn bind_addr(config: &ConfigRouter) -> SockAddr {
        match config.family {
            Domain::IPV4 => SockAddr::from(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, config.src_port)),
            Domain::IPV6 => SockAddr::from(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, config.src_port, 0, 0)),
            _ => panic!("unsupported family"),
        }
    }
}

fn bpf_stmt(code: u32, k: u32) -> SockFilter {
    SockFilter::new(code as u16, 0, 0, k)
}

fn bpf_jump(code: u32, k: u32, jt: u8, jf: u8) -> SockFilter {
    SockFilter::new(code as u16, jt, jf, k)
}
