Commit 3541fff7 authored by nanahira's avatar nanahira

fix

parent 643d5d91
Pipeline #37493 failed with stages
in 1 minute and 16 seconds
......@@ -4,9 +4,6 @@ use crate::router::{Router, RouterReader, RouterWriter, SECRET_LENGTH};
use std::collections::HashMap;
use std::env;
use std::error::Error;
use std::intrinsics::transmute;
use std::io::{Read, Write};
use std::mem::MaybeUninit;
use std::sync::{Arc, Mutex};
use crossbeam_utils::thread;
use grouping_by::GroupingBy;
......@@ -44,24 +41,34 @@ pub struct Config {
fn main() -> Result<(), Box<dyn Error>> {
let config: Config = serde_json::from_str(env::args().nth(1).ok_or("need param")?.as_str())?;
let local_secret: [u8; SECRET_LENGTH] = Router::create_secret(config.local_secret.as_str())?;
// Create shared resources (Arc<Mutex>)
let mut sockets: HashMap<u16, Arc<Socket>> = HashMap::new();
let routers: HashMap<u8, Router> = config
.routers
.iter()
.map(|c| Router::new(c, &mut sockets).map(|router| (c.remote_id, router)))
.collect::<Result<_, _>>()?;
let (mut router_readers, router_writers): (
HashMap<u8, RouterReader>,
HashMap<u8, RouterWriter>,
) = routers
.into_iter()
.map(|(id, router)| {
// Mutex to allow safe concurrent access to router readers and writers
let router_readers: Arc<Mutex<HashMap<u8, RouterReader>>> = Arc::new(Mutex::new(HashMap::new()));
let router_writers: Arc<Mutex<HashMap<u8, RouterWriter>>> = Arc::new(Mutex::new(HashMap::new()));
// Populate router_readers and router_writers
{
let mut readers = router_readers.lock().unwrap();
let mut writers = router_writers.lock().unwrap();
for (id, router) in routers {
let (reader, writer) = router.split();
((id, reader), (id, writer))
})
.unzip();
let router_writers3: Vec<(Arc<Socket>, HashMap<u8, RouterWriter>)> = router_writers
.into_iter()
readers.insert(id, reader);
writers.insert(id, writer);
}
}
let router_writers3: Vec<(Arc<Socket>, HashMap<u8, RouterWriter>)> = {
let writers = router_writers.lock().unwrap();
writers
.iter()
.grouping_by(|(_, v)| v.key())
.into_iter()
.map(|(k, v)| {
......@@ -70,74 +77,76 @@ fn main() -> Result<(), Box<dyn Error>> {
v.into_iter().collect(),
)
})
.collect();
.collect()
};
println!("created tuns");
// 获取系统的核心数量
// Get system's available cores and calculate threads per task
let num_threads = std::thread::available_parallelism()
.map_or(1, |n| n.get()); // 默认使用1个线程,如果获取不到核心数
.map_or(1, |n| n.get());
thread::scope(|s| {
// 根据核心数量调整线程数
let readers_per_thread = (router_readers.len() as f32 / num_threads as f32).ceil() as usize;
let writers_per_thread = (router_writers3.len() as f32 / num_threads as f32).ceil() as usize;
// 将 router_readers 划分到多个线程
let readers_chunks: Vec<_> = router_readers
.into_iter()
.collect::<Vec<_>>()
.chunks(readers_per_thread)
// Split tasks based on available threads
let readers_chunks: Vec<_> = {
let readers = router_readers.lock().unwrap();
readers
.iter()
.chunks((readers.len() as f32 / num_threads as f32).ceil() as usize)
.map(|chunk| chunk.to_vec())
.collect();
.collect()
};
// 将 router_writers3 划分到多个线程
let writers_chunks: Vec<_> = router_writers3
.into_iter()
.collect::<Vec<_>>()
.chunks(writers_per_thread)
.chunks((router_writers3.len() as f32 / num_threads as f32).ceil() as usize)
.map(|chunk| chunk.to_vec())
.collect();
// 启动处理 router_readers 的线程
// Spawn threads for router readers
for chunk in readers_chunks {
s.spawn(move |_| {
for router in chunk {
let mut buffer = [0u8; 1500 - 20]; // minus typical IP header space
for (id, router_reader) in chunk {
let mut buffer = [0u8; 1500 - 20];
let meta_size = std::mem::size_of::<Meta>();
// Pre-initialize with our Meta header (local -> remote)
let meta = Meta {
src_id: config.local_id,
dst_id: router.1.config.remote_id,
dst_id: id,
reversed: 0,
};
// Turn the Meta struct into bytes
let meta_bytes = unsafe {
std::slice::from_raw_parts(&meta as *const Meta as *const u8, meta_size)
};
buffer[..meta_size].copy_from_slice(meta_bytes);
loop {
let n = router.1.tun_reader.read(&mut buffer[meta_size..]).unwrap();
if let Some(ref addr) = *router.1.endpoint.read().unwrap() {
router.1.encrypt(&mut buffer[meta_size..meta_size + n]);
match router_reader.tun_reader.read(&mut buffer[meta_size..]) {
Ok(n) => {
if let Some(ref addr) = *router_reader.endpoint.read().unwrap() {
router_reader.encrypt(&mut buffer[meta_size..meta_size + n]);
#[cfg(target_os = "linux")]
let _ = router.1.socket.set_mark(router.1.config.mark);
let _ = router.1.socket.send_to(&buffer[..meta_size + n], addr);
let _ = router_reader.socket.set_mark(router_reader.config.mark);
let _ = router_reader.socket.send_to(&buffer[..meta_size + n], addr);
}
}
Err(e) => {
eprintln!("Error reading from tunnel: {}", e);
break;
}
}
}
}
});
}
// 启动处理 router_writers 的线程
// Spawn threads for router writers
for chunk in writers_chunks {
s.spawn(move |_| {
for (socket, mut router_writers) in chunk {
let mut recv_buf = [MaybeUninit::uninit(); 1500];
loop {
let _ = (|| {
let (len, addr) = socket.recv_from(&mut recv_buf).unwrap();
match socket.recv_from(&mut recv_buf) {
Ok((len, addr)) => {
let data: &mut [u8] = unsafe { transmute(&mut recv_buf[..len]) };
let packet = Ipv4Packet::new(data).ok_or("malformed packet")?;
......@@ -159,6 +168,12 @@ fn main() -> Result<(), Box<dyn Error>> {
}
Ok::<(), Box<dyn Error>>(())
}
Err(e) => {
eprintln!("Error receiving data: {}", e);
Err(e.into())
}
}
})();
}
}
......@@ -166,5 +181,6 @@ fn main() -> Result<(), Box<dyn Error>> {
}
})
.unwrap();
Ok(())
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment