Commit 23b7ad07 authored by nanahira's avatar nanahira

test

parent 3541fff7
Pipeline #37494 failed with stages
in 15 seconds
...@@ -4,6 +4,9 @@ use crate::router::{Router, RouterReader, RouterWriter, SECRET_LENGTH}; ...@@ -4,6 +4,9 @@ use crate::router::{Router, RouterReader, RouterWriter, SECRET_LENGTH};
use std::collections::HashMap; use std::collections::HashMap;
use std::env; use std::env;
use std::error::Error; use std::error::Error;
use std::intrinsics::transmute;
use std::io::{Read, Write};
use std::mem::MaybeUninit;
use std::sync::{Arc, Mutex}; use std::sync::{Arc, Mutex};
use crossbeam_utils::thread; use crossbeam_utils::thread;
use grouping_by::GroupingBy; use grouping_by::GroupingBy;
...@@ -39,143 +42,149 @@ pub struct Config { ...@@ -39,143 +42,149 @@ pub struct Config {
} }
fn main() -> Result<(), Box<dyn Error>> { fn main() -> Result<(), Box<dyn Error>> {
// Log: Start the program
println!("Starting program...");
let config: Config = serde_json::from_str(env::args().nth(1).ok_or("need param")?.as_str())?; let config: Config = serde_json::from_str(env::args().nth(1).ok_or("need param")?.as_str())?;
let local_secret: [u8; SECRET_LENGTH] = Router::create_secret(config.local_secret.as_str())?; let local_secret: [u8; SECRET_LENGTH] = Router::create_secret(config.local_secret.as_str())?;
// Create shared resources (Arc<Mutex>) // Log: Print config data
println!("Loaded config: {:?}", config);
let mut sockets: HashMap<u16, Arc<Socket>> = HashMap::new(); let mut sockets: HashMap<u16, Arc<Socket>> = HashMap::new();
let routers: HashMap<u8, Router> = config let routers: HashMap<u8, Router> = config
.routers .routers
.iter() .iter()
.map(|c| Router::new(c, &mut sockets).map(|router| (c.remote_id, router))) .map(|c| {
// Log: Creating router
println!("Creating router for remote_id: {}", c.remote_id);
Router::new(c, &mut sockets).map(|router| (c.remote_id, router))
})
.collect::<Result<_, _>>()?; .collect::<Result<_, _>>()?;
// Mutex to allow safe concurrent access to router readers and writers // Log: Routers created successfully
let router_readers: Arc<Mutex<HashMap<u8, RouterReader>>> = Arc::new(Mutex::new(HashMap::new())); println!("Routers created successfully");
let router_writers: Arc<Mutex<HashMap<u8, RouterWriter>>> = Arc::new(Mutex::new(HashMap::new()));
let (mut router_readers, router_writers): (
// Populate router_readers and router_writers HashMap<u8, RouterReader>,
{ HashMap<u8, RouterWriter>,
let mut readers = router_readers.lock().unwrap(); ) = routers
let mut writers = router_writers.lock().unwrap(); .into_iter()
for (id, router) in routers { .map(|(id, router)| {
// Log: Splitting router into reader and writer
println!("Splitting router with id: {}", id);
let (reader, writer) = router.split(); let (reader, writer) = router.split();
readers.insert(id, reader); ((id, reader), (id, writer))
writers.insert(id, writer); })
} .unzip();
}
// Log: Writers grouped
let router_writers3: Vec<(Arc<Socket>, HashMap<u8, RouterWriter>)> = { println!("Grouping router writers");
let writers = router_writers.lock().unwrap(); let router_writers3: Vec<(Arc<Socket>, HashMap<u8, RouterWriter>)> = router_writers
writers .into_iter()
.iter() .grouping_by(|(_, v)| v.key())
.grouping_by(|(_, v)| v.key()) .into_iter()
.into_iter() .map(|(k, v)| {
.map(|(k, v)| { // Log: Grouping writers for key: {}
( println!("Grouping writers for key: {}", k);
Arc::clone(sockets.get_mut(&k).unwrap()), (
v.into_iter().collect(), Arc::clone(sockets.get_mut(&k).unwrap()),
) v.into_iter().collect(),
}) )
.collect() })
}; .collect();
println!("created tuns"); // Log: Tun interfaces created
println!("Created tuns");
// Get system's available cores and calculate threads per task
// Get system's available cores
let num_threads = std::thread::available_parallelism() let num_threads = std::thread::available_parallelism()
.map_or(1, |n| n.get()); .map_or(1, |n| n.get()); // Default to 1 thread if unavailable
thread::scope(|s| { // Log: Number of threads
// Split tasks based on available threads println!("System has {} available threads", num_threads);
let readers_chunks: Vec<_> = {
let readers = router_readers.lock().unwrap();
readers
.iter()
.chunks((readers.len() as f32 / num_threads as f32).ceil() as usize)
.map(|chunk| chunk.to_vec())
.collect()
};
let writers_chunks: Vec<_> = router_writers3
.chunks((router_writers3.len() as f32 / num_threads as f32).ceil() as usize)
.map(|chunk| chunk.to_vec())
.collect();
thread::scope(|s| {
// Spawn threads for router readers // Spawn threads for router readers
for chunk in readers_chunks { for router in router_readers.values_mut() {
s.spawn(move |_| { s.spawn(|_| {
for (id, router_reader) in chunk { let mut buffer = [0u8; 1500 - 20]; // minus typical IP header space
let mut buffer = [0u8; 1500 - 20]; let meta_size = std::mem::size_of::<Meta>();
let meta_size = std::mem::size_of::<Meta>();
let meta = Meta { // Pre-initialize with our Meta header (local -> remote)
src_id: config.local_id, let meta = Meta {
dst_id: id, src_id: config.local_id,
reversed: 0, dst_id: router.config.remote_id,
}; reversed: 0,
let meta_bytes = unsafe { };
std::slice::from_raw_parts(&meta as *const Meta as *const u8, meta_size)
}; // Log: Preparing meta bytes
buffer[..meta_size].copy_from_slice(meta_bytes); let meta_bytes = unsafe {
std::slice::from_raw_parts(&meta as *const Meta as *const u8, meta_size)
loop { };
match router_reader.tun_reader.read(&mut buffer[meta_size..]) { buffer[..meta_size].copy_from_slice(meta_bytes);
Ok(n) => {
if let Some(ref addr) = *router_reader.endpoint.read().unwrap() { // Log: Starting reader loop for router
router_reader.encrypt(&mut buffer[meta_size..meta_size + n]); println!("Starting reader loop for router with id: {}", router.config.remote_id);
#[cfg(target_os = "linux")]
let _ = router_reader.socket.set_mark(router_reader.config.mark); loop {
let _ = router_reader.socket.send_to(&buffer[..meta_size + n], addr); match router.tun_reader.read(&mut buffer[meta_size..]) {
} Ok(n) => {
} if let Some(ref addr) = *router.endpoint.read().unwrap() {
Err(e) => { router.encrypt(&mut buffer[meta_size..meta_size + n]);
eprintln!("Error reading from tunnel: {}", e); #[cfg(target_os = "linux")]
break; let _ = router.socket.set_mark(router.config.mark);
let _ = router.socket.send_to(&buffer[..meta_size + n], addr);
} }
} }
Err(e) => {
eprintln!("Error reading from tunnel: {}", e);
break;
}
} }
} }
}); });
} }
// Spawn threads for router writers // Spawn threads for router writers
for chunk in writers_chunks { for (socket, mut router_writers) in router_writers3 {
s.spawn(move |_| { s.spawn(move |_| {
for (socket, mut router_writers) in chunk { let mut recv_buf = [MaybeUninit::uninit(); 1500];
let mut recv_buf = [MaybeUninit::uninit(); 1500]; loop {
loop { let _ = (|| {
let _ = (|| { match socket.recv_from(&mut recv_buf) {
match socket.recv_from(&mut recv_buf) { Ok((len, addr)) => {
Ok((len, addr)) => { let data: &mut [u8] = unsafe { transmute(&mut recv_buf[..len]) };
let data: &mut [u8] = unsafe { transmute(&mut recv_buf[..len]) };
// Log: Received data
let packet = Ipv4Packet::new(data).ok_or("malformed packet")?; println!("Received data from: {}", addr);
let header_len = packet.get_header_length() as usize * 4;
let (_ip_header, rest) = data let packet = Ipv4Packet::new(data).ok_or("malformed packet")?;
.split_at_mut_checked(header_len) let header_len = packet.get_header_length() as usize * 4;
.ok_or("malformed packet")?; let (_ip_header, rest) = data
let (meta_bytes, payload) = rest .split_at_mut_checked(header_len)
.split_at_mut_checked(std::mem::size_of::<Meta>()) .ok_or("malformed packet")?;
.ok_or("malformed packet")?; let (meta_bytes, payload) = rest
let meta: &Meta = unsafe { transmute(meta_bytes.as_ptr()) }; .split_at_mut_checked(std::mem::size_of::<Meta>())
if meta.dst_id == config.local_id && meta.reversed == 0 { .ok_or("malformed packet")?;
let router = router_writers let meta: &Meta = unsafe { transmute(meta_bytes.as_ptr()) };
.get_mut(&meta.src_id) if meta.dst_id == config.local_id && meta.reversed == 0 {
.ok_or("missing router")?; let router = router_writers
*router.endpoint.write().unwrap() = Some(addr); .get_mut(&meta.src_id)
router.decrypt(payload, &local_secret); .ok_or("missing router")?;
router.tun_writer.write_all(payload)?; *router.endpoint.write().unwrap() = Some(addr);
} router.decrypt(payload, &local_secret);
router.tun_writer.write_all(payload)?;
Ok::<(), Box<dyn Error>>(())
}
Err(e) => {
eprintln!("Error receiving data: {}", e);
Err(e.into())
} }
Ok::<(), Box<dyn Error>>(())
} }
})(); Err(e) => {
} eprintln!("Error receiving data: {}", e);
Err(e.into())
}
}
})();
} }
}); });
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment