Commit 5cf0d2c5 authored by nanahira's avatar nanahira

retry

parent bbb76719
Pipeline #37401 failed with stages
in 15 seconds
...@@ -4,12 +4,17 @@ use crate::router::{Router, RouterReader, RouterWriter, SECRET_LENGTH}; ...@@ -4,12 +4,17 @@ use crate::router::{Router, RouterReader, RouterWriter, SECRET_LENGTH};
use std::collections::HashMap; use std::collections::HashMap;
use std::env; use std::env;
use std::error::Error; use std::error::Error;
use std::intrinsics::transmute;
use std::io::{Read, Write}; use std::io::{Read, Write};
use std::mem::MaybeUninit; use std::mem::{self, MaybeUninit};
use std::sync::Arc; use std::net::SocketAddr;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use std::sync::mpsc::{sync_channel, Receiver, SyncSender, TryRecvError};
use std::time::{Duration, Instant};
use std::thread;
#[repr(C)] #[repr(C)]
#[derive(Copy, Clone)]
pub struct Meta { pub struct Meta {
pub src_id: u8, pub src_id: u8,
pub dst_id: u8, pub dst_id: u8,
...@@ -36,20 +41,78 @@ pub struct Config { ...@@ -36,20 +41,78 @@ pub struct Config {
pub local_secret: String, pub local_secret: String,
pub routers: Vec<ConfigRouter>, pub routers: Vec<ConfigRouter>,
} }
use crossbeam_utils::thread;
use crossbeam_utils::thread as scoped_thread;
use grouping_by::GroupingBy; use grouping_by::GroupingBy;
use pnet::packet::ipv4::Ipv4Packet; use pnet::packet::ipv4::Ipv4Packet;
use socket2::Socket; use socket2::Socket;
// 性能优化配置
const BATCH_SIZE: usize = 32;
const BATCH_TIMEOUT_MILLIS: u64 = 1;
const CHANNEL_SIZE: usize = 1024;
const SOCKET_BUFFER_SIZE: usize = 4 * 1024 * 1024; // 4MB
const MAX_PACKET_SIZE: usize = 1500;
// 数据包结构
#[derive(Clone)]
struct Packet {
data: Vec<u8>,
len: usize,
}
impl Packet {
fn new() -> Self {
Self {
data: vec![0u8; MAX_PACKET_SIZE],
len: 0,
}
}
}
// 统计信息
struct Stats {
packets_sent: AtomicU64,
packets_recv: AtomicU64,
bytes_sent: AtomicU64,
bytes_recv: AtomicU64,
}
fn optimize_socket(socket: &Socket) -> Result<(), Box<dyn Error>> {
socket.set_send_buffer_size(SOCKET_BUFFER_SIZE)?;
socket.set_recv_buffer_size(SOCKET_BUFFER_SIZE)?;
socket.set_nonblocking(false)?;
#[cfg(target_os = "linux")]
{
// Linux 特定优化
socket.set_nodelay(true)?;
}
Ok(())
}
fn main() -> Result<(), Box<dyn Error>> { fn main() -> Result<(), Box<dyn Error>> {
let config: Config = serde_json::from_str(env::args().nth(1).ok_or("need param")?.as_str())?; let config: Config = serde_json::from_str(env::args().nth(1).ok_or("need param")?.as_str())?;
let local_secret: [u8; SECRET_LENGTH] = Router::create_secret(config.local_secret.as_str())?; let local_secret: [u8; SECRET_LENGTH] = Router::create_secret(config.local_secret.as_str())?;
let mut sockets: HashMap<u16, Arc<Socket>> = HashMap::new(); let mut sockets: HashMap<u16, Arc<Socket>> = HashMap::new();
// 创建并优化 routers
let routers: HashMap<u8, Router> = config let routers: HashMap<u8, Router> = config
.routers .routers
.iter() .iter()
.map(|c| Router::new(c, &mut sockets).map(|router| (c.remote_id, router))) .map(|c| {
Router::new(c, &mut sockets).map(|router| {
// 优化每个 socket
let router_key = Router::key(c); // 使用关联函数
if let Some(socket) = sockets.get(&router_key) {
let _ = optimize_socket(socket);
}
(c.remote_id, router)
})
})
.collect::<Result<_, _>>()?; .collect::<Result<_, _>>()?;
let (mut router_readers, router_writers): ( let (mut router_readers, router_writers): (
HashMap<u8, RouterReader>, HashMap<u8, RouterReader>,
HashMap<u8, RouterWriter>, HashMap<u8, RouterWriter>,
...@@ -60,6 +123,7 @@ fn main() -> Result<(), Box<dyn Error>> { ...@@ -60,6 +123,7 @@ fn main() -> Result<(), Box<dyn Error>> {
((id, reader), (id, writer)) ((id, reader), (id, writer))
}) })
.unzip(); .unzip();
let router_writers3: Vec<(Arc<Socket>, HashMap<u8, RouterWriter>)> = router_writers let router_writers3: Vec<(Arc<Socket>, HashMap<u8, RouterWriter>)> = router_writers
.into_iter() .into_iter()
.grouping_by(|(_, v)| v.key()) .grouping_by(|(_, v)| v.key())
...@@ -71,77 +135,218 @@ fn main() -> Result<(), Box<dyn Error>> { ...@@ -71,77 +135,218 @@ fn main() -> Result<(), Box<dyn Error>> {
) )
}) })
.collect(); .collect();
println!("created tuns");
// 全局统计
let stats = Arc::new(Stats {
packets_sent: AtomicU64::new(0),
packets_recv: AtomicU64::new(0),
bytes_sent: AtomicU64::new(0),
bytes_recv: AtomicU64::new(0),
});
println!("Created optimized TUN devices");
thread::scope(|s| { scoped_thread::scope(|s| {
// 为每个 router 创建读写线程
for router in router_readers.values_mut() { for router in router_readers.values_mut() {
s.spawn(|_| { let router_id = router.config.remote_id;
// 使用 2048 字节缓冲区,足够处理大多数包且避免过大的内存占用 let local_id = config.local_id;
let mut buffer = vec![0u8; 2048]; let stats_clone = Arc::clone(&stats);
let meta_size = size_of::<Meta>();
// 创建批处理通道
// 预初始化 Meta 头部 let (tx, rx): (SyncSender<Packet>, Receiver<Packet>) = sync_channel(CHANNEL_SIZE);
// TUN 读取线程
s.spawn(move |_| {
let meta_size = mem::size_of::<Meta>();
let mut buffer = vec![0u8; MAX_PACKET_SIZE];
// 预构建 Meta
let meta = Meta { let meta = Meta {
src_id: config.local_id, src_id: local_id,
dst_id: router.config.remote_id, dst_id: router_id,
reversed: 0, reversed: 0,
}; };
let meta_bytes = unsafe { let meta_bytes = unsafe {
std::slice::from_raw_parts(&meta as *const Meta as *const u8, meta_size) std::slice::from_raw_parts(&meta as *const _ as *const u8, meta_size)
}; };
buffer[..meta_size].copy_from_slice(meta_bytes);
loop { loop {
// 写入 Meta 头
buffer[..meta_size].copy_from_slice(meta_bytes);
// 读取数据
match router.tun_reader.read(&mut buffer[meta_size..]) { match router.tun_reader.read(&mut buffer[meta_size..]) {
Ok(n) => { Ok(n) if n > 0 => {
if let Some(ref addr) = *router.endpoint.read().unwrap() { let packet = Packet {
router.encrypt(&mut buffer[meta_size..meta_size + n]); data: buffer[..meta_size + n].to_vec(),
#[cfg(target_os = "linux")] len: meta_size + n,
let _ = router.socket.set_mark(router.config.mark); };
let _ = router.socket.send_to(&buffer[..meta_size + n], addr); let _ = tx.try_send(packet);
}
_ => continue,
}
}
});
// 批处理发送线程
s.spawn(move |_| {
let mut batch = Vec::with_capacity(BATCH_SIZE);
let meta_size = mem::size_of::<Meta>();
let mut last_batch_time = Instant::now();
loop {
batch.clear();
// 收集数据包
match rx.recv_timeout(Duration::from_millis(BATCH_TIMEOUT_MILLIS)) {
Ok(packet) => batch.push(packet),
Err(_) => {
if batch.is_empty() { continue; }
}
}
// 继续收集直到批量大小或超时
while batch.len() < BATCH_SIZE {
match rx.try_recv() {
Ok(packet) => batch.push(packet),
Err(TryRecvError::Empty) => {
if last_batch_time.elapsed() > Duration::from_millis(BATCH_TIMEOUT_MILLIS) {
break;
}
} }
Err(TryRecvError::Disconnected) => return,
} }
Err(_) => continue, }
if !batch.is_empty() {
// 获取endpoint
if let Some(ref addr) = *router.endpoint.read().unwrap() {
// 加密所有包
for packet in &mut batch {
router.encrypt(&mut packet.data[meta_size..packet.len]);
}
// 设置 mark(每批只设置一次)
#[cfg(target_os = "linux")]
let _ = router.socket.set_mark(router.config.mark);
// 批量发送
let mut total_bytes = 0;
for packet in &batch {
if let Ok(_) = router.socket.send_to(&packet.data[..packet.len], addr) {
total_bytes += packet.len;
}
}
// 更新统计
stats_clone.packets_sent.fetch_add(batch.len() as u64, Ordering::Relaxed);
stats_clone.bytes_sent.fetch_add(total_bytes as u64, Ordering::Relaxed);
}
last_batch_time = Instant::now();
} }
} }
}); });
} }
// 优化的接收线程
for (socket, mut router_writers) in router_writers3 { for (socket, mut router_writers) in router_writers3 {
let stats_clone = Arc::clone(&stats);
let local_secret_clone = local_secret.clone();
let config_local_id = config.local_id;
s.spawn(move |_| { s.spawn(move |_| {
// 使用 2048 字节接收缓冲区 // 预分配多个缓冲区
let mut recv_buf = vec![MaybeUninit::uninit(); 2048]; let mut recv_bufs: Vec<[MaybeUninit<u8>; MAX_PACKET_SIZE]> =
(0..4).map(|_| unsafe { MaybeUninit::uninit().assume_init() }).collect();
let mut buf_idx = 0;
loop { loop {
let _ = (|| -> Result<(), Box<dyn Error>> { let recv_buf = &mut recv_bufs[buf_idx];
let (len, addr) = socket.recv_from(&mut recv_buf)?; buf_idx = (buf_idx + 1) % recv_bufs.len();
let data: &mut [u8] = unsafe { transmute(&mut recv_buf[..len]) };
match socket.recv_from(recv_buf) {
let packet = Ipv4Packet::new(data).ok_or("malformed packet")?; Ok((len, sock_addr)) => {
let header_len = packet.get_header_length() as usize * 4; // 转换为 SocketAddr
let (_ip_header, rest) = data let addr = match sock_addr.as_socket() {
.split_at_mut_checked(header_len) Some(addr) => addr,
.ok_or("malformed packet")?; None => continue,
let (meta_bytes, payload) = rest };
.split_at_mut_checked(size_of::<Meta>())
.ok_or("malformed packet")?; // 将 MaybeUninit 转换为初始化的数据
let meta: &Meta = unsafe { transmute(meta_bytes.as_ptr()) }; let data = unsafe {
std::slice::from_raw_parts_mut(
if meta.dst_id == config.local_id && meta.reversed == 0 { recv_buf.as_mut_ptr() as *mut u8,
let router = router_writers len
.get_mut(&meta.src_id) )
.ok_or("missing router")?; };
*router.endpoint.write().unwrap() = Some(addr);
router.decrypt(payload, &local_secret); // 快速路径检查
router.tun_writer.write_all(payload)?; if len < 20 + mem::size_of::<Meta>() {
continue;
}
if let Some(packet) = Ipv4Packet::new(data) {
let header_len = packet.get_header_length() as usize * 4;
if header_len < data.len() {
let rest = &mut data[header_len..];
if rest.len() >= mem::size_of::<Meta>() {
let (meta_bytes, payload) = rest.split_at_mut(mem::size_of::<Meta>());
let meta: &Meta = unsafe { &*(meta_bytes.as_ptr() as *const Meta) };
if meta.dst_id == config_local_id && meta.reversed == 0 {
if let Some(router) = router_writers.get_mut(&meta.src_id) {
// 更新endpoint
*router.endpoint.write().unwrap() = Some(addr);
// 解密
router.decrypt(payload, &local_secret_clone);
// 写入TUN
let _ = router.tun_writer.write_all(payload);
// 更新统计
stats_clone.packets_recv.fetch_add(1, Ordering::Relaxed);
stats_clone.bytes_recv.fetch_add(payload.len() as u64, Ordering::Relaxed);
}
}
}
}
}
} }
Err(_) => continue,
Ok(()) }
})();
} }
}); });
} }
// 统计线程
let stats_clone = Arc::clone(&stats);
s.spawn(move |_| {
let mut last_stats = (0u64, 0u64, 0u64, 0u64);
loop {
thread::sleep(Duration::from_secs(10));
let sent_pkts = stats_clone.packets_sent.load(Ordering::Relaxed);
let recv_pkts = stats_clone.packets_recv.load(Ordering::Relaxed);
let sent_bytes = stats_clone.bytes_sent.load(Ordering::Relaxed);
let recv_bytes = stats_clone.bytes_recv.load(Ordering::Relaxed);
let sent_pps = (sent_pkts - last_stats.0) / 10;
let recv_pps = (recv_pkts - last_stats.1) / 10;
let sent_mbps = ((sent_bytes - last_stats.2) * 8) / (10 * 1024 * 1024);
let recv_mbps = ((recv_bytes - last_stats.3) * 8) / (10 * 1024 * 1024);
println!(
"Stats - TX: {} pps, {} Mbps | RX: {} pps, {} Mbps",
sent_pps, sent_mbps, recv_pps, recv_mbps
);
last_stats = (sent_pkts, recv_pkts, sent_bytes, recv_bytes);
}
});
}) })
.unwrap(); .unwrap();
Ok(()) Ok(())
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment