Commit 0907669a authored by nanahira's avatar nanahira

again

parent f5ffc9fe
Pipeline #37432 failed with stages
in 20 seconds
...@@ -8,6 +8,7 @@ use std::intrinsics::transmute; ...@@ -8,6 +8,7 @@ use std::intrinsics::transmute;
use std::io::{Read, Write}; use std::io::{Read, Write};
use std::mem::MaybeUninit; use std::mem::MaybeUninit;
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
#[repr(C)] #[repr(C)]
pub struct Meta { pub struct Meta {
...@@ -39,17 +40,44 @@ pub struct Config { ...@@ -39,17 +40,44 @@ pub struct Config {
use crossbeam_utils::thread; use crossbeam_utils::thread;
use grouping_by::GroupingBy; use grouping_by::GroupingBy;
use pnet::packet::ipv4::Ipv4Packet; use pnet::packet::ipv4::Ipv4Packet;
use socket2::Socket; use socket2::{Socket, SockAddr};
// 优化参数
const BUFFER_SIZE: usize = 65536; // 64KB 缓冲区
const BATCH_SIZE: usize = 32; // 批量处理大小
const SOCKET_BUFFER_SIZE: usize = 8 * 1024 * 1024; // 8MB socket 缓冲区
fn main() -> Result<(), Box<dyn Error>> { fn main() -> Result<(), Box<dyn Error>> {
let config: Config = serde_json::from_str(env::args().nth(1).ok_or("need param")?.as_str())?; let config: Config = serde_json::from_str(env::args().nth(1).ok_or("need param")?.as_str())?;
let local_secret: [u8; SECRET_LENGTH] = Router::create_secret(config.local_secret.as_str())?; let local_secret: [u8; SECRET_LENGTH] = Router::create_secret(config.local_secret.as_str())?;
let mut sockets: HashMap<u16, Arc<Socket>> = HashMap::new(); let mut sockets: HashMap<u16, Arc<Socket>> = HashMap::new();
let routers: HashMap<u8, Router> = config let routers: HashMap<u8, Router> = config
.routers .routers
.iter() .iter()
.map(|c| Router::new(c, &mut sockets).map(|router| (c.remote_id, router))) .map(|c| Router::new(c, &mut sockets).map(|router| (c.remote_id, router)))
.collect::<Result<_, _>>()?; .collect::<Result<_, _>>()?;
// 优化 socket 缓冲区大小
for socket in sockets.values() {
let _ = socket.set_send_buffer_size(SOCKET_BUFFER_SIZE);
let _ = socket.set_recv_buffer_size(SOCKET_BUFFER_SIZE);
#[cfg(target_os = "linux")]
{
// 启用 GSO/GRO
unsafe {
let enable = 1i32;
libc::setsockopt(
socket.as_raw_fd(),
libc::SOL_UDP,
libc::UDP_GRO,
&enable as *const _ as *const libc::c_void,
std::mem::size_of_val(&enable) as libc::socklen_t,
);
}
}
}
let (mut router_readers, router_writers): ( let (mut router_readers, router_writers): (
HashMap<u8, RouterReader>, HashMap<u8, RouterReader>,
HashMap<u8, RouterWriter>, HashMap<u8, RouterWriter>,
...@@ -74,41 +102,96 @@ fn main() -> Result<(), Box<dyn Error>> { ...@@ -74,41 +102,96 @@ fn main() -> Result<(), Box<dyn Error>> {
println!("created tuns"); println!("created tuns");
thread::scope(|s| { thread::scope(|s| {
// 为每个路由创建多个发送线程
for router in router_readers.values_mut() { for router in router_readers.values_mut() {
s.spawn(|_| { let router_id = router.config.remote_id;
let mut buffer = [0u8; 1500 - 20]; // minus typical IP header space let local_id = config.local_id;
let mark = router.config.mark;
// 创建 4 个并发发送线程
for _ in 0..4 {
let socket = Arc::clone(&router.socket);
let endpoint = Arc::clone(&router.endpoint);
let tun_reader = router.tun_reader.try_clone().unwrap();
let encrypt_fn = router.encrypt.clone();
s.spawn(move |_| {
let mut buffers: Vec<Vec<u8>> = (0..BATCH_SIZE)
.map(|_| vec![0u8; BUFFER_SIZE])
.collect();
let meta_size = size_of::<Meta>(); let meta_size = size_of::<Meta>();
// Pre-initialize with our Meta header (local -> remote) // 预初始化 Meta 头
let meta = Meta { let meta = Meta {
src_id: config.local_id, src_id: local_id,
dst_id: router.config.remote_id, dst_id: router_id,
reversed: 0, reversed: 0,
}; };
// Turn the Meta struct into bytes
let meta_bytes = unsafe { let meta_bytes = unsafe {
std::slice::from_raw_parts(&meta as *const Meta as *const u8, meta_size) std::slice::from_raw_parts(&meta as *const Meta as *const u8, meta_size)
}; };
for buffer in &mut buffers {
buffer[..meta_size].copy_from_slice(meta_bytes); buffer[..meta_size].copy_from_slice(meta_bytes);
}
let mut current_buffer = 0;
loop { loop {
let n = router.tun_reader.read(&mut buffer[meta_size..]).unwrap(); let buffer = &mut buffers[current_buffer];
if let Some(ref addr) = *router.endpoint.read().unwrap() { match tun_reader.read(&mut buffer[meta_size..]) {
router.encrypt(&mut buffer[meta_size..meta_size + n]); Ok(n) if n > 0 => {
if let Some(ref addr) = *endpoint.read().unwrap() {
encrypt_fn(&mut buffer[meta_size..meta_size + n]);
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
let _ = router.socket.set_mark(router.config.mark); let _ = socket.set_mark(mark);
let _ = router.socket.send_to(&buffer[..meta_size + n], addr);
// 使用 MSG_DONTWAIT 避免阻塞
match socket.send_to(&buffer[..meta_size + n], addr) {
Ok(_) => {},
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
// 缓冲区满,稍后重试
std::thread::yield_now();
},
Err(_) => {},
}
}
current_buffer = (current_buffer + 1) % BATCH_SIZE;
},
_ => std::thread::yield_now(),
} }
} }
}); });
} }
}
// 为每个 socket 创建多个接收线程
for (socket, mut router_writers) in router_writers3 { for (socket, mut router_writers) in router_writers3 {
// 创建 4 个并发接收线程
for _ in 0..4 {
let socket = Arc::clone(&socket);
let mut router_writers = router_writers.clone();
let local_id = config.local_id;
let local_secret = local_secret.clone();
s.spawn(move |_| { s.spawn(move |_| {
let mut recv_buf = [MaybeUninit::uninit(); 1500]; let mut recv_bufs: Vec<[MaybeUninit<u8>; BUFFER_SIZE]> = (0..BATCH_SIZE)
.map(|_| [MaybeUninit::uninit(); BUFFER_SIZE])
.collect();
let mut current_buffer = 0;
loop { loop {
let recv_buf = &mut recv_bufs[current_buffer];
let _ = (|| { let _ = (|| {
let (len, addr) = socket.recv_from(&mut recv_buf).unwrap(); let (len, addr) = match socket.recv_from(recv_buf) {
Ok(result) => result,
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
std::thread::yield_now();
return Ok(());
},
Err(_) => return Ok(()),
};
let data: &mut [u8] = unsafe { transmute(&mut recv_buf[..len]) }; let data: &mut [u8] = unsafe { transmute(&mut recv_buf[..len]) };
let packet = Ipv4Packet::new(data).ok_or("malformed packet")?; let packet = Ipv4Packet::new(data).ok_or("malformed packet")?;
...@@ -120,21 +203,38 @@ fn main() -> Result<(), Box<dyn Error>> { ...@@ -120,21 +203,38 @@ fn main() -> Result<(), Box<dyn Error>> {
.split_at_mut_checked(size_of::<Meta>()) .split_at_mut_checked(size_of::<Meta>())
.ok_or("malformed packet")?; .ok_or("malformed packet")?;
let meta: &Meta = unsafe { transmute(meta_bytes.as_ptr()) }; let meta: &Meta = unsafe { transmute(meta_bytes.as_ptr()) };
if meta.dst_id == config.local_id && meta.reversed == 0 { if meta.dst_id == local_id && meta.reversed == 0 {
let router = router_writers let router = router_writers
.get_mut(&meta.src_id) .get_mut(&meta.src_id)
.ok_or("missing router")?; .ok_or("missing router")?;
*router.endpoint.write().unwrap() = Some(addr); *router.endpoint.write().unwrap() = Some(addr);
router.decrypt(payload, &local_secret); router.decrypt(payload, &local_secret);
router.tun_writer.write_all(payload)?;
// 批量写入以减少系统调用
let mut offset = 0;
while offset < payload.len() {
match router.tun_writer.write(&payload[offset..]) {
Ok(n) => offset += n,
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
std::thread::yield_now();
},
Err(_) => break,
}
}
} }
current_buffer = (current_buffer + 1) % BATCH_SIZE;
Ok::<(), Box<dyn Error>>(()) Ok::<(), Box<dyn Error>>(())
})(); })();
} }
}); });
} }
}
}) })
.unwrap(); .unwrap();
Ok(()) Ok(())
} }
// 辅助函数:设置 socket 为非阻塞模式
#[cfg(target_os = "linux")]
use std::os::unix::io::AsRawFd;
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment