Commit 643d5d91 authored by nanahira's avatar nanahira

again

parent 9f7e5412
Pipeline #37488 failed with stages
in 50 seconds
...@@ -7,7 +7,11 @@ use std::error::Error; ...@@ -7,7 +7,11 @@ use std::error::Error;
use std::intrinsics::transmute; use std::intrinsics::transmute;
use std::io::{Read, Write}; use std::io::{Read, Write};
use std::mem::MaybeUninit; use std::mem::MaybeUninit;
use std::sync::Arc; use std::sync::{Arc, Mutex};
use crossbeam_utils::thread;
use grouping_by::GroupingBy;
use pnet::packet::ipv4::Ipv4Packet;
use socket2::Socket;
#[repr(C)] #[repr(C)]
pub struct Meta { pub struct Meta {
...@@ -36,10 +40,6 @@ pub struct Config { ...@@ -36,10 +40,6 @@ pub struct Config {
pub local_secret: String, pub local_secret: String,
pub routers: Vec<ConfigRouter>, pub routers: Vec<ConfigRouter>,
} }
use crossbeam_utils::thread;
use grouping_by::GroupingBy;
use pnet::packet::ipv4::Ipv4Packet;
use socket2::Socket;
fn main() -> Result<(), Box<dyn Error>> { fn main() -> Result<(), Box<dyn Error>> {
let config: Config = serde_json::from_str(env::args().nth(1).ok_or("need param")?.as_str())?; let config: Config = serde_json::from_str(env::args().nth(1).ok_or("need param")?.as_str())?;
...@@ -73,64 +73,94 @@ fn main() -> Result<(), Box<dyn Error>> { ...@@ -73,64 +73,94 @@ fn main() -> Result<(), Box<dyn Error>> {
.collect(); .collect();
println!("created tuns"); println!("created tuns");
// 获取系统的核心数量
let num_threads = std::thread::available_parallelism()
.map_or(1, |n| n.get()); // 默认使用1个线程,如果获取不到核心数
thread::scope(|s| { thread::scope(|s| {
for router in router_readers.values_mut() { // 根据核心数量调整线程数
s.spawn(|_| { let readers_per_thread = (router_readers.len() as f32 / num_threads as f32).ceil() as usize;
let mut buffer = [0u8; 1500 - 20]; // minus typical IP header space let writers_per_thread = (router_writers3.len() as f32 / num_threads as f32).ceil() as usize;
let meta_size = size_of::<Meta>();
// 将 router_readers 划分到多个线程
// Pre-initialize with our Meta header (local -> remote) let readers_chunks: Vec<_> = router_readers
let meta = Meta { .into_iter()
src_id: config.local_id, .collect::<Vec<_>>()
dst_id: router.config.remote_id, .chunks(readers_per_thread)
reversed: 0, .map(|chunk| chunk.to_vec())
}; .collect();
// Turn the Meta struct into bytes
let meta_bytes = unsafe { // 将 router_writers3 划分到多个线程
std::slice::from_raw_parts(&meta as *const Meta as *const u8, meta_size) let writers_chunks: Vec<_> = router_writers3
}; .into_iter()
buffer[..meta_size].copy_from_slice(meta_bytes); .collect::<Vec<_>>()
.chunks(writers_per_thread)
loop { .map(|chunk| chunk.to_vec())
let n = router.tun_reader.read(&mut buffer[meta_size..]).unwrap(); .collect();
if let Some(ref addr) = *router.endpoint.read().unwrap() {
router.encrypt(&mut buffer[meta_size..meta_size + n]); // 启动处理 router_readers 的线程
#[cfg(target_os = "linux")] for chunk in readers_chunks {
let _ = router.socket.set_mark(router.config.mark); s.spawn(move |_| {
let _ = router.socket.send_to(&buffer[..meta_size + n], addr); for router in chunk {
let mut buffer = [0u8; 1500 - 20]; // minus typical IP header space
let meta_size = std::mem::size_of::<Meta>();
// Pre-initialize with our Meta header (local -> remote)
let meta = Meta {
src_id: config.local_id,
dst_id: router.1.config.remote_id,
reversed: 0,
};
// Turn the Meta struct into bytes
let meta_bytes = unsafe {
std::slice::from_raw_parts(&meta as *const Meta as *const u8, meta_size)
};
buffer[..meta_size].copy_from_slice(meta_bytes);
loop {
let n = router.1.tun_reader.read(&mut buffer[meta_size..]).unwrap();
if let Some(ref addr) = *router.1.endpoint.read().unwrap() {
router.1.encrypt(&mut buffer[meta_size..meta_size + n]);
#[cfg(target_os = "linux")]
let _ = router.1.socket.set_mark(router.1.config.mark);
let _ = router.1.socket.send_to(&buffer[..meta_size + n], addr);
}
} }
} }
}); });
} }
for (socket, mut router_writers) in router_writers3 { // 启动处理 router_writers 的线程
for chunk in writers_chunks {
s.spawn(move |_| { s.spawn(move |_| {
let mut recv_buf = [MaybeUninit::uninit(); 1500]; for (socket, mut router_writers) in chunk {
loop { let mut recv_buf = [MaybeUninit::uninit(); 1500];
let _ = (|| { loop {
let (len, addr) = socket.recv_from(&mut recv_buf).unwrap(); let _ = (|| {
let data: &mut [u8] = unsafe { transmute(&mut recv_buf[..len]) }; let (len, addr) = socket.recv_from(&mut recv_buf).unwrap();
let data: &mut [u8] = unsafe { transmute(&mut recv_buf[..len]) };
let packet = Ipv4Packet::new(data).ok_or("malformed packet")?;
let header_len = packet.get_header_length() as usize * 4;
let (_ip_header, rest) = data
.split_at_mut_checked(header_len)
.ok_or("malformed packet")?;
let (meta_bytes, payload) = rest
.split_at_mut_checked(size_of::<Meta>())
.ok_or("malformed packet")?;
let meta: &Meta = unsafe { transmute(meta_bytes.as_ptr()) };
if meta.dst_id == config.local_id && meta.reversed == 0 {
let router = router_writers
.get_mut(&meta.src_id)
.ok_or("missing router")?;
*router.endpoint.write().unwrap() = Some(addr);
router.decrypt(payload, &local_secret);
router.tun_writer.write_all(payload)?;
}
Ok::<(), Box<dyn Error>>(()) let packet = Ipv4Packet::new(data).ok_or("malformed packet")?;
})(); let header_len = packet.get_header_length() as usize * 4;
let (_ip_header, rest) = data
.split_at_mut_checked(header_len)
.ok_or("malformed packet")?;
let (meta_bytes, payload) = rest
.split_at_mut_checked(std::mem::size_of::<Meta>())
.ok_or("malformed packet")?;
let meta: &Meta = unsafe { transmute(meta_bytes.as_ptr()) };
if meta.dst_id == config.local_id && meta.reversed == 0 {
let router = router_writers
.get_mut(&meta.src_id)
.ok_or("missing router")?;
*router.endpoint.write().unwrap() = Some(addr);
router.decrypt(payload, &local_secret);
router.tun_writer.write_all(payload)?;
}
Ok::<(), Box<dyn Error>>(())
})();
}
} }
}); });
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment