Commit a3c8c95b authored by nanahira's avatar nanahira

version 6

parent 8607551f
Pipeline #37394 failed with stages
in 14 seconds
...@@ -8,8 +8,6 @@ use std::intrinsics::transmute; ...@@ -8,8 +8,6 @@ use std::intrinsics::transmute;
use std::io::{Read, Write}; use std::io::{Read, Write};
use std::mem::MaybeUninit; use std::mem::MaybeUninit;
use std::sync::Arc; use std::sync::Arc;
use std::sync::mpsc::{sync_channel, Receiver, SyncSender};
use std::time::Duration;
#[repr(C)] #[repr(C)]
pub struct Meta { pub struct Meta {
...@@ -44,27 +42,19 @@ use grouping_by::GroupingBy; ...@@ -44,27 +42,19 @@ use grouping_by::GroupingBy;
use pnet::packet::ipv4::Ipv4Packet; use pnet::packet::ipv4::Ipv4Packet;
use socket2::Socket; use socket2::Socket;
// 批处理配置 // 性能优化:增加 socket 缓冲区大小
const BATCH_SIZE: usize = 16;
const CHANNEL_SIZE: usize = 256;
const SOCKET_BUFFER_SIZE: usize = 2 * 1024 * 1024; // 2MB const SOCKET_BUFFER_SIZE: usize = 2 * 1024 * 1024; // 2MB
struct Packet {
data: Vec<u8>,
len: usize,
}
fn main() -> Result<(), Box<dyn Error>> { fn main() -> Result<(), Box<dyn Error>> {
let config: Config = serde_json::from_str(env::args().nth(1).ok_or("need param")?.as_str())?; let config: Config = serde_json::from_str(env::args().nth(1).ok_or("need param")?.as_str())?;
let local_secret: [u8; SECRET_LENGTH] = Router::create_secret(config.local_secret.as_str())?; let local_secret: [u8; SECRET_LENGTH] = Router::create_secret(config.local_secret.as_str())?;
let mut sockets: HashMap<u16, Arc<Socket>> = HashMap::new(); let mut sockets: HashMap<u16, Arc<Socket>> = HashMap::new();
let routers: HashMap<u8, Router> = config let routers: HashMap<u8, Router> = config
.routers .routers
.iter() .iter()
.map(|c| { .map(|c| {
Router::new(c, &mut sockets).map(|router| { Router::new(c, &mut sockets).map(|router| {
// 为 socket 设置更大的缓冲区 // 优化:设置更大的 socket 缓冲区
if let Some(socket) = sockets.get(&Router::key(c)) { if let Some(socket) = sockets.get(&Router::key(c)) {
let _ = socket.set_send_buffer_size(SOCKET_BUFFER_SIZE); let _ = socket.set_send_buffer_size(SOCKET_BUFFER_SIZE);
let _ = socket.set_recv_buffer_size(SOCKET_BUFFER_SIZE); let _ = socket.set_recv_buffer_size(SOCKET_BUFFER_SIZE);
...@@ -73,7 +63,6 @@ fn main() -> Result<(), Box<dyn Error>> { ...@@ -73,7 +63,6 @@ fn main() -> Result<(), Box<dyn Error>> {
}) })
}) })
.collect::<Result<_, _>>()?; .collect::<Result<_, _>>()?;
let (mut router_readers, router_writers): ( let (mut router_readers, router_writers): (
HashMap<u8, RouterReader>, HashMap<u8, RouterReader>,
HashMap<u8, RouterWriter>, HashMap<u8, RouterWriter>,
...@@ -84,7 +73,6 @@ fn main() -> Result<(), Box<dyn Error>> { ...@@ -84,7 +73,6 @@ fn main() -> Result<(), Box<dyn Error>> {
((id, reader), (id, writer)) ((id, reader), (id, writer))
}) })
.unzip(); .unzip();
let router_writers3: Vec<(Arc<Socket>, HashMap<u8, RouterWriter>)> = router_writers let router_writers3: Vec<(Arc<Socket>, HashMap<u8, RouterWriter>)> = router_writers
.into_iter() .into_iter()
.grouping_by(|(_, v)| v.key()) .grouping_by(|(_, v)| v.key())
...@@ -96,88 +84,56 @@ fn main() -> Result<(), Box<dyn Error>> { ...@@ -96,88 +84,56 @@ fn main() -> Result<(), Box<dyn Error>> {
) )
}) })
.collect(); .collect();
println!("created tuns"); println!("created tuns");
thread::scope(|s| { thread::scope(|s| {
// 发送端优化:添加批处理
for router in router_readers.values_mut() { for router in router_readers.values_mut() {
let (tx, rx): (SyncSender<Packet>, Receiver<Packet>) = sync_channel(CHANNEL_SIZE);
// 读取线程
s.spawn(|_| { s.spawn(|_| {
let mut buffer = [0u8; 1500 - 20]; // 优化:增加缓冲区大小
let mut buffer = [0u8; 1500];
let meta_size = size_of::<Meta>(); let meta_size = size_of::<Meta>();
// Pre-initialize with our Meta header (local -> remote)
let meta = Meta { let meta = Meta {
src_id: config.local_id, src_id: config.local_id,
dst_id: router.config.remote_id, dst_id: router.config.remote_id,
reversed: 0, reversed: 0,
}; };
// Turn the Meta struct into bytes
let meta_bytes = unsafe { let meta_bytes = unsafe {
std::slice::from_raw_parts(&meta as *const Meta as *const u8, meta_size) std::slice::from_raw_parts(&meta as *const Meta as *const u8, meta_size)
}; };
buffer[..meta_size].copy_from_slice(meta_bytes);
loop { loop {
let n = router.tun_reader.read(&mut buffer[meta_size..]).unwrap(); let n = router.tun_reader.read(&mut buffer[meta_size..]).unwrap();
if n > 0 {
let mut packet_data = vec![0u8; meta_size + n];
packet_data[..meta_size].copy_from_slice(meta_bytes);
packet_data[meta_size..].copy_from_slice(&buffer[meta_size..meta_size + n]);
let packet = Packet {
data: packet_data,
len: meta_size + n,
};
let _ = tx.try_send(packet);
}
}
});
// 批处理发送线程
s.spawn(|_| {
let mut batch = Vec::with_capacity(BATCH_SIZE);
let meta_size = size_of::<Meta>();
loop {
batch.clear();
// 收集批量数据,使用超时避免延迟
if let Ok(packet) = rx.recv_timeout(Duration::from_micros(100)) {
batch.push(packet);
// 尝试收集更多包
while batch.len() < BATCH_SIZE {
match rx.try_recv() {
Ok(packet) => batch.push(packet),
Err(_) => break,
}
}
}
// 批量发送
if !batch.is_empty() {
if let Some(ref addr) = *router.endpoint.read().unwrap() { if let Some(ref addr) = *router.endpoint.read().unwrap() {
// 批量加密和发送 router.encrypt(&mut buffer[meta_size..meta_size + n]);
for packet in &mut batch {
router.encrypt(&mut packet.data[meta_size..packet.len]);
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
let _ = router.socket.set_mark(router.config.mark); let _ = router.socket.set_mark(router.config.mark);
let _ = router.socket.send_to(&packet.data[..packet.len], addr); let _ = router.socket.send_to(&buffer[..meta_size + n], addr);
}
}
} }
} }
}); });
} }
// 接收端保持原样,避免引入问题
for (socket, mut router_writers) in router_writers3 { for (socket, mut router_writers) in router_writers3 {
s.spawn(move |_| { s.spawn(move |_| {
let mut recv_buf = [MaybeUninit::uninit(); 1500]; // 优化:使用多个缓冲区轮换
let mut recv_bufs: [[MaybeUninit<u8>; 1500]; 4] = [
[MaybeUninit::uninit(); 1500],
[MaybeUninit::uninit(); 1500],
[MaybeUninit::uninit(); 1500],
[MaybeUninit::uninit(); 1500],
];
let mut buf_idx = 0;
loop { loop {
let recv_buf = &mut recv_bufs[buf_idx];
buf_idx = (buf_idx + 1) % 4;
let _ = (|| { let _ = (|| {
let (len, addr) = socket.recv_from(&mut recv_buf).unwrap(); let (len, addr) = socket.recv_from(recv_buf).unwrap();
let data: &mut [u8] = unsafe { transmute(&mut recv_buf[..len]) }; let data: &mut [u8] = unsafe { transmute(&mut recv_buf[..len]) };
let packet = Ipv4Packet::new(data).ok_or("malformed packet")?; let packet = Ipv4Packet::new(data).ok_or("malformed packet")?;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment