use pnet::packet::ipv4::Ipv4Packet;
use pnet::packet::Packet;
use socket2::{Domain, Protocol, SockAddr, Socket, Type};
use std::error::Error;
use std::io::{BufRead, Read, Write};
use std::mem::{transmute, MaybeUninit};
use std::sync::{Arc, Mutex};
use std::{env, thread};
use std::net::{SocketAddr, SocketAddrV4};

use std::net::UdpSocket;


#[repr(C)]
pub struct Meta {
    pub src_id: u8,
    pub dst_id: u8,
    pub reversed: u16,
}

struct Secret {
    key: Vec<u8>,
}

impl Secret {
    fn new(key: &String) -> Secret {
        Secret {
            key: key.as_bytes().to_vec(),
        }
    }
    fn encrypt(&self, data: &mut [u8]) {
        for (i, byte) in data.iter_mut().enumerate() {
            *byte ^= self.key[i % self.key.len()];
        }
    }

    fn decrypt(&self, data: &mut [u8]) {
        self.encrypt(data);
    }
}

fn main() -> Result<(), Box<dyn Error>> {
    let local_id: u8 = env::var("LOCAL_ID")?.parse()?;
    let remote_id: u8 = env::var("REMOTE_ID")?.parse()?;
    let local_secret = Secret::new(&env::var("LOCAL_SECRET")?);
    let remote_secret = Secret::new(&env::var("REMOTE_SECRET")?);
    let remote_addr: Arc<Mutex<Option<SockAddr>>> = Arc::new(Mutex::new(
        match env::var("REMOTE_ADDR") {
            Ok(addr_str) => {
                let parsed = addr_str.parse::<SocketAddr>()?;
                Some(parsed.into())
            }
            Err(_) => None,
        }
    ));

    let mut config = tun::Configuration::default();
    config
        .address((10, 0, 0, local_id))
        .netmask((255, 255, 255, 0))
        .destination((10, 0, 0, remote_id))
        .up();

    #[cfg(target_os = "linux")]
    config.platform_config(|config| {
        config.ensure_root_privileges(true);
    });

    let dev = tun::create(&config)?;
    let (mut reader, mut writer) = dev.split();

    // 创建用于接收的原始套接字，协议号为 144
    let socket = Socket::new(Domain::IPV4, Type::RAW, Some(Protocol::from(144)))?;
    let socket_arc = Arc::new(Mutex::new(socket));
    let socket_for_inbound = Arc::clone(&socket_arc);

    let remote_addr_clone = Arc::clone(&remote_addr);

    let inbound = thread::spawn(move || {
        let mut recv_buf = [MaybeUninit::uninit(); 1500];
        loop {
            let sock_guard = socket_for_inbound.lock().unwrap();
            match sock_guard.recv_from(&mut recv_buf) {
                Ok((len, addr1)) => {
                    println!("recv from {:?}", recv_buf);
                    let recv_buf2: &mut [u8] = unsafe { transmute(&mut recv_buf[..len]) };
                    if let Some(packet) = Ipv4Packet::new(recv_buf2) {
                        let header_length = packet.get_header_length() as usize * 4;
                        let (header, data) = recv_buf2.split_at_mut(header_length);
                        let (meta1, payload) = data.split_at_mut(size_of::<Meta>());
                        let meta: &Meta = unsafe { transmute(&meta1) };
                        if meta.src_id == remote_id && meta.dst_id == local_id && meta.reversed == 0
                        {
                            let mut addr_lock = remote_addr_clone.lock().unwrap();
                            *addr_lock = Some(addr1);
                            local_secret.decrypt(payload);
                            writer.write(payload);
                        }
                    }
                }
                Err(e) => {
                    eprintln!("接收数据包时出错: {:?}", e);
                }
            }
        }
    });

    let outbound = thread::spawn(move || {
        let mut recv_buf = [0u8; 1500 - 20];
        let meta = Meta {
            src_id: local_id,
            dst_id: remote_id,
            reversed: 0,
        };
        let meta_size = size_of::<Meta>();

        unsafe {
            let meta_bytes =
                std::slice::from_raw_parts(&meta as *const Meta as *const u8, meta_size);
            recv_buf[..meta_size].copy_from_slice(meta_bytes);
        }
        loop {
            match reader.read(&mut recv_buf[meta_size..]) {
                Ok(len) => {
                    println!("recv from {:?}", &recv_buf[meta_size..meta_size + len]);
                    let maybe_remote = remote_addr.lock().unwrap().clone();
                    println!("1");
                    if let Some(ref remote) = maybe_remote {
                        println!("2");
                        let sock_guard = socket_arc.lock().unwrap();
                        println!("3");
                        remote_secret.encrypt(&mut recv_buf[meta_size..meta_size + len]);
                        println!("4");
                        if let Err(e) = sock_guard.send_to(&recv_buf[..meta_size + len], remote) {
                            println!("5");
                            eprintln!("Error sending packet: {:?}", e);
                        }
                    } else {
                        eprintln!("No remote address available; packet discarded.");
                    }
                }
                Err(_) => {}
            }
        }
    });

    inbound.join();
    outbound.join();

    Ok(())
}
