Commit da0071c4 authored by nanahira's avatar nanahira

rework

parent ba3c41ac
Pipeline #37404 failed with stages
in 1 minute and 51 seconds
...@@ -4,17 +4,15 @@ use crate::router::{Router, RouterReader, RouterWriter, SECRET_LENGTH}; ...@@ -4,17 +4,15 @@ use crate::router::{Router, RouterReader, RouterWriter, SECRET_LENGTH};
use std::collections::HashMap; use std::collections::HashMap;
use std::env; use std::env;
use std::error::Error; use std::error::Error;
use std::intrinsics::transmute;
use std::io::{Read, Write}; use std::io::{Read, Write};
use std::mem::{self, MaybeUninit}; use std::mem::MaybeUninit;
use std::net::SocketAddr; use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering}; use std::sync::atomic::{AtomicPtr, Ordering};
use std::sync::{Arc, Mutex}; use std::ptr;
use std::sync::mpsc::{sync_channel, Receiver, SyncSender, TryRecvError}; use std::thread::available_parallelism;
use std::time::{Duration, Instant};
use std::thread;
#[repr(C)] #[repr(C)]
#[derive(Copy, Clone)]
pub struct Meta { pub struct Meta {
pub src_id: u8, pub src_id: u8,
pub dst_id: u8, pub dst_id: u8,
...@@ -42,77 +40,116 @@ pub struct Config { ...@@ -42,77 +40,116 @@ pub struct Config {
pub routers: Vec<ConfigRouter>, pub routers: Vec<ConfigRouter>,
} }
use crossbeam_utils::thread as scoped_thread; use crossbeam_utils::thread;
use grouping_by::GroupingBy; use grouping_by::GroupingBy;
use pnet::packet::ipv4::Ipv4Packet; use pnet::packet::ipv4::Ipv4Packet;
use socket2::Socket; use socket2::Socket;
use libc::{iovec, msghdr, mmsghdr, recvmmsg, sendmmsg, MSG_WAITFORONE};
use std::net::SocketAddr;
// 性能优化配置 // 批量处理的包数量
const BATCH_SIZE: usize = 32; const BATCH_SIZE: usize = 32;
const BATCH_TIMEOUT_MILLIS: u64 = 1; // 更大的缓冲区以支持 GSO
const CHANNEL_SIZE: usize = 1024; const BUFFER_SIZE: usize = 65536;
const SOCKET_BUFFER_SIZE: usize = 4 * 1024 * 1024; // 4MB // 预分配的缓冲池大小(会根据 CPU 数量调整)
const MAX_PACKET_SIZE: usize = 1500; const BUFFER_POOL_SIZE_PER_THREAD: usize = 64;
// 数据包结构 // 无锁的地址更新结构
#[derive(Clone)] struct AtomicEndpoint {
struct Packet { ptr: AtomicPtr<SocketAddr>,
data: Vec<u8>,
len: usize,
} }
impl Packet { impl AtomicEndpoint {
fn new() -> Self { fn new() -> Self {
Self { AtomicEndpoint {
data: vec![0u8; MAX_PACKET_SIZE], ptr: AtomicPtr::new(ptr::null_mut()),
len: 0, }
}
fn load(&self) -> Option<SocketAddr> {
let ptr = self.ptr.load(Ordering::Acquire);
if ptr.is_null() {
None
} else {
Some(unsafe { *ptr })
}
}
fn store(&self, addr: SocketAddr) {
let boxed = Box::new(addr);
let old = self.ptr.swap(Box::into_raw(boxed), Ordering::Release);
if !old.is_null() {
unsafe { Box::from_raw(old); }
} }
} }
} }
// 统计信息 // 缓冲池结构
struct Stats { struct BufferPool {
packets_sent: AtomicU64, buffers: Vec<Vec<u8>>,
packets_recv: AtomicU64,
bytes_sent: AtomicU64,
bytes_recv: AtomicU64,
} }
fn optimize_socket(socket: &Socket) -> Result<(), Box<dyn Error>> { impl BufferPool {
socket.set_send_buffer_size(SOCKET_BUFFER_SIZE)?; fn new(size: usize) -> Self {
socket.set_recv_buffer_size(SOCKET_BUFFER_SIZE)?; let mut buffers = Vec::with_capacity(size);
socket.set_nonblocking(false)?; for _ in 0..size {
buffers.push(vec![0u8; BUFFER_SIZE]);
}
BufferPool { buffers }
}
#[cfg(target_os = "linux")] fn get_buffers(&mut self, count: usize) -> Vec<Vec<u8>> {
{ let mut result = Vec::with_capacity(count);
// Linux 特定优化 for _ in 0..count.min(self.buffers.len()) {
socket.set_nodelay(true)?; if let Some(buf) = self.buffers.pop() {
result.push(buf);
} else {
result.push(vec![0u8; BUFFER_SIZE]);
}
}
while result.len() < count {
result.push(vec![0u8; BUFFER_SIZE]);
}
result
} }
Ok(()) fn return_buffers(&mut self, buffers: Vec<Vec<u8>>) {
self.buffers.extend(buffers);
}
} }
fn main() -> Result<(), Box<dyn Error>> { fn main() -> Result<(), Box<dyn Error>> {
// 获取系统的 CPU 核心数
let num_threads = available_parallelism()
.map(|n| n.get())
.unwrap_or(4); // 默认使用 4 个线程
println!("Using {} threads based on available CPU cores", num_threads);
let config: Config = serde_json::from_str(env::args().nth(1).ok_or("need param")?.as_str())?; let config: Config = serde_json::from_str(env::args().nth(1).ok_or("need param")?.as_str())?;
let local_secret: [u8; SECRET_LENGTH] = Router::create_secret(config.local_secret.as_str())?; let local_secret: [u8; SECRET_LENGTH] = Router::create_secret(config.local_secret.as_str())?;
let mut sockets: HashMap<u16, Arc<Socket>> = HashMap::new(); let mut sockets: HashMap<u16, Arc<Socket>> = HashMap::new();
// 创建并优化 routers // 为所有 socket 设置更大的缓冲区
for socket in sockets.values_mut() {
// 设置 socket 缓冲区,根据线程数动态调整
let buffer_size = 2 * 1024 * 1024 * num_threads; // 每个线程 2MB
let _ = socket.set_recv_buffer_size(buffer_size);
let _ = socket.set_send_buffer_size(buffer_size);
}
let routers: HashMap<u8, Router> = config let routers: HashMap<u8, Router> = config
.routers .routers
.iter() .iter()
.map(|c| { .map(|c| Router::new(c, &mut sockets).map(|router| (c.remote_id, router)))
Router::new(c, &mut sockets).map(|router| {
// 优化每个 socket
let router_key = Router::key(c); // 使用关联函数
if let Some(socket) = sockets.get(&router_key) {
let _ = optimize_socket(socket);
}
(c.remote_id, router)
})
})
.collect::<Result<_, _>>()?; .collect::<Result<_, _>>()?;
// 使用原子指针替代 RwLock 来存储 endpoint
let atomic_endpoints: HashMap<u8, Arc<AtomicEndpoint>> = routers
.keys()
.map(|&id| (id, Arc::new(AtomicEndpoint::new())))
.collect();
let (mut router_readers, router_writers): ( let (mut router_readers, router_writers): (
HashMap<u8, RouterReader>, HashMap<u8, RouterReader>,
HashMap<u8, RouterWriter>, HashMap<u8, RouterWriter>,
...@@ -136,217 +173,236 @@ fn main() -> Result<(), Box<dyn Error>> { ...@@ -136,217 +173,236 @@ fn main() -> Result<(), Box<dyn Error>> {
}) })
.collect(); .collect();
// 全局统计 println!("created tuns");
let stats = Arc::new(Stats {
packets_sent: AtomicU64::new(0),
packets_recv: AtomicU64::new(0),
bytes_sent: AtomicU64::new(0),
bytes_recv: AtomicU64::new(0),
});
println!("Created optimized TUN devices"); thread::scope(|s| {
// 为每个路由器启动多个发送线程
scoped_thread::scope(|s| { for (id, router) in router_readers.drain() {
// 为每个 router 创建读写线程 let endpoint = Arc::clone(&atomic_endpoints[&id]);
for router in router_readers.values_mut() {
let router_id = router.config.remote_id;
let local_id = config.local_id;
let stats_clone = Arc::clone(&stats);
// 创建批处理通道 // 根据 CPU 核心数创建发送线程
let (tx, rx): (SyncSender<Packet>, Receiver<Packet>) = sync_channel(CHANNEL_SIZE); // 为避免过多线程竞争同一个 TUN 设备,限制最大线程数
let sender_threads = num_threads.min(4);
// TUN 读取线程 for thread_id in 0..sender_threads {
s.spawn(move |_| { let mut router = router.clone(); // 假设 RouterReader 实现了 Clone
let meta_size = mem::size_of::<Meta>(); let endpoint = Arc::clone(&endpoint);
let mut buffer = vec![0u8; MAX_PACKET_SIZE]; let local_id = config.local_id;
// 预构建 Meta s.spawn(move |_| {
let meta = Meta { // 设置线程 CPU 亲和性(可选)
src_id: local_id, #[cfg(target_os = "linux")]
dst_id: router_id, {
reversed: 0, use libc::{cpu_set_t, CPU_SET, CPU_ZERO, sched_setaffinity};
}; unsafe {
let meta_bytes = unsafe { let mut cpu_set: cpu_set_t = std::mem::zeroed();
std::slice::from_raw_parts(&meta as *const _ as *const u8, meta_size) CPU_ZERO(&mut cpu_set);
}; CPU_SET(thread_id % num_threads, &mut cpu_set);
sched_setaffinity(0, std::mem::size_of_val(&cpu_set), &cpu_set);
loop {
// 写入 Meta 头
buffer[..meta_size].copy_from_slice(meta_bytes);
// 读取数据
match router.tun_reader.read(&mut buffer[meta_size..]) {
Ok(n) if n > 0 => {
let packet = Packet {
data: buffer[..meta_size + n].to_vec(),
len: meta_size + n,
};
let _ = tx.try_send(packet);
} }
_ => continue,
} }
}
});
// 批处理发送线程
s.spawn(move |_| {
let mut batch = Vec::with_capacity(BATCH_SIZE);
let meta_size = mem::size_of::<Meta>();
let mut last_batch_time = Instant::now();
loop {
batch.clear();
// 收集数据包 let mut buffer_pool = BufferPool::new(BUFFER_POOL_SIZE_PER_THREAD);
match rx.recv_timeout(Duration::from_millis(BATCH_TIMEOUT_MILLIS)) { let meta_size = size_of::<Meta>();
Ok(packet) => batch.push(packet),
Err(_) => {
if batch.is_empty() { continue; }
}
}
// 继续收集直到批量大小或超时 // 预初始化 Meta 头
while batch.len() < BATCH_SIZE { let meta = Meta {
match rx.try_recv() { src_id: local_id,
Ok(packet) => batch.push(packet), dst_id: router.config.remote_id,
Err(TryRecvError::Empty) => { reversed: 0,
if last_batch_time.elapsed() > Duration::from_millis(BATCH_TIMEOUT_MILLIS) { };
break;
loop {
// 批量读取多个包
let mut buffers = buffer_pool.get_buffers(BATCH_SIZE);
let mut valid_count = 0;
for buffer in &mut buffers {
// 写入 Meta 头
let meta_bytes = unsafe {
std::slice::from_raw_parts(&meta as *const Meta as *const u8, meta_size)
};
buffer[..meta_size].copy_from_slice(meta_bytes);
// 尝试非阻塞读取
match router.tun_reader.read(&mut buffer[meta_size..]) {
Ok(n) if n > 0 => {
if let Some(addr) = endpoint.load() {
router.encrypt(&mut buffer[meta_size..meta_size + n]);
buffer.truncate(meta_size + n);
valid_count += 1;
} else {
break;
}
} }
_ => break,
} }
Err(TryRecvError::Disconnected) => return,
} }
}
// 批量发送
if !batch.is_empty() { if valid_count > 0 && endpoint.load().is_some() {
// 获取endpoint
if let Some(ref addr) = *router.endpoint.read().unwrap() {
// 加密所有包
for packet in &mut batch {
router.encrypt(&mut packet.data[meta_size..packet.len]);
}
// 设置 mark(每批只设置一次)
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
let _ = router.socket.set_mark(router.config.mark); let _ = router.socket.set_mark(router.config.mark);
// 批量发送 // 使用 sendmmsg 批量发送(如果可用)
let mut total_bytes = 0; let addr = endpoint.load().unwrap();
for packet in &batch { #[cfg(target_os = "linux")]
if let Ok(_) = router.socket.send_to(&packet.data[..packet.len], addr) { {
total_bytes += packet.len; use std::os::unix::io::AsRawFd;
let fd = router.socket.as_raw_fd();
// 准备批量发送的消息
let mut messages: Vec<mmsghdr> = Vec::with_capacity(valid_count);
let mut iovecs: Vec<iovec> = Vec::with_capacity(valid_count);
let mut sockaddrs: Vec<libc::sockaddr_storage> = Vec::with_capacity(valid_count);
// TODO: 实现 sendmmsg 批量发送
// 暂时使用普通发送
for buffer in &buffers[..valid_count] {
let _ = router.socket.send_to(buffer, &addr);
} }
} }
// 更新统计 #[cfg(not(target_os = "linux"))]
stats_clone.packets_sent.fetch_add(batch.len() as u64, Ordering::Relaxed); {
stats_clone.bytes_sent.fetch_add(total_bytes as u64, Ordering::Relaxed); for buffer in &buffers[..valid_count] {
let _ = router.socket.send_to(buffer, &addr);
}
}
}
// 归还缓冲区
for buffer in &mut buffers {
buffer.clear();
buffer.resize(BUFFER_SIZE, 0);
} }
buffer_pool.return_buffers(buffers);
last_batch_time = Instant::now(); // 避免 CPU 占用过高
if valid_count == 0 {
std::thread::yield_now();
}
} }
} });
}); }
} }
// 优化的接收线程 // 为每个 socket 启动多个接收线程
for (socket, mut router_writers) in router_writers3 { for (socket, mut router_writers) in router_writers3 {
let stats_clone = Arc::clone(&stats); // 将 router_writers 转换为使用原子指针的版本
let local_secret_clone = local_secret.clone(); let atomic_writers: Arc<HashMap<u8, (RouterWriter, Arc<AtomicEndpoint>)>> = Arc::new(
let config_local_id = config.local_id; router_writers
.drain()
.map(|(id, writer)| {
let endpoint = Arc::clone(&atomic_endpoints[&id]);
(id, (writer, endpoint))
})
.collect()
);
// 根据 CPU 核心数创建接收线程
let receiver_threads = num_threads.min(8); // 接收线程可以多一些
s.spawn(move |_| { for thread_id in 0..receiver_threads {
// 预分配多个缓冲区 let socket = Arc::clone(&socket);
let mut recv_bufs: Vec<[MaybeUninit<u8>; MAX_PACKET_SIZE]> = let atomic_writers = Arc::clone(&atomic_writers);
(0..4).map(|_| unsafe { MaybeUninit::uninit().assume_init() }).collect(); let local_id = config.local_id;
let mut buf_idx = 0; let local_secret = local_secret.clone();
loop { s.spawn(move |_| {
let recv_buf = &mut recv_bufs[buf_idx]; // 设置线程 CPU 亲和性(可选)
buf_idx = (buf_idx + 1) % recv_bufs.len(); #[cfg(target_os = "linux")]
{
use libc::{cpu_set_t, CPU_SET, CPU_ZERO, sched_setaffinity};
unsafe {
let mut cpu_set: cpu_set_t = std::mem::zeroed();
CPU_ZERO(&mut cpu_set);
CPU_SET(thread_id % num_threads, &mut cpu_set);
sched_setaffinity(0, std::mem::size_of_val(&cpu_set), &cpu_set);
}
}
match socket.recv_from(recv_buf) { let mut buffer_pool = BufferPool::new(BUFFER_POOL_SIZE_PER_THREAD);
Ok((len, sock_addr)) => {
// 转换为 SocketAddr loop {
let addr = match sock_addr.as_socket() { let mut recv_buffers = buffer_pool.get_buffers(BATCH_SIZE);
Some(addr) => addr, let mut addrs = vec![MaybeUninit::uninit(); BATCH_SIZE];
None => continue, let mut valid_count = 0;
};
// 批量接收
// 将 MaybeUninit 转换为初始化的数据 #[cfg(target_os = "linux")]
let data = unsafe { {
std::slice::from_raw_parts_mut( // TODO: 实现 recvmmsg 批量接收
recv_buf.as_mut_ptr() as *mut u8, // 暂时使用普通接收
len for i in 0..BATCH_SIZE {
) match socket.recv_from(&mut recv_buffers[i]) {
}; Ok((len, addr)) => {
recv_buffers[i].truncate(len);
// 快速路径检查 addrs[i].write(addr);
if len < 20 + mem::size_of::<Meta>() { valid_count += 1;
continue; }
Err(_) => break,
}
} }
}
if let Some(packet) = Ipv4Packet::new(data) {
let header_len = packet.get_header_length() as usize * 4; #[cfg(not(target_os = "linux"))]
if header_len < data.len() { {
let rest = &mut data[header_len..]; for i in 0..BATCH_SIZE {
if rest.len() >= mem::size_of::<Meta>() { match socket.recv_from(&mut recv_buffers[i]) {
let (meta_bytes, payload) = rest.split_at_mut(mem::size_of::<Meta>()); Ok((len, addr)) => {
let meta: &Meta = unsafe { &*(meta_bytes.as_ptr() as *const Meta) }; recv_buffers[i].truncate(len);
addrs[i].write(addr);
if meta.dst_id == config_local_id && meta.reversed == 0 { valid_count += 1;
if let Some(router) = router_writers.get_mut(&meta.src_id) {
// 更新endpoint
*router.endpoint.write().unwrap() = Some(addr);
// 解密
router.decrypt(payload, &local_secret_clone);
// 写入TUN
let _ = router.tun_writer.write_all(payload);
// 更新统计
stats_clone.packets_recv.fetch_add(1, Ordering::Relaxed);
stats_clone.bytes_recv.fetch_add(payload.len() as u64, Ordering::Relaxed);
}
}
} }
Err(_) => break,
} }
} }
} }
Err(_) => continue,
// 批量处理接收到的包
for i in 0..valid_count {
let _ = (|| {
let data = &mut recv_buffers[i];
let addr = unsafe { addrs[i].assume_init() };
let packet = Ipv4Packet::new(data).ok_or("malformed packet")?;
let header_len = packet.get_header_length() as usize * 4;
let (_ip_header, rest) = data
.split_at_mut_checked(header_len)
.ok_or("malformed packet")?;
let (meta_bytes, payload) = rest
.split_at_mut_checked(size_of::<Meta>())
.ok_or("malformed packet")?;
let meta: &Meta = unsafe { transmute(meta_bytes.as_ptr()) };
if meta.dst_id == local_id && meta.reversed == 0 {
if let Some((router, endpoint)) = atomic_writers.get(&meta.src_id) {
endpoint.store(addr);
router.decrypt(payload, &local_secret);
router.tun_writer.write_all(payload)?;
}
}
Ok::<(), Box<dyn Error>>(())
})();
}
// 归还缓冲区
for buffer in &mut recv_buffers {
buffer.clear();
buffer.resize(BUFFER_SIZE, 0);
}
buffer_pool.return_buffers(recv_buffers);
if valid_count == 0 {
std::thread::yield_now();
}
} }
} });
});
}
// 统计线程
let stats_clone = Arc::clone(&stats);
s.spawn(move |_| {
let mut last_stats = (0u64, 0u64, 0u64, 0u64);
loop {
thread::sleep(Duration::from_secs(10));
let sent_pkts = stats_clone.packets_sent.load(Ordering::Relaxed);
let recv_pkts = stats_clone.packets_recv.load(Ordering::Relaxed);
let sent_bytes = stats_clone.bytes_sent.load(Ordering::Relaxed);
let recv_bytes = stats_clone.bytes_recv.load(Ordering::Relaxed);
let sent_pps = (sent_pkts - last_stats.0) / 10;
let recv_pps = (recv_pkts - last_stats.1) / 10;
let sent_mbps = ((sent_bytes - last_stats.2) * 8) / (10 * 1024 * 1024);
let recv_mbps = ((recv_bytes - last_stats.3) * 8) / (10 * 1024 * 1024);
println!(
"Stats - TX: {} pps, {} Mbps | RX: {} pps, {} Mbps",
sent_pps, sent_mbps, recv_pps, recv_mbps
);
last_stats = (sent_pkts, recv_pkts, sent_bytes, recv_bytes);
} }
}); }
}) })
.unwrap(); .unwrap();
Ok(()) Ok(())
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment