Commit da0071c4 authored by nanahira's avatar nanahira

rework

parent ba3c41ac
Pipeline #37404 failed with stages
in 1 minute and 51 seconds
......@@ -4,17 +4,15 @@ use crate::router::{Router, RouterReader, RouterWriter, SECRET_LENGTH};
use std::collections::HashMap;
use std::env;
use std::error::Error;
use std::intrinsics::transmute;
use std::io::{Read, Write};
use std::mem::{self, MaybeUninit};
use std::net::SocketAddr;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use std::sync::mpsc::{sync_channel, Receiver, SyncSender, TryRecvError};
use std::time::{Duration, Instant};
use std::thread;
use std::mem::MaybeUninit;
use std::sync::Arc;
use std::sync::atomic::{AtomicPtr, Ordering};
use std::ptr;
use std::thread::available_parallelism;
#[repr(C)]
#[derive(Copy, Clone)]
pub struct Meta {
pub src_id: u8,
pub dst_id: u8,
......@@ -42,77 +40,116 @@ pub struct Config {
pub routers: Vec<ConfigRouter>,
}
use crossbeam_utils::thread as scoped_thread;
use crossbeam_utils::thread;
use grouping_by::GroupingBy;
use pnet::packet::ipv4::Ipv4Packet;
use socket2::Socket;
use libc::{iovec, msghdr, mmsghdr, recvmmsg, sendmmsg, MSG_WAITFORONE};
use std::net::SocketAddr;
// 性能优化配置
// 批量处理的包数量
const BATCH_SIZE: usize = 32;
const BATCH_TIMEOUT_MILLIS: u64 = 1;
const CHANNEL_SIZE: usize = 1024;
const SOCKET_BUFFER_SIZE: usize = 4 * 1024 * 1024; // 4MB
const MAX_PACKET_SIZE: usize = 1500;
// 数据包结构
#[derive(Clone)]
struct Packet {
data: Vec<u8>,
len: usize,
// 更大的缓冲区以支持 GSO
const BUFFER_SIZE: usize = 65536;
// 预分配的缓冲池大小(会根据 CPU 数量调整)
const BUFFER_POOL_SIZE_PER_THREAD: usize = 64;
// 无锁的地址更新结构
struct AtomicEndpoint {
ptr: AtomicPtr<SocketAddr>,
}
impl Packet {
impl AtomicEndpoint {
fn new() -> Self {
Self {
data: vec![0u8; MAX_PACKET_SIZE],
len: 0,
AtomicEndpoint {
ptr: AtomicPtr::new(ptr::null_mut()),
}
}
fn load(&self) -> Option<SocketAddr> {
let ptr = self.ptr.load(Ordering::Acquire);
if ptr.is_null() {
None
} else {
Some(unsafe { *ptr })
}
}
fn store(&self, addr: SocketAddr) {
let boxed = Box::new(addr);
let old = self.ptr.swap(Box::into_raw(boxed), Ordering::Release);
if !old.is_null() {
unsafe { Box::from_raw(old); }
}
}
}
// 统计信息
struct Stats {
packets_sent: AtomicU64,
packets_recv: AtomicU64,
bytes_sent: AtomicU64,
bytes_recv: AtomicU64,
// 缓冲池结构
struct BufferPool {
buffers: Vec<Vec<u8>>,
}
fn optimize_socket(socket: &Socket) -> Result<(), Box<dyn Error>> {
socket.set_send_buffer_size(SOCKET_BUFFER_SIZE)?;
socket.set_recv_buffer_size(SOCKET_BUFFER_SIZE)?;
socket.set_nonblocking(false)?;
impl BufferPool {
fn new(size: usize) -> Self {
let mut buffers = Vec::with_capacity(size);
for _ in 0..size {
buffers.push(vec![0u8; BUFFER_SIZE]);
}
BufferPool { buffers }
}
#[cfg(target_os = "linux")]
{
// Linux 特定优化
socket.set_nodelay(true)?;
fn get_buffers(&mut self, count: usize) -> Vec<Vec<u8>> {
let mut result = Vec::with_capacity(count);
for _ in 0..count.min(self.buffers.len()) {
if let Some(buf) = self.buffers.pop() {
result.push(buf);
} else {
result.push(vec![0u8; BUFFER_SIZE]);
}
}
while result.len() < count {
result.push(vec![0u8; BUFFER_SIZE]);
}
result
}
Ok(())
fn return_buffers(&mut self, buffers: Vec<Vec<u8>>) {
self.buffers.extend(buffers);
}
}
fn main() -> Result<(), Box<dyn Error>> {
// 获取系统的 CPU 核心数
let num_threads = available_parallelism()
.map(|n| n.get())
.unwrap_or(4); // 默认使用 4 个线程
println!("Using {} threads based on available CPU cores", num_threads);
let config: Config = serde_json::from_str(env::args().nth(1).ok_or("need param")?.as_str())?;
let local_secret: [u8; SECRET_LENGTH] = Router::create_secret(config.local_secret.as_str())?;
let mut sockets: HashMap<u16, Arc<Socket>> = HashMap::new();
// 创建并优化 routers
// 为所有 socket 设置更大的缓冲区
for socket in sockets.values_mut() {
// 设置 socket 缓冲区,根据线程数动态调整
let buffer_size = 2 * 1024 * 1024 * num_threads; // 每个线程 2MB
let _ = socket.set_recv_buffer_size(buffer_size);
let _ = socket.set_send_buffer_size(buffer_size);
}
let routers: HashMap<u8, Router> = config
.routers
.iter()
.map(|c| {
Router::new(c, &mut sockets).map(|router| {
// 优化每个 socket
let router_key = Router::key(c); // 使用关联函数
if let Some(socket) = sockets.get(&router_key) {
let _ = optimize_socket(socket);
}
(c.remote_id, router)
})
})
.map(|c| Router::new(c, &mut sockets).map(|router| (c.remote_id, router)))
.collect::<Result<_, _>>()?;
// 使用原子指针替代 RwLock 来存储 endpoint
let atomic_endpoints: HashMap<u8, Arc<AtomicEndpoint>> = routers
.keys()
.map(|&id| (id, Arc::new(AtomicEndpoint::new())))
.collect();
let (mut router_readers, router_writers): (
HashMap<u8, RouterReader>,
HashMap<u8, RouterWriter>,
......@@ -136,217 +173,236 @@ fn main() -> Result<(), Box<dyn Error>> {
})
.collect();
// 全局统计
let stats = Arc::new(Stats {
packets_sent: AtomicU64::new(0),
packets_recv: AtomicU64::new(0),
bytes_sent: AtomicU64::new(0),
bytes_recv: AtomicU64::new(0),
});
println!("created tuns");
println!("Created optimized TUN devices");
thread::scope(|s| {
// 为每个路由器启动多个发送线程
for (id, router) in router_readers.drain() {
let endpoint = Arc::clone(&atomic_endpoints[&id]);
scoped_thread::scope(|s| {
// 为每个 router 创建读写线程
for router in router_readers.values_mut() {
let router_id = router.config.remote_id;
let local_id = config.local_id;
let stats_clone = Arc::clone(&stats);
// 根据 CPU 核心数创建发送线程
// 为避免过多线程竞争同一个 TUN 设备,限制最大线程数
let sender_threads = num_threads.min(4);
// 创建批处理通道
let (tx, rx): (SyncSender<Packet>, Receiver<Packet>) = sync_channel(CHANNEL_SIZE);
for thread_id in 0..sender_threads {
let mut router = router.clone(); // 假设 RouterReader 实现了 Clone
let endpoint = Arc::clone(&endpoint);
let local_id = config.local_id;
// TUN 读取线程
s.spawn(move |_| {
let meta_size = mem::size_of::<Meta>();
let mut buffer = vec![0u8; MAX_PACKET_SIZE];
// 设置线程 CPU 亲和性(可选)
#[cfg(target_os = "linux")]
{
use libc::{cpu_set_t, CPU_SET, CPU_ZERO, sched_setaffinity};
unsafe {
let mut cpu_set: cpu_set_t = std::mem::zeroed();
CPU_ZERO(&mut cpu_set);
CPU_SET(thread_id % num_threads, &mut cpu_set);
sched_setaffinity(0, std::mem::size_of_val(&cpu_set), &cpu_set);
}
}
// 预构建 Meta
let mut buffer_pool = BufferPool::new(BUFFER_POOL_SIZE_PER_THREAD);
let meta_size = size_of::<Meta>();
// 预初始化 Meta 头
let meta = Meta {
src_id: local_id,
dst_id: router_id,
dst_id: router.config.remote_id,
reversed: 0,
};
let meta_bytes = unsafe {
std::slice::from_raw_parts(&meta as *const _ as *const u8, meta_size)
};
loop {
// 批量读取多个包
let mut buffers = buffer_pool.get_buffers(BATCH_SIZE);
let mut valid_count = 0;
for buffer in &mut buffers {
// 写入 Meta 头
let meta_bytes = unsafe {
std::slice::from_raw_parts(&meta as *const Meta as *const u8, meta_size)
};
buffer[..meta_size].copy_from_slice(meta_bytes);
// 读取数据
// 尝试非阻塞读取
match router.tun_reader.read(&mut buffer[meta_size..]) {
Ok(n) if n > 0 => {
let packet = Packet {
data: buffer[..meta_size + n].to_vec(),
len: meta_size + n,
};
let _ = tx.try_send(packet);
}
_ => continue,
}
}
});
// 批处理发送线程
s.spawn(move |_| {
let mut batch = Vec::with_capacity(BATCH_SIZE);
let meta_size = mem::size_of::<Meta>();
let mut last_batch_time = Instant::now();
loop {
batch.clear();
// 收集数据包
match rx.recv_timeout(Duration::from_millis(BATCH_TIMEOUT_MILLIS)) {
Ok(packet) => batch.push(packet),
Err(_) => {
if batch.is_empty() { continue; }
}
}
// 继续收集直到批量大小或超时
while batch.len() < BATCH_SIZE {
match rx.try_recv() {
Ok(packet) => batch.push(packet),
Err(TryRecvError::Empty) => {
if last_batch_time.elapsed() > Duration::from_millis(BATCH_TIMEOUT_MILLIS) {
if let Some(addr) = endpoint.load() {
router.encrypt(&mut buffer[meta_size..meta_size + n]);
buffer.truncate(meta_size + n);
valid_count += 1;
} else {
break;
}
}
Err(TryRecvError::Disconnected) => return,
}
_ => break,
}
if !batch.is_empty() {
// 获取endpoint
if let Some(ref addr) = *router.endpoint.read().unwrap() {
// 加密所有包
for packet in &mut batch {
router.encrypt(&mut packet.data[meta_size..packet.len]);
}
// 设置 mark(每批只设置一次)
// 批量发送
if valid_count > 0 && endpoint.load().is_some() {
#[cfg(target_os = "linux")]
let _ = router.socket.set_mark(router.config.mark);
// 批量发送
let mut total_bytes = 0;
for packet in &batch {
if let Ok(_) = router.socket.send_to(&packet.data[..packet.len], addr) {
total_bytes += packet.len;
// 使用 sendmmsg 批量发送(如果可用)
let addr = endpoint.load().unwrap();
#[cfg(target_os = "linux")]
{
use std::os::unix::io::AsRawFd;
let fd = router.socket.as_raw_fd();
// 准备批量发送的消息
let mut messages: Vec<mmsghdr> = Vec::with_capacity(valid_count);
let mut iovecs: Vec<iovec> = Vec::with_capacity(valid_count);
let mut sockaddrs: Vec<libc::sockaddr_storage> = Vec::with_capacity(valid_count);
// TODO: 实现 sendmmsg 批量发送
// 暂时使用普通发送
for buffer in &buffers[..valid_count] {
let _ = router.socket.send_to(buffer, &addr);
}
}
// 更新统计
stats_clone.packets_sent.fetch_add(batch.len() as u64, Ordering::Relaxed);
stats_clone.bytes_sent.fetch_add(total_bytes as u64, Ordering::Relaxed);
#[cfg(not(target_os = "linux"))]
{
for buffer in &buffers[..valid_count] {
let _ = router.socket.send_to(buffer, &addr);
}
}
}
last_batch_time = Instant::now();
// 归还缓冲区
for buffer in &mut buffers {
buffer.clear();
buffer.resize(BUFFER_SIZE, 0);
}
buffer_pool.return_buffers(buffers);
// 避免 CPU 占用过高
if valid_count == 0 {
std::thread::yield_now();
}
}
});
}
}
// 优化的接收线程
// 为每个 socket 启动多个接收线程
for (socket, mut router_writers) in router_writers3 {
let stats_clone = Arc::clone(&stats);
let local_secret_clone = local_secret.clone();
let config_local_id = config.local_id;
s.spawn(move |_| {
// 预分配多个缓冲区
let mut recv_bufs: Vec<[MaybeUninit<u8>; MAX_PACKET_SIZE]> =
(0..4).map(|_| unsafe { MaybeUninit::uninit().assume_init() }).collect();
let mut buf_idx = 0;
// 将 router_writers 转换为使用原子指针的版本
let atomic_writers: Arc<HashMap<u8, (RouterWriter, Arc<AtomicEndpoint>)>> = Arc::new(
router_writers
.drain()
.map(|(id, writer)| {
let endpoint = Arc::clone(&atomic_endpoints[&id]);
(id, (writer, endpoint))
})
.collect()
);
loop {
let recv_buf = &mut recv_bufs[buf_idx];
buf_idx = (buf_idx + 1) % recv_bufs.len();
match socket.recv_from(recv_buf) {
Ok((len, sock_addr)) => {
// 转换为 SocketAddr
let addr = match sock_addr.as_socket() {
Some(addr) => addr,
None => continue,
};
// 根据 CPU 核心数创建接收线程
let receiver_threads = num_threads.min(8); // 接收线程可以多一些
// 将 MaybeUninit 转换为初始化的数据
let data = unsafe {
std::slice::from_raw_parts_mut(
recv_buf.as_mut_ptr() as *mut u8,
len
)
};
for thread_id in 0..receiver_threads {
let socket = Arc::clone(&socket);
let atomic_writers = Arc::clone(&atomic_writers);
let local_id = config.local_id;
let local_secret = local_secret.clone();
// 快速路径检查
if len < 20 + mem::size_of::<Meta>() {
continue;
s.spawn(move |_| {
// 设置线程 CPU 亲和性(可选)
#[cfg(target_os = "linux")]
{
use libc::{cpu_set_t, CPU_SET, CPU_ZERO, sched_setaffinity};
unsafe {
let mut cpu_set: cpu_set_t = std::mem::zeroed();
CPU_ZERO(&mut cpu_set);
CPU_SET(thread_id % num_threads, &mut cpu_set);
sched_setaffinity(0, std::mem::size_of_val(&cpu_set), &cpu_set);
}
}
if let Some(packet) = Ipv4Packet::new(data) {
let header_len = packet.get_header_length() as usize * 4;
if header_len < data.len() {
let rest = &mut data[header_len..];
if rest.len() >= mem::size_of::<Meta>() {
let (meta_bytes, payload) = rest.split_at_mut(mem::size_of::<Meta>());
let meta: &Meta = unsafe { &*(meta_bytes.as_ptr() as *const Meta) };
if meta.dst_id == config_local_id && meta.reversed == 0 {
if let Some(router) = router_writers.get_mut(&meta.src_id) {
// 更新endpoint
*router.endpoint.write().unwrap() = Some(addr);
// 解密
router.decrypt(payload, &local_secret_clone);
let mut buffer_pool = BufferPool::new(BUFFER_POOL_SIZE_PER_THREAD);
// 写入TUN
let _ = router.tun_writer.write_all(payload);
loop {
let mut recv_buffers = buffer_pool.get_buffers(BATCH_SIZE);
let mut addrs = vec![MaybeUninit::uninit(); BATCH_SIZE];
let mut valid_count = 0;
// 更新统计
stats_clone.packets_recv.fetch_add(1, Ordering::Relaxed);
stats_clone.bytes_recv.fetch_add(payload.len() as u64, Ordering::Relaxed);
// 批量接收
#[cfg(target_os = "linux")]
{
// TODO: 实现 recvmmsg 批量接收
// 暂时使用普通接收
for i in 0..BATCH_SIZE {
match socket.recv_from(&mut recv_buffers[i]) {
Ok((len, addr)) => {
recv_buffers[i].truncate(len);
addrs[i].write(addr);
valid_count += 1;
}
Err(_) => break,
}
}
}
#[cfg(not(target_os = "linux"))]
{
for i in 0..BATCH_SIZE {
match socket.recv_from(&mut recv_buffers[i]) {
Ok((len, addr)) => {
recv_buffers[i].truncate(len);
addrs[i].write(addr);
valid_count += 1;
}
Err(_) => break,
}
Err(_) => continue,
}
}
});
// 批量处理接收到的包
for i in 0..valid_count {
let _ = (|| {
let data = &mut recv_buffers[i];
let addr = unsafe { addrs[i].assume_init() };
let packet = Ipv4Packet::new(data).ok_or("malformed packet")?;
let header_len = packet.get_header_length() as usize * 4;
let (_ip_header, rest) = data
.split_at_mut_checked(header_len)
.ok_or("malformed packet")?;
let (meta_bytes, payload) = rest
.split_at_mut_checked(size_of::<Meta>())
.ok_or("malformed packet")?;
let meta: &Meta = unsafe { transmute(meta_bytes.as_ptr()) };
if meta.dst_id == local_id && meta.reversed == 0 {
if let Some((router, endpoint)) = atomic_writers.get(&meta.src_id) {
endpoint.store(addr);
router.decrypt(payload, &local_secret);
router.tun_writer.write_all(payload)?;
}
}
// 统计线程
let stats_clone = Arc::clone(&stats);
s.spawn(move |_| {
let mut last_stats = (0u64, 0u64, 0u64, 0u64);
loop {
thread::sleep(Duration::from_secs(10));
let sent_pkts = stats_clone.packets_sent.load(Ordering::Relaxed);
let recv_pkts = stats_clone.packets_recv.load(Ordering::Relaxed);
let sent_bytes = stats_clone.bytes_sent.load(Ordering::Relaxed);
let recv_bytes = stats_clone.bytes_recv.load(Ordering::Relaxed);
let sent_pps = (sent_pkts - last_stats.0) / 10;
let recv_pps = (recv_pkts - last_stats.1) / 10;
let sent_mbps = ((sent_bytes - last_stats.2) * 8) / (10 * 1024 * 1024);
let recv_mbps = ((recv_bytes - last_stats.3) * 8) / (10 * 1024 * 1024);
println!(
"Stats - TX: {} pps, {} Mbps | RX: {} pps, {} Mbps",
sent_pps, sent_mbps, recv_pps, recv_mbps
);
Ok::<(), Box<dyn Error>>(())
})();
}
last_stats = (sent_pkts, recv_pkts, sent_bytes, recv_bytes);
// 归还缓冲区
for buffer in &mut recv_buffers {
buffer.clear();
buffer.resize(BUFFER_SIZE, 0);
}
buffer_pool.return_buffers(recv_buffers);
if valid_count == 0 {
std::thread::yield_now();
}
}
});
}
}
})
.unwrap();
Ok(())
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment