Commit 8607551f authored by nanahira's avatar nanahira

version 5

parent 0a8ac54c
Pipeline #37393 failed with stages
in 1 minute and 7 seconds
...@@ -8,6 +8,8 @@ use std::intrinsics::transmute; ...@@ -8,6 +8,8 @@ use std::intrinsics::transmute;
use std::io::{Read, Write}; use std::io::{Read, Write};
use std::mem::MaybeUninit; use std::mem::MaybeUninit;
use std::sync::Arc; use std::sync::Arc;
use std::sync::mpsc::{sync_channel, Receiver, SyncSender};
use std::time::Duration;
#[repr(C)] #[repr(C)]
pub struct Meta { pub struct Meta {
...@@ -36,20 +38,42 @@ pub struct Config { ...@@ -36,20 +38,42 @@ pub struct Config {
pub local_secret: String, pub local_secret: String,
pub routers: Vec<ConfigRouter>, pub routers: Vec<ConfigRouter>,
} }
use crossbeam_utils::thread; use crossbeam_utils::thread;
use grouping_by::GroupingBy; use grouping_by::GroupingBy;
use pnet::packet::ipv4::Ipv4Packet; use pnet::packet::ipv4::Ipv4Packet;
use socket2::Socket; use socket2::Socket;
// 批处理配置
const BATCH_SIZE: usize = 16;
const CHANNEL_SIZE: usize = 256;
const SOCKET_BUFFER_SIZE: usize = 2 * 1024 * 1024; // 2MB
struct Packet {
data: Vec<u8>,
len: usize,
}
fn main() -> Result<(), Box<dyn Error>> { fn main() -> Result<(), Box<dyn Error>> {
let config: Config = serde_json::from_str(env::args().nth(1).ok_or("need param")?.as_str())?; let config: Config = serde_json::from_str(env::args().nth(1).ok_or("need param")?.as_str())?;
let local_secret: [u8; SECRET_LENGTH] = Router::create_secret(config.local_secret.as_str())?; let local_secret: [u8; SECRET_LENGTH] = Router::create_secret(config.local_secret.as_str())?;
let mut sockets: HashMap<u16, Arc<Socket>> = HashMap::new(); let mut sockets: HashMap<u16, Arc<Socket>> = HashMap::new();
let routers: HashMap<u8, Router> = config let routers: HashMap<u8, Router> = config
.routers .routers
.iter() .iter()
.map(|c| Router::new(c, &mut sockets).map(|router| (c.remote_id, router))) .map(|c| {
Router::new(c, &mut sockets).map(|router| {
// 为 socket 设置更大的缓冲区
if let Some(socket) = sockets.get(&Router::key(c)) {
let _ = socket.set_send_buffer_size(SOCKET_BUFFER_SIZE);
let _ = socket.set_recv_buffer_size(SOCKET_BUFFER_SIZE);
}
(c.remote_id, router)
})
})
.collect::<Result<_, _>>()?; .collect::<Result<_, _>>()?;
let (mut router_readers, router_writers): ( let (mut router_readers, router_writers): (
HashMap<u8, RouterReader>, HashMap<u8, RouterReader>,
HashMap<u8, RouterWriter>, HashMap<u8, RouterWriter>,
...@@ -60,6 +84,7 @@ fn main() -> Result<(), Box<dyn Error>> { ...@@ -60,6 +84,7 @@ fn main() -> Result<(), Box<dyn Error>> {
((id, reader), (id, writer)) ((id, reader), (id, writer))
}) })
.unzip(); .unzip();
let router_writers3: Vec<(Arc<Socket>, HashMap<u8, RouterWriter>)> = router_writers let router_writers3: Vec<(Arc<Socket>, HashMap<u8, RouterWriter>)> = router_writers
.into_iter() .into_iter()
.grouping_by(|(_, v)| v.key()) .grouping_by(|(_, v)| v.key())
...@@ -71,38 +96,82 @@ fn main() -> Result<(), Box<dyn Error>> { ...@@ -71,38 +96,82 @@ fn main() -> Result<(), Box<dyn Error>> {
) )
}) })
.collect(); .collect();
println!("created tuns"); println!("created tuns");
thread::scope(|s| { thread::scope(|s| {
// 发送端优化:添加批处理
for router in router_readers.values_mut() { for router in router_readers.values_mut() {
let (tx, rx): (SyncSender<Packet>, Receiver<Packet>) = sync_channel(CHANNEL_SIZE);
// 读取线程
s.spawn(|_| { s.spawn(|_| {
let mut buffer = [0u8; 1500 - 20]; // minus typical IP header space let mut buffer = [0u8; 1500 - 20];
let meta_size = std::mem::size_of::<Meta>(); let meta_size = size_of::<Meta>();
// Pre-initialize with our Meta header (local -> remote)
let meta = Meta { let meta = Meta {
src_id: config.local_id, src_id: config.local_id,
dst_id: router.config.remote_id, dst_id: router.config.remote_id,
reversed: 0, reversed: 0,
}; };
// Turn the Meta struct into bytes
let meta_bytes = unsafe { let meta_bytes = unsafe {
std::slice::from_raw_parts(&meta as *const Meta as *const u8, meta_size) std::slice::from_raw_parts(&meta as *const Meta as *const u8, meta_size)
}; };
buffer[..meta_size].copy_from_slice(meta_bytes);
loop { loop {
let n = router.tun_reader.read(&mut buffer[meta_size..]).unwrap(); let n = router.tun_reader.read(&mut buffer[meta_size..]).unwrap();
if let Some(ref addr) = *router.endpoint.read().unwrap() { if n > 0 {
router.encrypt(&mut buffer[meta_size..meta_size + n]); let mut packet_data = vec![0u8; meta_size + n];
#[cfg(target_os = "linux")] packet_data[..meta_size].copy_from_slice(meta_bytes);
let _ = router.socket.set_mark(router.config.mark); packet_data[meta_size..].copy_from_slice(&buffer[meta_size..meta_size + n]);
let _ = router.socket.send_to(&buffer[..meta_size + n], addr);
let packet = Packet {
data: packet_data,
len: meta_size + n,
};
let _ = tx.try_send(packet);
}
}
});
// 批处理发送线程
s.spawn(|_| {
let mut batch = Vec::with_capacity(BATCH_SIZE);
let meta_size = size_of::<Meta>();
loop {
batch.clear();
// 收集批量数据,使用超时避免延迟
if let Ok(packet) = rx.recv_timeout(Duration::from_micros(100)) {
batch.push(packet);
// 尝试收集更多包
while batch.len() < BATCH_SIZE {
match rx.try_recv() {
Ok(packet) => batch.push(packet),
Err(_) => break,
}
}
}
// 批量发送
if !batch.is_empty() {
if let Some(ref addr) = *router.endpoint.read().unwrap() {
// 批量加密和发送
for packet in &mut batch {
router.encrypt(&mut packet.data[meta_size..packet.len]);
#[cfg(target_os = "linux")]
let _ = router.socket.set_mark(router.config.mark);
let _ = router.socket.send_to(&packet.data[..packet.len], addr);
}
}
} }
} }
}); });
} }
// 接收端保持原样,避免引入问题
for (socket, mut router_writers) in router_writers3 { for (socket, mut router_writers) in router_writers3 {
s.spawn(move |_| { s.spawn(move |_| {
let mut recv_buf = [MaybeUninit::uninit(); 1500]; let mut recv_buf = [MaybeUninit::uninit(); 1500];
...@@ -117,7 +186,7 @@ fn main() -> Result<(), Box<dyn Error>> { ...@@ -117,7 +186,7 @@ fn main() -> Result<(), Box<dyn Error>> {
.split_at_mut_checked(header_len) .split_at_mut_checked(header_len)
.ok_or("malformed packet")?; .ok_or("malformed packet")?;
let (meta_bytes, payload) = rest let (meta_bytes, payload) = rest
.split_at_mut_checked(std::mem::size_of::<Meta>()) .split_at_mut_checked(size_of::<Meta>())
.ok_or("malformed packet")?; .ok_or("malformed packet")?;
let meta: &Meta = unsafe { transmute(meta_bytes.as_ptr()) }; let meta: &Meta = unsafe { transmute(meta_bytes.as_ptr()) };
if meta.dst_id == config.local_id && meta.reversed == 0 { if meta.dst_id == config.local_id && meta.reversed == 0 {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment