use base64::prelude::BASE64_STANDARD;
use base64::Engine;
use socket2::{Domain, Protocol, SockAddr, SockFilter, Socket, Type};
use std::net::ToSocketAddrs;
use std::process::{Command, ExitStatus};
use std::sync::Arc;
use tun::{Reader, Writer};
pub const SECRET_LENGTH: usize = 32;
use crate::{ConfigRouter, Meta};
use crossbeam::epoch::Atomic;
use libc::{BPF_ABS, BPF_B, BPF_IND, BPF_JEQ, BPF_JMP, BPF_K, BPF_LD, BPF_LDX, BPF_MSH, BPF_RET, BPF_W};

// tun -> raw
pub struct RouterReader {
    pub config: ConfigRouter,
    pub secret: [u8; SECRET_LENGTH],
    pub tun_reader: Reader,
    pub socket: Arc<Socket>,
    pub endpoint: Arc<Atomic<SockAddr>>,
}

impl RouterReader {
    pub(crate) fn encrypt(&self, data: &mut [u8]) {
        for (i, b) in data.iter_mut().enumerate() {
            *b ^= self.secret[i % SECRET_LENGTH];
        }
    }
}

// raw -> tun
pub struct RouterWriter {
    pub tun_writer: Writer,
    pub socket: Arc<Socket>,
    pub endpoint: Arc<Atomic<SockAddr>>,
}

impl RouterWriter {
    pub(crate) fn decrypt(&self, data: &mut [u8], secret: &[u8; SECRET_LENGTH]) {
        for (i, b) in data.iter_mut().enumerate() {
            *b ^= secret[i % SECRET_LENGTH];
        }
    }
}

pub struct Router {
    pub config: ConfigRouter,
    pub secret: [u8; SECRET_LENGTH],
    pub tun_reader: Reader,
    pub tun_writer: Writer,
    pub socket: Arc<Socket>,
    pub endpoint: Arc<Atomic<SockAddr>>,
}

impl Router {
    pub(crate) fn create_secret(
        config: &str,
    ) -> Result<[u8; SECRET_LENGTH], Box<dyn std::error::Error>> {
        let mut secret = [0u8; SECRET_LENGTH];
        let decoded = BASE64_STANDARD.decode(config)?;
        let len = decoded.len().min(SECRET_LENGTH);
        secret[..len].copy_from_slice(&decoded[..len]);
        Ok(secret)
    }

    fn create_raw_socket(
        config: &ConfigRouter,
        local_id: u8,
    ) -> Result<Arc<Socket>, Box<dyn std::error::Error>> {
        let result = Socket::new(
            if config.family == 6 {
                Domain::IPV6
            } else {
                Domain::IPV4
            },
            Type::RAW,
            Some(Protocol::from(config.proto as i32)),
        )?;
        #[cfg(target_os = "linux")]
        result.set_mark(config.mark)?;
        Self::attach_readable_filter(config, local_id, &result)?;
        Ok(Arc::new(result))
    }

    fn attach_readable_filter(
        config: &ConfigRouter,
        local_id: u8,
        socket: &Socket,
    ) -> Result<(), Box<dyn std::error::Error>> {
        // 由于多个对端可能会使用相同的 ipprpto 号，这里确保每个 socket 上只会收到自己对应的对端发来的消息
        const META_SIZE: usize = size_of::<Meta>();
        let meta = Meta {
            src_id: config.remote_id,
            dst_id: local_id,
            reversed: 0,
        };
        let meta_bytes: [u8; META_SIZE] =
            unsafe { *(&meta as *const Meta as *const [u8; META_SIZE]) };
        let value = u32::from_be_bytes(meta_bytes);

        // 如果你的协议是 UDP，这里必须是 8 (跳过 UDP 头: SrcPort(2)+DstPort(2)+Len(2)+Sum(2))
        // 如果是纯自定义 IP 协议，这里是 0
        let payload_offset = 0;

        let filter: &[SockFilter] = match config.family {
            4 => &[
                // [IPv4] 计算 IPv4 头长度: X = 4 * (IP[0] & 0xf)
                bpf_stmt(BPF_LDX | BPF_B | BPF_MSH, 0),
                // A = Packet[X + payload_offset]
                bpf_stmt(BPF_LD | BPF_W | BPF_IND, payload_offset),
                // if (A == target_val) goto Accept; else goto Reject;
                bpf_jump(BPF_JMP | BPF_JEQ | BPF_K, value, 0, 1),
                // 【接受 (True 路径)】
                bpf_stmt(BPF_RET | BPF_K, u32::MAX),
                // 【拒绝 (False 路径)】
                bpf_stmt(BPF_RET | BPF_K, 0),
            ],
            6 => &[
                // raw socket IPv6 没有 header，加载第 0 字节到累加器 A
                bpf_stmt(BPF_LD | BPF_W | BPF_ABS, 0),
                // if (A == target_val) goto Accept; else goto Reject;
                bpf_jump(BPF_JMP | BPF_JEQ | BPF_K, value, 0, 1),
                // 【接受 (True 路径)】
                bpf_stmt(BPF_RET | BPF_K, u32::MAX),
                // 【拒绝 (False 路径)】
                bpf_stmt(BPF_RET | BPF_K, 0),
            ],
            _ => Err("unsupported family")?,
        };
        socket.attach_filter(filter)?;

        Ok(())
    }

    fn create_tun_device(
        config: &ConfigRouter,
    ) -> Result<(Reader, Writer), Box<dyn std::error::Error>> {
        let mut tun_config = tun::Configuration::default();
        tun_config.tun_name(config.dev.as_str()).up();

        let dev = tun::create(&tun_config)?;
        Ok(dev.split())
    }
    fn run_up_script(config: &ConfigRouter) -> std::io::Result<ExitStatus> {
        Command::new("sh").args(["-c", config.up.as_str()]).status()
    }

    fn create_endpoint(
        config: &ConfigRouter,
    ) -> Result<Arc<Atomic<SockAddr>>, Box<dyn std::error::Error>> {
        let parsed = (config.endpoint.clone(), 0u16)
            .to_socket_addrs()?
            .next()
            .ok_or(config.endpoint.clone())?;
        Ok(Arc::new(Atomic::new(parsed.into())))
    }

    pub fn new(config: ConfigRouter, local_id: u8) -> Result<Router, Box<dyn std::error::Error>> {
        let secret = Self::create_secret(config.remote_secret.as_str())?;
        let endpoint = Self::create_endpoint(&config)?;
        let socket = Self::create_raw_socket(&config, local_id)?;
        let (tun_reader, tun_writer) = Self::create_tun_device(&config)?;
        Self::run_up_script(&config)?;

        let router = Router {
            config,
            secret,
            endpoint,
            tun_reader,
            tun_writer,
            socket,
        };

        Ok(router)
    }

    pub fn split(self) -> (RouterReader, RouterWriter) {
        let writer = RouterWriter {
            endpoint: self.endpoint.clone(),
            tun_writer: self.tun_writer,
            socket: self.socket.clone(),
        };

        let reader = RouterReader {
            config: self.config,
            secret: self.secret,
            endpoint: self.endpoint,
            tun_reader: self.tun_reader,
            socket: self.socket,
        };

        (reader, writer)
    }
}

fn bpf_stmt(code: u32, k: u32) -> SockFilter {
    SockFilter::new(code as u16, 0, 0, k)
}

fn bpf_jump(code: u32, k: u32, jt: u8, jf: u8) -> SockFilter {
    SockFilter::new(code as u16, jt, jf, k)
}
