use pnet::packet::ipv4::Ipv4Packet;
use socket2::{Domain, Protocol, Socket, Type};
use std::error::Error;
use std::io::{Read, Write};
use std::mem::{size_of, MaybeUninit};
use std::net::ToSocketAddrs;
use std::process::Command;
use std::sync::{Arc, RwLock};
use std::{env, thread};

#[repr(C)]
pub struct Meta {
    pub src_id: u8,
    pub dst_id: u8,
    pub reversed: u16,
}

struct Secret {
    key: Vec<u8>,
}

impl Secret {
    fn new(key: &str) -> Self {
        Self {
            key: key.as_bytes().to_vec(),
        }
    }
    fn encrypt(&self, data: &mut [u8]) {
        for (i, b) in data.iter_mut().enumerate() {
            *b ^= self.key[i % self.key.len()];
        }
    }
    fn decrypt(&self, data: &mut [u8]) {
        self.encrypt(data);
    }
}

fn main() -> Result<(), Box<dyn Error>> {
    let local_id: u8 = env::var("LOCAL_ID")?.parse()?;
    let remote_id: u8 = env::var("REMOTE_ID")?.parse()?;
    let local_secret = Secret::new(&env::var("LOCAL_SECRET")?);
    let remote_secret = Secret::new(&env::var("REMOTE_SECRET")?);
    let proto: i32 = env::var("PROTO")?.parse()?;
    let up = env::var("UP_SCRIPT")?;
    let dev = env::var("DEV")?;
    let family = if env::var("FAMILY")?.parse::<u8>()? == 6 {
        Domain::IPV6
    } else {
        Domain::IPV4
    };

    let endpoint = Arc::new(RwLock::new(match env::var("ENDPOINT") {
        Ok(addr_str) => {
            let parsed = addr_str.to_socket_addrs()?.next().unwrap();
            Some(parsed.into())
        }
        Err(_) => None,
    }));

    let mut config = tun::Configuration::default();
    config.tun_name(dev).up();

    let dev = tun::create(&config)?;
    let (mut reader, mut writer) = dev.split();

    // Create a raw socket with protocol number 144
    let socket = Socket::new(family, Type::RAW, Some(Protocol::from(proto)))?;
    #[cfg(target_os = "linux")]
    socket.set_mark(env::var("MARK")?.parse()?);
    let socket_clone = socket.try_clone()?;

    match Command::new(up).status() {
        Ok(status) => {
            if !status.success() {
                eprintln!("Script exited with non-zero status: {}", status);
            }
        }
        Err(e) => eprintln!("Failed to run script '{}'", e),
    }

    // Thread for receiving from WAN (raw socket) and writing to TUN
    let inbound = {
        let remote_addr = Arc::clone(&endpoint);
        let local_secret = local_secret;
        thread::spawn(move || {
            let mut recv_buf = [MaybeUninit::uninit(); 1500];
            loop {
                match socket.recv_from(&mut recv_buf) {
                    Ok((len, addr)) => {
                        // Safely interpret the uninit buffer up to `len` as bytes
                        let data = unsafe {
                            std::slice::from_raw_parts_mut(recv_buf.as_mut_ptr() as *mut u8, len)
                        };

                        if let Some(packet) = Ipv4Packet::new(data) {
                            let header_len = packet.get_header_length() as usize * 4;
                            let (_ip_header, rest) = data.split_at_mut(header_len);
                            if rest.len() < size_of::<Meta>() {
                                // Malformed packet
                                continue;
                            }
                            let (meta_bytes, payload) = rest.split_at_mut(size_of::<Meta>());

                            // Extract meta
                            let meta = Meta {
                                src_id: meta_bytes[0],
                                dst_id: meta_bytes[1],
                                reversed: u16::from_le_bytes([meta_bytes[2], meta_bytes[3]]),
                            };

                            // Check if it matches our expected IDs
                            if meta.src_id == remote_id
                                && meta.dst_id == local_id
                                && meta.reversed == 0
                            {
                                // Update remote address for outbound
                                *remote_addr.write().unwrap() = Some(addr);

                                // Decrypt and push into TUN
                                local_secret.decrypt(payload);
                                let _ = writer.write_all(payload);
                            } else {
                                // eprintln!("Dropping unexpected packet");
                            }
                        }
                    }
                    Err(e) => eprintln!("Error receiving: {}", e),
                }
            }
        })
    };

    // Thread for reading from TUN, encrypting, and sending to WAN
    let outbound = {
        let remote_addr = Arc::clone(&endpoint);
        let remote_secret = remote_secret;
        thread::spawn(move || {
            let mut buffer = [0u8; 1500 - 20]; // minus typical IP header space
            let meta_size = size_of::<Meta>();

            // Pre-initialize with our Meta header (local -> remote)
            let meta = Meta {
                src_id: local_id,
                dst_id: remote_id,
                reversed: 0,
            };
            // Turn the Meta struct into bytes
            let meta_bytes =
                unsafe { std::slice::from_raw_parts(&meta as *const Meta as *const u8, meta_size) };
            buffer[..meta_size].copy_from_slice(meta_bytes);

            loop {
                match reader.read(&mut buffer[meta_size..]) {
                    Ok(n) if n > 0 => {
                        // If we have a known remote address, encrypt and send
                        if let Some(ref addr) = *remote_addr.read().unwrap() {
                            remote_secret.encrypt(&mut buffer[meta_size..meta_size + n]);
                            if let Err(e) = socket_clone.send_to(&buffer[..meta_size + n], addr) {
                                eprintln!("Error sending packet: {}", e);
                            }
                        } else {
                            eprintln!("No remote address set; packet dropped.");
                        }
                    }
                    Err(e) => eprintln!("Error reading from TUN: {}", e),
                    _ => {}
                }
            }
        })
    };

    let _ = inbound.join();
    let _ = outbound.join();
    Ok(())
}
