mod router;

use crate::router::{Router, RouterReader, RouterWriter, SECRET_LENGTH};
use std::collections::HashMap;
use std::env;
use std::error::Error;
use std::sync::{Arc, Mutex};
use crossbeam_utils::thread;
use grouping_by::GroupingBy;
use pnet::packet::ipv4::Ipv4Packet;
use socket2::Socket;

#[repr(C)]
pub struct Meta {
    pub src_id: u8,
    pub dst_id: u8,
    pub reversed: u16,
}

use serde::Deserialize;

#[derive(Deserialize)]
pub struct ConfigRouter {
    pub remote_id: u8,
    pub proto: i32,
    pub family: u8,
    pub mark: u32,
    pub endpoint: String,
    pub remote_secret: String,
    pub dev: String,
    pub up: String,
}

#[derive(Deserialize)]
pub struct Config {
    pub local_id: u8,
    pub local_secret: String,
    pub routers: Vec<ConfigRouter>,
}

fn main() -> Result<(), Box<dyn Error>> {
    let config: Config = serde_json::from_str(env::args().nth(1).ok_or("need param")?.as_str())?;
    let local_secret: [u8; SECRET_LENGTH] = Router::create_secret(config.local_secret.as_str())?;
    
    // Create shared resources (Arc<Mutex>)
    let mut sockets: HashMap<u16, Arc<Socket>> = HashMap::new();
    let routers: HashMap<u8, Router> = config
        .routers
        .iter()
        .map(|c| Router::new(c, &mut sockets).map(|router| (c.remote_id, router)))
        .collect::<Result<_, _>>()?;
    
    // Mutex to allow safe concurrent access to router readers and writers
    let router_readers: Arc<Mutex<HashMap<u8, RouterReader>>> = Arc::new(Mutex::new(HashMap::new()));
    let router_writers: Arc<Mutex<HashMap<u8, RouterWriter>>> = Arc::new(Mutex::new(HashMap::new()));

    // Populate router_readers and router_writers
    {
        let mut readers = router_readers.lock().unwrap();
        let mut writers = router_writers.lock().unwrap();
        for (id, router) in routers {
            let (reader, writer) = router.split();
            readers.insert(id, reader);
            writers.insert(id, writer);
        }
    }

    let router_writers3: Vec<(Arc<Socket>, HashMap<u8, RouterWriter>)> = {
        let writers = router_writers.lock().unwrap();
        writers
            .iter()
            .grouping_by(|(_, v)| v.key())
            .into_iter()
            .map(|(k, v)| {
                (
                    Arc::clone(sockets.get_mut(&k).unwrap()),
                    v.into_iter().collect(),
                )
            })
            .collect()
    };

    println!("created tuns");

    // Get system's available cores and calculate threads per task
    let num_threads = std::thread::available_parallelism()
        .map_or(1, |n| n.get());

    thread::scope(|s| {
        // Split tasks based on available threads
        let readers_chunks: Vec<_> = {
            let readers = router_readers.lock().unwrap();
            readers
                .iter()
                .chunks((readers.len() as f32 / num_threads as f32).ceil() as usize)
                .map(|chunk| chunk.to_vec())
                .collect()
        };

        let writers_chunks: Vec<_> = router_writers3
            .chunks((router_writers3.len() as f32 / num_threads as f32).ceil() as usize)
            .map(|chunk| chunk.to_vec())
            .collect();

        // Spawn threads for router readers
        for chunk in readers_chunks {
            s.spawn(move |_| {
                for (id, router_reader) in chunk {
                    let mut buffer = [0u8; 1500 - 20];
                    let meta_size = std::mem::size_of::<Meta>();
                    let meta = Meta {
                        src_id: config.local_id,
                        dst_id: id,
                        reversed: 0,
                    };
                    let meta_bytes = unsafe {
                        std::slice::from_raw_parts(&meta as *const Meta as *const u8, meta_size)
                    };
                    buffer[..meta_size].copy_from_slice(meta_bytes);

                    loop {
                        match router_reader.tun_reader.read(&mut buffer[meta_size..]) {
                            Ok(n) => {
                                if let Some(ref addr) = *router_reader.endpoint.read().unwrap() {
                                    router_reader.encrypt(&mut buffer[meta_size..meta_size + n]);
                                    #[cfg(target_os = "linux")]
                                    let _ = router_reader.socket.set_mark(router_reader.config.mark);
                                    let _ = router_reader.socket.send_to(&buffer[..meta_size + n], addr);
                                }
                            }
                            Err(e) => {
                                eprintln!("Error reading from tunnel: {}", e);
                                break;
                            }
                        }
                    }
                }
            });
        }

        // Spawn threads for router writers
        for chunk in writers_chunks {
            s.spawn(move |_| {
                for (socket, mut router_writers) in chunk {
                    let mut recv_buf = [MaybeUninit::uninit(); 1500];
                    loop {
                        let _ = (|| {
                            match socket.recv_from(&mut recv_buf) {
                                Ok((len, addr)) => {
                                    let data: &mut [u8] = unsafe { transmute(&mut recv_buf[..len]) };

                                    let packet = Ipv4Packet::new(data).ok_or("malformed packet")?;
                                    let header_len = packet.get_header_length() as usize * 4;
                                    let (_ip_header, rest) = data
                                        .split_at_mut_checked(header_len)
                                        .ok_or("malformed packet")?;
                                    let (meta_bytes, payload) = rest
                                        .split_at_mut_checked(std::mem::size_of::<Meta>())
                                        .ok_or("malformed packet")?;
                                    let meta: &Meta = unsafe { transmute(meta_bytes.as_ptr()) };
                                    if meta.dst_id == config.local_id && meta.reversed == 0 {
                                        let router = router_writers
                                            .get_mut(&meta.src_id)
                                            .ok_or("missing router")?;
                                        *router.endpoint.write().unwrap() = Some(addr);
                                        router.decrypt(payload, &local_secret);
                                        router.tun_writer.write_all(payload)?;
                                    }

                                    Ok::<(), Box<dyn Error>>(())
                                }
                                Err(e) => {
                                    eprintln!("Error receiving data: {}", e);
                                    Err(e.into())
                                }
                            }
                        })();
                    }
                }
            });
        }
    })
    .unwrap();

    Ok(())
}
