Commit 0907669a authored by nanahira's avatar nanahira

again

parent f5ffc9fe
Pipeline #37432 failed with stages
in 20 seconds
...@@ -8,6 +8,7 @@ use std::intrinsics::transmute; ...@@ -8,6 +8,7 @@ use std::intrinsics::transmute;
use std::io::{Read, Write}; use std::io::{Read, Write};
use std::mem::MaybeUninit; use std::mem::MaybeUninit;
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
#[repr(C)] #[repr(C)]
pub struct Meta { pub struct Meta {
...@@ -39,17 +40,44 @@ pub struct Config { ...@@ -39,17 +40,44 @@ pub struct Config {
use crossbeam_utils::thread; use crossbeam_utils::thread;
use grouping_by::GroupingBy; use grouping_by::GroupingBy;
use pnet::packet::ipv4::Ipv4Packet; use pnet::packet::ipv4::Ipv4Packet;
use socket2::Socket; use socket2::{Socket, SockAddr};
// 优化参数
const BUFFER_SIZE: usize = 65536; // 64KB 缓冲区
const BATCH_SIZE: usize = 32; // 批量处理大小
const SOCKET_BUFFER_SIZE: usize = 8 * 1024 * 1024; // 8MB socket 缓冲区
fn main() -> Result<(), Box<dyn Error>> { fn main() -> Result<(), Box<dyn Error>> {
let config: Config = serde_json::from_str(env::args().nth(1).ok_or("need param")?.as_str())?; let config: Config = serde_json::from_str(env::args().nth(1).ok_or("need param")?.as_str())?;
let local_secret: [u8; SECRET_LENGTH] = Router::create_secret(config.local_secret.as_str())?; let local_secret: [u8; SECRET_LENGTH] = Router::create_secret(config.local_secret.as_str())?;
let mut sockets: HashMap<u16, Arc<Socket>> = HashMap::new(); let mut sockets: HashMap<u16, Arc<Socket>> = HashMap::new();
let routers: HashMap<u8, Router> = config let routers: HashMap<u8, Router> = config
.routers .routers
.iter() .iter()
.map(|c| Router::new(c, &mut sockets).map(|router| (c.remote_id, router))) .map(|c| Router::new(c, &mut sockets).map(|router| (c.remote_id, router)))
.collect::<Result<_, _>>()?; .collect::<Result<_, _>>()?;
// 优化 socket 缓冲区大小
for socket in sockets.values() {
let _ = socket.set_send_buffer_size(SOCKET_BUFFER_SIZE);
let _ = socket.set_recv_buffer_size(SOCKET_BUFFER_SIZE);
#[cfg(target_os = "linux")]
{
// 启用 GSO/GRO
unsafe {
let enable = 1i32;
libc::setsockopt(
socket.as_raw_fd(),
libc::SOL_UDP,
libc::UDP_GRO,
&enable as *const _ as *const libc::c_void,
std::mem::size_of_val(&enable) as libc::socklen_t,
);
}
}
}
let (mut router_readers, router_writers): ( let (mut router_readers, router_writers): (
HashMap<u8, RouterReader>, HashMap<u8, RouterReader>,
HashMap<u8, RouterWriter>, HashMap<u8, RouterWriter>,
...@@ -74,67 +102,139 @@ fn main() -> Result<(), Box<dyn Error>> { ...@@ -74,67 +102,139 @@ fn main() -> Result<(), Box<dyn Error>> {
println!("created tuns"); println!("created tuns");
thread::scope(|s| { thread::scope(|s| {
// 为每个路由创建多个发送线程
for router in router_readers.values_mut() { for router in router_readers.values_mut() {
s.spawn(|_| { let router_id = router.config.remote_id;
let mut buffer = [0u8; 1500 - 20]; // minus typical IP header space let local_id = config.local_id;
let meta_size = size_of::<Meta>(); let mark = router.config.mark;
// Pre-initialize with our Meta header (local -> remote) // 创建 4 个并发发送线程
let meta = Meta { for _ in 0..4 {
src_id: config.local_id, let socket = Arc::clone(&router.socket);
dst_id: router.config.remote_id, let endpoint = Arc::clone(&router.endpoint);
reversed: 0, let tun_reader = router.tun_reader.try_clone().unwrap();
}; let encrypt_fn = router.encrypt.clone();
// Turn the Meta struct into bytes
let meta_bytes = unsafe { s.spawn(move |_| {
std::slice::from_raw_parts(&meta as *const Meta as *const u8, meta_size) let mut buffers: Vec<Vec<u8>> = (0..BATCH_SIZE)
}; .map(|_| vec![0u8; BUFFER_SIZE])
buffer[..meta_size].copy_from_slice(meta_bytes); .collect();
let meta_size = size_of::<Meta>();
loop {
let n = router.tun_reader.read(&mut buffer[meta_size..]).unwrap(); // 预初始化 Meta 头
if let Some(ref addr) = *router.endpoint.read().unwrap() { let meta = Meta {
router.encrypt(&mut buffer[meta_size..meta_size + n]); src_id: local_id,
#[cfg(target_os = "linux")] dst_id: router_id,
let _ = router.socket.set_mark(router.config.mark); reversed: 0,
let _ = router.socket.send_to(&buffer[..meta_size + n], addr); };
let meta_bytes = unsafe {
std::slice::from_raw_parts(&meta as *const Meta as *const u8, meta_size)
};
for buffer in &mut buffers {
buffer[..meta_size].copy_from_slice(meta_bytes);
}
let mut current_buffer = 0;
loop {
let buffer = &mut buffers[current_buffer];
match tun_reader.read(&mut buffer[meta_size..]) {
Ok(n) if n > 0 => {
if let Some(ref addr) = *endpoint.read().unwrap() {
encrypt_fn(&mut buffer[meta_size..meta_size + n]);
#[cfg(target_os = "linux")]
let _ = socket.set_mark(mark);
// 使用 MSG_DONTWAIT 避免阻塞
match socket.send_to(&buffer[..meta_size + n], addr) {
Ok(_) => {},
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
// 缓冲区满,稍后重试
std::thread::yield_now();
},
Err(_) => {},
}
}
current_buffer = (current_buffer + 1) % BATCH_SIZE;
},
_ => std::thread::yield_now(),
}
} }
} });
}); }
} }
// 为每个 socket 创建多个接收线程
for (socket, mut router_writers) in router_writers3 { for (socket, mut router_writers) in router_writers3 {
s.spawn(move |_| { // 创建 4 个并发接收线程
let mut recv_buf = [MaybeUninit::uninit(); 1500]; for _ in 0..4 {
loop { let socket = Arc::clone(&socket);
let _ = (|| { let mut router_writers = router_writers.clone();
let (len, addr) = socket.recv_from(&mut recv_buf).unwrap(); let local_id = config.local_id;
let data: &mut [u8] = unsafe { transmute(&mut recv_buf[..len]) }; let local_secret = local_secret.clone();
let packet = Ipv4Packet::new(data).ok_or("malformed packet")?; s.spawn(move |_| {
let header_len = packet.get_header_length() as usize * 4; let mut recv_bufs: Vec<[MaybeUninit<u8>; BUFFER_SIZE]> = (0..BATCH_SIZE)
let (_ip_header, rest) = data .map(|_| [MaybeUninit::uninit(); BUFFER_SIZE])
.split_at_mut_checked(header_len) .collect();
.ok_or("malformed packet")?; let mut current_buffer = 0;
let (meta_bytes, payload) = rest
.split_at_mut_checked(size_of::<Meta>()) loop {
.ok_or("malformed packet")?; let recv_buf = &mut recv_bufs[current_buffer];
let meta: &Meta = unsafe { transmute(meta_bytes.as_ptr()) }; let _ = (|| {
if meta.dst_id == config.local_id && meta.reversed == 0 { let (len, addr) = match socket.recv_from(recv_buf) {
let router = router_writers Ok(result) => result,
.get_mut(&meta.src_id) Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
.ok_or("missing router")?; std::thread::yield_now();
*router.endpoint.write().unwrap() = Some(addr); return Ok(());
router.decrypt(payload, &local_secret); },
router.tun_writer.write_all(payload)?; Err(_) => return Ok(()),
} };
let data: &mut [u8] = unsafe { transmute(&mut recv_buf[..len]) };
Ok::<(), Box<dyn Error>>(()) let packet = Ipv4Packet::new(data).ok_or("malformed packet")?;
})(); let header_len = packet.get_header_length() as usize * 4;
} let (_ip_header, rest) = data
}); .split_at_mut_checked(header_len)
.ok_or("malformed packet")?;
let (meta_bytes, payload) = rest
.split_at_mut_checked(size_of::<Meta>())
.ok_or("malformed packet")?;
let meta: &Meta = unsafe { transmute(meta_bytes.as_ptr()) };
if meta.dst_id == local_id && meta.reversed == 0 {
let router = router_writers
.get_mut(&meta.src_id)
.ok_or("missing router")?;
*router.endpoint.write().unwrap() = Some(addr);
router.decrypt(payload, &local_secret);
// 批量写入以减少系统调用
let mut offset = 0;
while offset < payload.len() {
match router.tun_writer.write(&payload[offset..]) {
Ok(n) => offset += n,
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
std::thread::yield_now();
},
Err(_) => break,
}
}
}
current_buffer = (current_buffer + 1) % BATCH_SIZE;
Ok::<(), Box<dyn Error>>(())
})();
}
});
}
} }
}) })
.unwrap(); .unwrap();
Ok(()) Ok(())
} }
// 辅助函数:设置 socket 为非阻塞模式
#[cfg(target_os = "linux")]
use std::os::unix::io::AsRawFd;
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment