use base64::prelude::BASE64_STANDARD;
use base64::Engine;
use socket2::{Domain, Protocol, SockAddr, SockFilter, Socket, Type};
use std::collections::HashMap;
use std::ffi::c_void;
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6, ToSocketAddrs};
use std::os::fd::AsRawFd;
use std::process::{Command, ExitStatus};
use std::sync::Arc;
use tun::{Reader, Writer};
pub const SECRET_LENGTH: usize = 32;
use crate::{ConfigRouter, Meta, Schema};
use crossbeam::epoch::Atomic;
use libc::{
    sock_filter, sock_fprog, socklen_t, BPF_ABS, BPF_B, BPF_IND, BPF_JEQ, BPF_JMP, BPF_K, BPF_LD, BPF_LDX,
    BPF_MSH, BPF_RET, BPF_W, SOL_SOCKET, SO_ATTACH_REUSEPORT_CBPF,
};

// tun -> raw
pub struct RouterReader {
    pub config: ConfigRouter,
    pub secret: [u8; SECRET_LENGTH],
    pub tun_reader: Reader,
    pub socket: Arc<Socket>,
    pub endpoint: Arc<Atomic<SockAddr>>,
}

impl RouterReader {
    pub(crate) fn encrypt(&self, data: &mut [u8]) {
        for (i, b) in data.iter_mut().enumerate() {
            *b ^= self.secret[i % SECRET_LENGTH];
        }
    }
}

// raw -> tun
pub struct RouterWriter {
    pub tun_writer: Writer,
    pub socket: Arc<Socket>,
    pub endpoint: Arc<Atomic<SockAddr>>,
    pub last_activity: std::time::Instant,
}

impl RouterWriter {
    pub(crate) fn decrypt(&self, data: &mut [u8], secret: &[u8; SECRET_LENGTH]) {
        for (i, b) in data.iter_mut().enumerate() {
            *b ^= secret[i % SECRET_LENGTH];
        }
    }
}

pub struct Router {
    pub config: ConfigRouter,
    pub secret: [u8; SECRET_LENGTH],
    pub tun_reader: Reader,
    pub tun_writer: Writer,
    pub socket: Arc<Socket>,
    pub endpoint: Arc<Atomic<SockAddr>>,
}

impl Router {
    pub(crate) fn create_secret(
        config: &str,
    ) -> Result<[u8; SECRET_LENGTH], Box<dyn std::error::Error>> {
        let mut secret = [0u8; SECRET_LENGTH];
        let decoded = BASE64_STANDARD.decode(config)?;
        let len = decoded.len().min(SECRET_LENGTH);
        secret[..len].copy_from_slice(&decoded[..len]);
        Ok(secret)
    }

    fn create_socket(
        config: &ConfigRouter,
        local_id: u8,
        groups: &mut HashMap<u16, (Arc<Socket>, Vec<u8>)>,
    ) -> Result<Arc<Socket>, Box<dyn std::error::Error>> {
        match config.schema {
            Schema::IP => {
                let result = Socket::new(
                    config.family,
                    Type::RAW,
                    Some(Protocol::from(config.proto as i32)),
                )?;
                #[cfg(target_os = "linux")]
                result.set_mark(config.mark)?;
                Self::attach_filter_raw(config, local_id, &result)?;
                Ok(Arc::new(result))
            }
            Schema::UDP => {
                let result = Socket::new(config.family, Type::DGRAM, Some(Protocol::UDP))?;
                if config.src_port != 0 {
                    result.set_reuse_port(true)?;
                    let addr = match config.family {
                        Domain::IPV4 => SockAddr::from(SocketAddrV4::new(
                            Ipv4Addr::UNSPECIFIED,
                            config.src_port,
                        )),
                        Domain::IPV6 => SockAddr::from(SocketAddrV6::new(
                            Ipv6Addr::UNSPECIFIED,
                            config.src_port,
                            0,
                            0,
                        )),
                        _ => return Err("unsupported family".into()),
                    };
                    result.bind(&addr)?;
                    let result1 = Arc::new(result);
                    match groups.get_mut(&config.src_port) {
                        None => {
                            groups
                                .insert(config.src_port, (result1.clone(), vec![config.remote_id]));
                        }
                        Some((_, group)) => {
                            group.push(config.remote_id);
                        }
                    }

                    Ok(result1)
                } else {
                    Ok(Arc::new(result))
                }
            }
            Schema::TCP => {
                let result = Socket::new(config.family, Type::STREAM, Some(Protocol::TCP))?;
                Ok(Arc::new(result))
            }
            Schema::FakeTCP => {
                let result = Socket::new(config.family, Type::STREAM, Some(Protocol::TCP))?;
                Ok(Arc::new(result))
            }
        }
    }

    fn attach_filter_raw(
        config: &ConfigRouter,
        local_id: u8,
        socket: &Socket,
    ) -> Result<(), Box<dyn std::error::Error>> {
        // 由于多个对端可能会使用相同的 ipprpto 号，这里确保每个 socket 上只会收到自己对应的对端发来的消息
        const META_SIZE: usize = size_of::<Meta>();
        let meta = Meta {
            src_id: config.remote_id,
            dst_id: local_id,
            reversed: 0,
        };
        let meta_bytes: [u8; META_SIZE] =
            unsafe { *(&meta as *const Meta as *const [u8; META_SIZE]) };
        let value = u32::from_be_bytes(meta_bytes);

        // 如果你的协议是 UDP，这里必须是 8 (跳过 UDP 头: SrcPort(2)+DstPort(2)+Len(2)+Sum(2))
        // 如果是纯自定义 IP 协议，这里是 0
        let payload_offset = 0;

        let filters: &[SockFilter] = match socket.domain()? {
            Domain::IPV4 => &[
                // [IPv4] 计算 IPv4 头长度: X = 4 * (IP[0] & 0xf)
                bpf_stmt(BPF_LDX | BPF_B | BPF_MSH, 0),
                // A = Packet[X + payload_offset]
                bpf_stmt(BPF_LD | BPF_W | BPF_IND, payload_offset),
                // if (A == target_val) goto Accept; else goto Reject;
                bpf_jump(BPF_JMP | BPF_JEQ | BPF_K, value, 0, 1),
                // 【接受 (True 路径)】
                bpf_stmt(BPF_RET | BPF_K, u32::MAX),
                // 【拒绝 (False 路径)】
                bpf_stmt(BPF_RET | BPF_K, 0),
            ],
            Domain::IPV6 => &[
                // raw socket IPv6 没有 header，加载第 0 字节到累加器 A
                bpf_stmt(BPF_LD | BPF_W | BPF_ABS, 0),
                // if (A == target_val) goto Accept; else goto Reject;
                bpf_jump(BPF_JMP | BPF_JEQ | BPF_K, value, 0, 1),
                // 【接受 (True 路径)】
                bpf_stmt(BPF_RET | BPF_K, u32::MAX),
                // 【拒绝 (False 路径)】
                bpf_stmt(BPF_RET | BPF_K, 0),
            ],
            _ => Err("unsupported family")?,
        };
        socket.attach_filter(filters)?;

        Ok(())
    }

    pub fn attach_filter_udp(
        socket: &Arc<Socket>,
        group: &Vec<u8>,
        local_id: u8,
    ) -> Result<(), Box<dyn std::error::Error>> {
        let values: Vec<u32> = group
            .iter()
            .map(|&f| {
                const META_SIZE: usize = size_of::<Meta>();
                let meta = Meta {
                    src_id: f,
                    dst_id: local_id,
                    reversed: 0,
                };
                let meta_bytes: [u8; META_SIZE] =
                    unsafe { *(&meta as *const Meta as *const [u8; META_SIZE]) };
                u32::from_be_bytes(meta_bytes)
            })
            .collect();

        let mut filters: Vec<SockFilter> = Vec::with_capacity(1 + values.len() * 2 + 1);
        // Load the first 4 bytes of the packet into the accumulator (A)
        filters.push(bpf_stmt(BPF_LD | BPF_W | BPF_ABS, 0));
        for (i, &val) in values.iter().enumerate() {
            // 如果匹配，继续下一句(返回)，如果不匹配，跳过下一句。
            filters.push(bpf_jump(BPF_JMP | BPF_JEQ | BPF_K, val, 0, 1));
            // If match, return the index (i + 1, since 0 means drop)
            filters.push(bpf_stmt(BPF_RET | BPF_K, i as u32));
        }
        // If no match found after all comparisons, drop the packet
        filters.push(bpf_stmt(BPF_RET | BPF_K, u32::MAX));
        Self::attach_reuseport_cbpf(socket, &mut filters)?;
        Ok(())
    }

    fn attach_reuseport_cbpf(
        sock: &Arc<Socket>,
        code: &mut [SockFilter],
    ) -> Result<(), Box<dyn std::error::Error>> {
        let prog = sock_fprog {
            len: code.len() as u16,
            filter: code.as_mut_ptr() as *mut sock_filter,
        };
        let fd = sock.as_raw_fd();

        let ret = unsafe {
            libc::setsockopt(
                fd,
                SOL_SOCKET,
                SO_ATTACH_REUSEPORT_CBPF,
                &prog as *const _ as *const c_void,
                size_of_val(&prog) as socklen_t,
            )
        };

        if ret == -1 {
            Err(std::io::Error::last_os_error())?;
        }

        Ok(())
    }

    fn create_tun_device(
        config: &ConfigRouter,
    ) -> Result<(Reader, Writer), Box<dyn std::error::Error>> {
        let mut tun_config = tun::Configuration::default();
        tun_config.tun_name(config.dev.as_str()).up();

        let dev = tun::create(&tun_config)?;
        Ok(dev.split())
    }
    fn run_up_script(config: &ConfigRouter) -> std::io::Result<ExitStatus> {
        Command::new("sh").args(["-c", config.up.as_str()]).status()
    }

    fn create_endpoint(config: &ConfigRouter) -> Arc<Atomic<SockAddr>> {
        let addr = match (config.endpoint.clone(), config.dst_port)
            .to_socket_addrs()
            .unwrap_or_default()
            .filter(|a| match config.family {
                Domain::IPV4 => a.is_ipv4(),
                Domain::IPV6 => a.is_ipv6(),
                _ => false,
            })
            .next()
        {
            None => Atomic::null(),
            Some(addr) => Atomic::new(addr.into()),
        };

        Arc::new(addr)
    }

    pub fn new(
        config: ConfigRouter,
        local_id: u8,
        udp_count: &mut HashMap<u16, (Arc<Socket>, Vec<u8>)>,
    ) -> Result<Router, Box<dyn std::error::Error>> {
        let secret = Self::create_secret(config.remote_secret.as_str())?;
        let endpoint = Self::create_endpoint(&config);
        let socket = Self::create_socket(&config, local_id, udp_count)?;
        let (tun_reader, tun_writer) = Self::create_tun_device(&config)?;
        Self::run_up_script(&config)?;

        let router = Router {
            config,
            secret,
            endpoint,
            tun_reader,
            tun_writer,
            socket,
        };

        Ok(router)
    }

    pub fn split(self) -> (RouterReader, RouterWriter) {
        let writer = RouterWriter {
            endpoint: self.endpoint.clone(),
            tun_writer: self.tun_writer,
            socket: self.socket.clone(),
            last_activity: std::time::Instant::now(),
        };

        let reader = RouterReader {
            config: self.config,
            secret: self.secret,
            endpoint: self.endpoint,
            tun_reader: self.tun_reader,
            socket: self.socket,
        };

        (reader, writer)
    }
}

fn bpf_stmt(code: u32, k: u32) -> SockFilter {
    SockFilter::new(code as u16, 0, 0, k)
}

fn bpf_jump(code: u32, k: u32, jt: u8, jf: u8) -> SockFilter {
    SockFilter::new(code as u16, jt, jf, k)
}
