Commit ce345e4d authored by nanahira's avatar nanahira

optimized by claude

parent f5ffc9fe
Pipeline #37391 failed with stages
in 1 minute and 46 seconds
mod router; mod router;
use crate::router::{Router, RouterReader, RouterWriter, SECRET_LENGTH}; use crate::router::{Router, RouterReader, RouterWriter, SECRET_LENGTH};
use arc_swap::ArcSwap;
use bytes::{Bytes, BytesMut};
use crossbeam_channel::{bounded, Receiver, Sender, TryRecvError};
use parking_lot::RwLock;
use std::collections::HashMap; use std::collections::HashMap;
use std::env; use std::env;
use std::error::Error; use std::error::Error;
use std::intrinsics::transmute; use std::intrinsics::transmute;
use std::io::{Read, Write}; use std::io::{Read, Write};
use std::mem::MaybeUninit; use std::mem::{self, MaybeUninit};
use std::net::SocketAddr;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc; use std::sync::Arc;
use std::time::{Duration, Instant};
#[repr(C)] #[repr(C)]
#[derive(Copy, Clone)]
pub struct Meta { pub struct Meta {
pub src_id: u8, pub src_id: u8,
pub dst_id: u8, pub dst_id: u8,
...@@ -36,20 +44,132 @@ pub struct Config { ...@@ -36,20 +44,132 @@ pub struct Config {
pub local_secret: String, pub local_secret: String,
pub routers: Vec<ConfigRouter>, pub routers: Vec<ConfigRouter>,
} }
use crossbeam_utils::thread; use crossbeam_utils::thread;
use grouping_by::GroupingBy; use grouping_by::GroupingBy;
use pnet::packet::ipv4::Ipv4Packet; use pnet::packet::ipv4::Ipv4Packet;
use socket2::Socket; use socket2::{Domain, Protocol, Socket, Type};
// 性能优化配置
const BATCH_SIZE: usize = 64; // 增加批处理大小
const BATCH_TIMEOUT_MICROS: u64 = 50; // 减少批处理超时
const CHANNEL_SIZE: usize = 4096; // 增加通道容量
const SOCKET_BUFFER_SIZE: usize = 8 * 1024 * 1024; // 8MB socket 缓冲区
const PACKET_POOL_SIZE: usize = 256; // 预分配的数据包池大小
const MAX_PACKET_SIZE: usize = 1500;
// 零拷贝数据包结构
struct Packet {
data: BytesMut,
len: usize,
timestamp: Instant,
}
impl Packet {
fn new() -> Self {
Self {
data: BytesMut::with_capacity(MAX_PACKET_SIZE),
len: 0,
timestamp: Instant::now(),
}
}
fn reset(&mut self) {
self.data.clear();
self.len = 0;
self.timestamp = Instant::now();
}
}
// 数据包池 - 减少内存分配
struct PacketPool {
pool: Vec<Packet>,
used: AtomicU64,
}
impl PacketPool {
fn new(size: usize) -> Self {
let pool = (0..size).map(|_| Packet::new()).collect();
Self {
pool,
used: AtomicU64::new(0),
}
}
fn get(&self) -> Option<&mut Packet> {
let idx = self.used.fetch_add(1, Ordering::Relaxed) as usize % self.pool.len();
unsafe { Some(&mut *(self.pool.as_ptr().add(idx) as *mut Packet)) }
}
}
// 统计信息
struct Stats {
packets_sent: AtomicU64,
packets_recv: AtomicU64,
bytes_sent: AtomicU64,
bytes_recv: AtomicU64,
}
fn optimize_socket(socket: &Socket) -> Result<(), Box<dyn Error>> {
socket.set_send_buffer_size(SOCKET_BUFFER_SIZE)?;
socket.set_recv_buffer_size(SOCKET_BUFFER_SIZE)?;
socket.set_nonblocking(false)?; // 使用阻塞模式配合批处理
#[cfg(target_os = "linux")]
{
// Linux 特定优化
use libc::{c_int, c_void, setsockopt, SOL_SOCKET};
// SO_BUSY_POLL - 减少延迟
const SO_BUSY_POLL: c_int = 46;
let busy_poll: c_int = 50; // 50us
unsafe {
setsockopt(
socket.as_raw_fd(),
SOL_SOCKET,
SO_BUSY_POLL,
&busy_poll as *const _ as *const c_void,
mem::size_of::<c_int>() as u32,
);
}
// SO_INCOMING_CPU - CPU 亲和性
const SO_INCOMING_CPU: c_int = 49;
let cpu: c_int = 0;
unsafe {
setsockopt(
socket.as_raw_fd(),
SOL_SOCKET,
SO_INCOMING_CPU,
&cpu as *const _ as *const c_void,
mem::size_of::<c_int>() as u32,
);
}
}
Ok(())
}
fn main() -> Result<(), Box<dyn Error>> { fn main() -> Result<(), Box<dyn Error>> {
let config: Config = serde_json::from_str(env::args().nth(1).ok_or("need param")?.as_str())?; let config: Config = serde_json::from_str(env::args().nth(1).ok_or("need param")?.as_str())?;
let local_secret: [u8; SECRET_LENGTH] = Router::create_secret(config.local_secret.as_str())?; let local_secret: [u8; SECRET_LENGTH] = Router::create_secret(config.local_secret.as_str())?;
let mut sockets: HashMap<u16, Arc<Socket>> = HashMap::new(); let mut sockets: HashMap<u16, Arc<Socket>> = HashMap::new();
// 创建并优化 routers
let routers: HashMap<u8, Router> = config let routers: HashMap<u8, Router> = config
.routers .routers
.iter() .iter()
.map(|c| Router::new(c, &mut sockets).map(|router| (c.remote_id, router))) .map(|c| {
Router::new(c, &mut sockets).map(|router| {
// 优化每个 socket
if let Some(socket) = sockets.get(&router.key()) {
let _ = optimize_socket(socket);
}
(c.remote_id, router)
})
})
.collect::<Result<_, _>>()?; .collect::<Result<_, _>>()?;
let (mut router_readers, router_writers): ( let (mut router_readers, router_writers): (
HashMap<u8, RouterReader>, HashMap<u8, RouterReader>,
HashMap<u8, RouterWriter>, HashMap<u8, RouterWriter>,
...@@ -60,81 +180,256 @@ fn main() -> Result<(), Box<dyn Error>> { ...@@ -60,81 +180,256 @@ fn main() -> Result<(), Box<dyn Error>> {
((id, reader), (id, writer)) ((id, reader), (id, writer))
}) })
.unzip(); .unzip();
let router_writers3: Vec<(Arc<Socket>, HashMap<u8, RouterWriter>)> = router_writers
.into_iter() // 使用 ArcSwap 存储 writers 以减少锁竞争
.grouping_by(|(_, v)| v.key()) let router_writers_arc: Arc<HashMap<u8, Arc<RouterWriter>>> = Arc::new(
.into_iter() router_writers.into_iter()
.map(|(k, v)| { .map(|(k, v)| (k, Arc::new(v)))
( .collect()
Arc::clone(sockets.get_mut(&k).unwrap()), );
v.into_iter().collect(),
) let router_writers3: Vec<(Arc<Socket>, Arc<HashMap<u8, Arc<RouterWriter>>>)> =
}) router_writers_arc.iter()
.collect(); .fold(HashMap::<u16, Vec<(u8, Arc<RouterWriter>)>>::new(), |mut acc, (id, writer)| {
println!("created tuns"); acc.entry(writer.key()).or_insert_with(Vec::new).push((*id, Arc::clone(writer)));
acc
})
.into_iter()
.map(|(key, writers)| {
let socket = Arc::clone(sockets.get(&key).unwrap());
let writers_map = Arc::new(writers.into_iter().collect::<HashMap<_, _>>());
(socket, writers_map)
})
.collect();
// 全局统计
let stats = Arc::new(Stats {
packets_sent: AtomicU64::new(0),
packets_recv: AtomicU64::new(0),
bytes_sent: AtomicU64::new(0),
bytes_recv: AtomicU64::new(0),
});
println!("Created optimized TUN devices");
thread::scope(|s| { thread::scope(|s| {
for router in router_readers.values_mut() { // 为每个 router 创建高性能读写线程组
s.spawn(|_| { for (router_id, router) in router_readers.iter_mut() {
let mut buffer = [0u8; 1500 - 20]; // minus typical IP header space let router_id = *router_id;
let meta_size = size_of::<Meta>(); let local_id = config.local_id;
let stats_clone = Arc::clone(&stats);
// Pre-initialize with our Meta header (local -> remote)
// 创建多个通道用于负载均衡
let mut channels = Vec::new();
for _ in 0..2 { // 2个并行处理通道
channels.push(bounded::<Packet>(CHANNEL_SIZE));
}
// TUN 读取线程 - 零拷贝优化
let channels_clone = channels.clone();
let mut router_reader = router.clone(); // 假设实现了 Clone
s.spawn(move |_| {
let packet_pool = PacketPool::new(PACKET_POOL_SIZE);
let meta_size = mem::size_of::<Meta>();
let mut channel_idx = 0;
// 预构建 Meta
let meta = Meta { let meta = Meta {
src_id: config.local_id, src_id: local_id,
dst_id: router.config.remote_id, dst_id: router_id,
reversed: 0, reversed: 0,
}; };
// Turn the Meta struct into bytes
let meta_bytes = unsafe {
std::slice::from_raw_parts(&meta as *const Meta as *const u8, meta_size)
};
buffer[..meta_size].copy_from_slice(meta_bytes);
loop { loop {
let n = router.tun_reader.read(&mut buffer[meta_size..]).unwrap(); if let Some(packet) = packet_pool.get() {
if let Some(ref addr) = *router.endpoint.read().unwrap() { packet.reset();
router.encrypt(&mut buffer[meta_size..meta_size + n]); packet.data.resize(MAX_PACKET_SIZE, 0);
#[cfg(target_os = "linux")]
let _ = router.socket.set_mark(router.config.mark); // 写入 Meta 头
let _ = router.socket.send_to(&buffer[..meta_size + n], addr); let meta_bytes = unsafe {
std::slice::from_raw_parts(&meta as *const _ as *const u8, meta_size)
};
packet.data[..meta_size].copy_from_slice(meta_bytes);
// 读取数据
match router_reader.tun_reader.read(&mut packet.data[meta_size..]) {
Ok(n) if n > 0 => {
packet.len = meta_size + n;
packet.data.truncate(packet.len);
// 轮询发送到不同通道
let tx = &channels_clone[channel_idx].0;
let _ = tx.try_send(*packet);
channel_idx = (channel_idx + 1) % channels_clone.len();
}
_ => continue,
}
} }
} }
}); });
// 多个加密发送线程
for (tx, rx) in channels {
let router_clone = router.clone();
let stats_clone2 = Arc::clone(&stats_clone);
s.spawn(move |_| {
let mut batch = Vec::with_capacity(BATCH_SIZE);
let meta_size = mem::size_of::<Meta>();
let mut last_batch_time = Instant::now();
loop {
batch.clear();
// 智能批处理 - 根据负载动态调整
let timeout = if batch.is_empty() {
Duration::from_millis(1)
} else {
Duration::from_micros(BATCH_TIMEOUT_MICROS)
};
// 收集数据包
match rx.recv_timeout(timeout) {
Ok(packet) => batch.push(packet),
Err(_) => {
if batch.is_empty() { continue; }
}
}
// 继续收集直到批量大小或超时
while batch.len() < BATCH_SIZE {
match rx.try_recv() {
Ok(packet) => batch.push(packet),
Err(TryRecvError::Empty) => {
if last_batch_time.elapsed() > Duration::from_micros(BATCH_TIMEOUT_MICROS) {
break;
}
}
Err(TryRecvError::Disconnected) => return,
}
}
if !batch.is_empty() {
// 获取endpoint(使用缓存减少锁竞争)
if let Some(ref addr) = *router_clone.endpoint.read().unwrap() {
// 并行加密
batch.par_iter_mut().for_each(|packet| {
router_clone.encrypt(&mut packet.data[meta_size..packet.len]);
});
// 设置 mark(每批只设置一次)
#[cfg(target_os = "linux")]
let _ = router_clone.socket.set_mark(router_clone.config.mark);
// 批量发送
let mut total_bytes = 0;
for packet in &batch {
if let Ok(_) = router_clone.socket.send_to(&packet.data[..packet.len], addr) {
total_bytes += packet.len;
}
}
// 更新统计
stats_clone2.packets_sent.fetch_add(batch.len() as u64, Ordering::Relaxed);
stats_clone2.bytes_sent.fetch_add(total_bytes as u64, Ordering::Relaxed);
}
last_batch_time = Instant::now();
}
}
});
}
} }
for (socket, mut router_writers) in router_writers3 { // 优化的接收线程
for (socket, router_writers) in router_writers3 {
let stats_clone = Arc::clone(&stats);
let local_secret_clone = local_secret.clone();
let config_local_id = config.local_id;
s.spawn(move |_| { s.spawn(move |_| {
let mut recv_buf = [MaybeUninit::uninit(); 1500]; // 预分配接收缓冲区池
const RECV_POOL_SIZE: usize = 128;
let mut recv_pool: Vec<Vec<u8>> = (0..RECV_POOL_SIZE)
.map(|_| vec![0u8; MAX_PACKET_SIZE])
.collect();
let mut pool_idx = 0;
// Endpoint 缓存
let mut endpoint_cache: HashMap<u8, SocketAddr> = HashMap::new();
loop { loop {
let _ = (|| { let recv_buf = &mut recv_pool[pool_idx];
let (len, addr) = socket.recv_from(&mut recv_buf).unwrap(); pool_idx = (pool_idx + 1) % RECV_POOL_SIZE;
let data: &mut [u8] = unsafe { transmute(&mut recv_buf[..len]) };
match socket.recv_from(recv_buf) {
let packet = Ipv4Packet::new(data).ok_or("malformed packet")?; Ok((len, addr)) => {
let header_len = packet.get_header_length() as usize * 4; let data = &mut recv_buf[..len];
let (_ip_header, rest) = data
.split_at_mut_checked(header_len) // 快速路径检查
.ok_or("malformed packet")?; if len < 20 + mem::size_of::<Meta>() {
let (meta_bytes, payload) = rest continue;
.split_at_mut_checked(size_of::<Meta>()) }
.ok_or("malformed packet")?;
let meta: &Meta = unsafe { transmute(meta_bytes.as_ptr()) }; if let Some(packet) = Ipv4Packet::new(data) {
if meta.dst_id == config.local_id && meta.reversed == 0 { let header_len = packet.get_header_length() as usize * 4;
let router = router_writers if let Some((_ip_header, rest)) = data.split_at_mut_checked(header_len) {
.get_mut(&meta.src_id) if let Some((meta_bytes, payload)) = rest.split_at_mut_checked(mem::size_of::<Meta>()) {
.ok_or("missing router")?; let meta: &Meta = unsafe { &*(meta_bytes.as_ptr() as *const Meta) };
*router.endpoint.write().unwrap() = Some(addr);
router.decrypt(payload, &local_secret); if meta.dst_id == config_local_id && meta.reversed == 0 {
router.tun_writer.write_all(payload)?; if let Some(router) = router_writers.get(&meta.src_id) {
// 更新endpoint缓存
endpoint_cache.insert(meta.src_id, addr);
*router.endpoint.write().unwrap() = Some(addr);
// 解密
router.decrypt(payload, &local_secret_clone);
// 写入TUN
let _ = router.tun_writer.write_all(payload);
// 更新统计
stats_clone.packets_recv.fetch_add(1, Ordering::Relaxed);
stats_clone.bytes_recv.fetch_add(payload.len() as u64, Ordering::Relaxed);
}
}
}
}
}
} }
Err(_) => continue,
Ok::<(), Box<dyn Error>>(()) }
})();
} }
}); });
} }
// 统计线程
let stats_clone = Arc::clone(&stats);
s.spawn(move |_| {
let mut last_stats = (0u64, 0u64, 0u64, 0u64);
loop {
thread::sleep(Duration::from_secs(10));
let sent_pkts = stats_clone.packets_sent.load(Ordering::Relaxed);
let recv_pkts = stats_clone.packets_recv.load(Ordering::Relaxed);
let sent_bytes = stats_clone.bytes_sent.load(Ordering::Relaxed);
let recv_bytes = stats_clone.bytes_recv.load(Ordering::Relaxed);
let sent_pps = (sent_pkts - last_stats.0) / 10;
let recv_pps = (recv_pkts - last_stats.1) / 10;
let sent_mbps = ((sent_bytes - last_stats.2) * 8) / (10 * 1024 * 1024);
let recv_mbps = ((recv_bytes - last_stats.3) * 8) / (10 * 1024 * 1024);
println!(
"Stats - TX: {} pps, {} Mbps | RX: {} pps, {} Mbps",
sent_pps, sent_mbps, recv_pps, recv_mbps
);
last_stats = (sent_pkts, recv_pkts, sent_bytes, recv_bytes);
}
});
}) })
.unwrap(); .unwrap();
Ok(()) Ok(())
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment