Commit 76fb2986 authored by nanahira's avatar nanahira

again

parent 6402cfe6
Pipeline #37435 failed with stages
in 2 minutes and 33 seconds
......@@ -8,7 +8,6 @@ use std::intrinsics::transmute;
use std::io::{Read, Write};
use std::mem::MaybeUninit;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
#[repr(C)]
pub struct Meta {
......@@ -40,12 +39,13 @@ pub struct Config {
use crossbeam_utils::thread;
use grouping_by::GroupingBy;
use pnet::packet::ipv4::Ipv4Packet;
use socket2::{Socket, SockAddr};
use socket2::Socket;
// 优化参数
const BUFFER_SIZE: usize = 65536; // 64KB 缓冲区
const BATCH_SIZE: usize = 32; // 批量处理大小
const SOCKET_BUFFER_SIZE: usize = 8 * 1024 * 1024; // 8MB socket 缓冲区
// 优化参数 - 针对高延迟网络
const MTU: usize = 1500;
const MAX_PACKET_SIZE: usize = MTU - 20; // 减去 IP 头部
const BATCH_SIZE: usize = 64; // 批量处理数量
const SOCKET_BUFFER_SIZE: usize = 16 * 1024 * 1024; // 16MB socket 缓冲区
fn main() -> Result<(), Box<dyn Error>> {
let config: Config = serde_json::from_str(env::args().nth(1).ok_or("need param")?.as_str())?;
......@@ -58,22 +58,42 @@ fn main() -> Result<(), Box<dyn Error>> {
.map(|c| Router::new(c, &mut sockets).map(|router| (c.remote_id, router)))
.collect::<Result<_, _>>()?;
// 优化 socket 缓冲区大小
// 优化 raw socket 缓冲区
for socket in sockets.values() {
let _ = socket.set_send_buffer_size(SOCKET_BUFFER_SIZE);
let _ = socket.set_recv_buffer_size(SOCKET_BUFFER_SIZE);
// Linux 特定优化
#[cfg(target_os = "linux")]
{
// 启用 GSO/GRO
use std::os::unix::io::AsRawFd;
unsafe {
// 设置 IP_RECVERR 以快速检测错误
let enable = 1i32;
libc::setsockopt(
socket.as_raw_fd(),
libc::SOL_UDP,
libc::UDP_GRO,
libc::IPPROTO_IP,
libc::IP_RECVERR,
&enable as *const _ as *const libc::c_void,
std::mem::size_of_val(&enable) as libc::socklen_t,
);
// 设置 SO_RCVBUFFORCE 和 SO_SNDBUFFORCE 绕过系统限制(需要 CAP_NET_ADMIN)
let force_size = SOCKET_BUFFER_SIZE as i32;
libc::setsockopt(
socket.as_raw_fd(),
libc::SOL_SOCKET,
libc::SO_RCVBUFFORCE,
&force_size as *const _ as *const libc::c_void,
std::mem::size_of_val(&force_size) as libc::socklen_t,
);
libc::setsockopt(
socket.as_raw_fd(),
libc::SOL_SOCKET,
libc::SO_SNDBUFFORCE,
&force_size as *const _ as *const libc::c_void,
std::mem::size_of_val(&force_size) as libc::socklen_t,
);
}
}
}
......@@ -102,97 +122,113 @@ fn main() -> Result<(), Box<dyn Error>> {
println!("created tuns");
thread::scope(|s| {
// 为每个路由创建多个发送线程
// 发送线程 - 批量处理以提高吞吐量
for router in router_readers.values_mut() {
let router_id = router.config.remote_id;
let local_id = config.local_id;
let mark = router.config.mark;
// 创建 4 个并发发送线程
for _ in 0..4 {
let socket = Arc::clone(&router.socket);
let endpoint = Arc::clone(&router.endpoint);
let tun_reader = router.tun_reader.try_clone().unwrap();
let encrypt_fn = router.encrypt.clone();
s.spawn(|_| {
// 为批量发送准备多个缓冲区
let mut buffers: Vec<Vec<u8>> = (0..BATCH_SIZE)
.map(|_| vec![0u8; MAX_PACKET_SIZE])
.collect();
let meta_size = size_of::<Meta>();
// 预初始化所有缓冲区的 Meta 头
let meta = Meta {
src_id: config.local_id,
dst_id: router.config.remote_id,
reversed: 0,
};
let meta_bytes = unsafe {
std::slice::from_raw_parts(&meta as *const Meta as *const u8, meta_size)
};
for buffer in &mut buffers {
buffer[..meta_size].copy_from_slice(meta_bytes);
}
let mut batch_count = 0;
let mut batch_data: Vec<(usize, usize)> = Vec::with_capacity(BATCH_SIZE); // (buffer_idx, data_len)
s.spawn(move |_| {
let mut buffers: Vec<Vec<u8>> = (0..BATCH_SIZE)
.map(|_| vec![0u8; BUFFER_SIZE])
.collect();
let meta_size = size_of::<Meta>();
// 预初始化 Meta 头
let meta = Meta {
src_id: local_id,
dst_id: router_id,
reversed: 0,
};
let meta_bytes = unsafe {
std::slice::from_raw_parts(&meta as *const Meta as *const u8, meta_size)
};
for buffer in &mut buffers {
buffer[..meta_size].copy_from_slice(meta_bytes);
}
let mut current_buffer = 0;
loop {
let buffer = &mut buffers[current_buffer];
match tun_reader.read(&mut buffer[meta_size..]) {
loop {
// 批量读取
batch_data.clear();
for i in 0..BATCH_SIZE {
match router.tun_reader.read(&mut buffers[i][meta_size..]) {
Ok(n) if n > 0 => {
if let Some(ref addr) = *endpoint.read().unwrap() {
encrypt_fn(&mut buffer[meta_size..meta_size + n]);
#[cfg(target_os = "linux")]
let _ = socket.set_mark(mark);
// 使用 MSG_DONTWAIT 避免阻塞
match socket.send_to(&buffer[..meta_size + n], addr) {
Ok(_) => {},
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
// 缓冲区满,稍后重试
std::thread::yield_now();
},
Err(_) => {},
}
batch_data.push((i, n));
if batch_data.len() >= 32 { // 达到一定数量就发送
break;
}
current_buffer = (current_buffer + 1) % BATCH_SIZE;
},
_ => std::thread::yield_now(),
}
Ok(_) => break,
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => break,
Err(_) => break,
}
}
});
}
// 批量加密和发送
if !batch_data.is_empty() {
if let Some(ref addr) = *router.endpoint.read().unwrap() {
#[cfg(target_os = "linux")]
let _ = router.socket.set_mark(router.config.mark);
// 批量处理所有包
for &(idx, len) in &batch_data {
let buffer = &mut buffers[idx];
router.encrypt(&mut buffer[meta_size..meta_size + len]);
// 快速发送,不等待
let _ = router.socket.send_to(&buffer[..meta_size + len], addr);
}
batch_count += batch_data.len();
// 定期 yield 以避免饥饿其他线程
if batch_count > 1000 {
batch_count = 0;
std::thread::yield_now();
}
}
} else {
// 没有数据时短暂休眠
std::thread::sleep(std::time::Duration::from_micros(100));
}
}
});
}
// 为每个 socket 创建多个接收线程
// 接收线程 - 批量处理和缓存写入
for (socket, mut router_writers) in router_writers3 {
// 创建 4 个并发接收线程
for _ in 0..4 {
let socket = Arc::clone(&socket);
let mut router_writers = router_writers.clone();
let local_id = config.local_id;
let local_secret = local_secret.clone();
s.spawn(move |_| {
// 多个接收缓冲区用于批量处理
let mut recv_bufs: Vec<[MaybeUninit<u8>; MAX_PACKET_SIZE]> =
(0..BATCH_SIZE).map(|_| [MaybeUninit::uninit(); MAX_PACKET_SIZE]).collect();
// 为每个 router 维护写入缓冲区
let mut write_buffers: HashMap<u8, Vec<u8>> = HashMap::new();
let mut recv_count = 0;
s.spawn(move |_| {
let mut recv_bufs: Vec<[MaybeUninit<u8>; BUFFER_SIZE]> = (0..BATCH_SIZE)
.map(|_| [MaybeUninit::uninit(); BUFFER_SIZE])
.collect();
let mut current_buffer = 0;
loop {
// 批量接收
let mut received_packets = Vec::new();
loop {
let recv_buf = &mut recv_bufs[current_buffer];
// 尝试接收多个包
for i in 0..32 {
match socket.recv_from(&mut recv_bufs[i]) {
Ok((len, addr)) => {
received_packets.push((i, len, addr));
}
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
break;
}
Err(_) => break,
}
}
// 批量处理接收到的包
for (buf_idx, len, addr) in received_packets {
let _ = (|| {
let (len, addr) = match socket.recv_from(recv_buf) {
Ok(result) => result,
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
std::thread::yield_now();
return Ok(());
},
Err(_) => return Ok(()),
let data: &mut [u8] = unsafe {
transmute(&mut recv_bufs[buf_idx][..len])
};
let data: &mut [u8] = unsafe { transmute(&mut recv_buf[..len]) };
let packet = Ipv4Packet::new(data).ok_or("malformed packet")?;
let header_len = packet.get_header_length() as usize * 4;
......@@ -203,38 +239,52 @@ fn main() -> Result<(), Box<dyn Error>> {
.split_at_mut_checked(size_of::<Meta>())
.ok_or("malformed packet")?;
let meta: &Meta = unsafe { transmute(meta_bytes.as_ptr()) };
if meta.dst_id == local_id && meta.reversed == 0 {
if meta.dst_id == config.local_id && meta.reversed == 0 {
let router = router_writers
.get_mut(&meta.src_id)
.ok_or("missing router")?;
*router.endpoint.write().unwrap() = Some(addr);
router.decrypt(payload, &local_secret);
// 批量写入以减少系统调用
let mut offset = 0;
while offset < payload.len() {
match router.tun_writer.write(&payload[offset..]) {
Ok(n) => offset += n,
Err(e) if e.kind() == std::io::ErrorKind::WouldBlock => {
std::thread::yield_now();
},
Err(_) => break,
}
// 缓存数据以批量写入
let write_buf = write_buffers
.entry(meta.src_id)
.or_insert_with(|| Vec::with_capacity(65536));
write_buf.extend_from_slice(payload);
// 当缓冲区达到一定大小时写入
if write_buf.len() >= 32768 {
let data = std::mem::take(write_buf);
let _ = router.tun_writer.write_all(&data);
}
}
current_buffer = (current_buffer + 1) % BATCH_SIZE;
Ok::<(), Box<dyn Error>>(())
})();
}
});
}
// 定期刷新所有缓冲区
recv_count += 1;
if recv_count > 100 {
recv_count = 0;
for (router_id, data) in write_buffers.drain() {
if !data.is_empty() {
if let Some(router) = router_writers.get_mut(&router_id) {
let _ = router.tun_writer.write_all(&data);
}
}
}
}
// 如果没有接收到数据,短暂休眠
if received_packets.is_empty() {
std::thread::sleep(std::time::Duration::from_micros(100));
}
}
});
}
})
.unwrap();
Ok(())
}
// 辅助函数:设置 socket 为非阻塞模式
#[cfg(target_os = "linux")]
use std::os::unix::io::AsRawFd;
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment