Commit 5cf0d2c5 authored by nanahira's avatar nanahira

retry

parent bbb76719
Pipeline #37401 failed with stages
in 15 seconds
......@@ -4,12 +4,17 @@ use crate::router::{Router, RouterReader, RouterWriter, SECRET_LENGTH};
use std::collections::HashMap;
use std::env;
use std::error::Error;
use std::intrinsics::transmute;
use std::io::{Read, Write};
use std::mem::MaybeUninit;
use std::sync::Arc;
use std::mem::{self, MaybeUninit};
use std::net::SocketAddr;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use std::sync::mpsc::{sync_channel, Receiver, SyncSender, TryRecvError};
use std::time::{Duration, Instant};
use std::thread;
#[repr(C)]
#[derive(Copy, Clone)]
pub struct Meta {
pub src_id: u8,
pub dst_id: u8,
......@@ -36,20 +41,78 @@ pub struct Config {
pub local_secret: String,
pub routers: Vec<ConfigRouter>,
}
use crossbeam_utils::thread;
use crossbeam_utils::thread as scoped_thread;
use grouping_by::GroupingBy;
use pnet::packet::ipv4::Ipv4Packet;
use socket2::Socket;
// 性能优化配置
const BATCH_SIZE: usize = 32;
const BATCH_TIMEOUT_MILLIS: u64 = 1;
const CHANNEL_SIZE: usize = 1024;
const SOCKET_BUFFER_SIZE: usize = 4 * 1024 * 1024; // 4MB
const MAX_PACKET_SIZE: usize = 1500;
// 数据包结构
#[derive(Clone)]
struct Packet {
data: Vec<u8>,
len: usize,
}
impl Packet {
fn new() -> Self {
Self {
data: vec![0u8; MAX_PACKET_SIZE],
len: 0,
}
}
}
// 统计信息
struct Stats {
packets_sent: AtomicU64,
packets_recv: AtomicU64,
bytes_sent: AtomicU64,
bytes_recv: AtomicU64,
}
fn optimize_socket(socket: &Socket) -> Result<(), Box<dyn Error>> {
socket.set_send_buffer_size(SOCKET_BUFFER_SIZE)?;
socket.set_recv_buffer_size(SOCKET_BUFFER_SIZE)?;
socket.set_nonblocking(false)?;
#[cfg(target_os = "linux")]
{
// Linux 特定优化
socket.set_nodelay(true)?;
}
Ok(())
}
fn main() -> Result<(), Box<dyn Error>> {
let config: Config = serde_json::from_str(env::args().nth(1).ok_or("need param")?.as_str())?;
let local_secret: [u8; SECRET_LENGTH] = Router::create_secret(config.local_secret.as_str())?;
let mut sockets: HashMap<u16, Arc<Socket>> = HashMap::new();
// 创建并优化 routers
let routers: HashMap<u8, Router> = config
.routers
.iter()
.map(|c| Router::new(c, &mut sockets).map(|router| (c.remote_id, router)))
.map(|c| {
Router::new(c, &mut sockets).map(|router| {
// 优化每个 socket
let router_key = Router::key(c); // 使用关联函数
if let Some(socket) = sockets.get(&router_key) {
let _ = optimize_socket(socket);
}
(c.remote_id, router)
})
})
.collect::<Result<_, _>>()?;
let (mut router_readers, router_writers): (
HashMap<u8, RouterReader>,
HashMap<u8, RouterWriter>,
......@@ -60,6 +123,7 @@ fn main() -> Result<(), Box<dyn Error>> {
((id, reader), (id, writer))
})
.unzip();
let router_writers3: Vec<(Arc<Socket>, HashMap<u8, RouterWriter>)> = router_writers
.into_iter()
.grouping_by(|(_, v)| v.key())
......@@ -71,77 +135,218 @@ fn main() -> Result<(), Box<dyn Error>> {
)
})
.collect();
println!("created tuns");
// 全局统计
let stats = Arc::new(Stats {
packets_sent: AtomicU64::new(0),
packets_recv: AtomicU64::new(0),
bytes_sent: AtomicU64::new(0),
bytes_recv: AtomicU64::new(0),
});
println!("Created optimized TUN devices");
thread::scope(|s| {
scoped_thread::scope(|s| {
// 为每个 router 创建读写线程
for router in router_readers.values_mut() {
s.spawn(|_| {
// 使用 2048 字节缓冲区,足够处理大多数包且避免过大的内存占用
let mut buffer = vec![0u8; 2048];
let meta_size = size_of::<Meta>();
// 预初始化 Meta 头部
let router_id = router.config.remote_id;
let local_id = config.local_id;
let stats_clone = Arc::clone(&stats);
// 创建批处理通道
let (tx, rx): (SyncSender<Packet>, Receiver<Packet>) = sync_channel(CHANNEL_SIZE);
// TUN 读取线程
s.spawn(move |_| {
let meta_size = mem::size_of::<Meta>();
let mut buffer = vec![0u8; MAX_PACKET_SIZE];
// 预构建 Meta
let meta = Meta {
src_id: config.local_id,
dst_id: router.config.remote_id,
src_id: local_id,
dst_id: router_id,
reversed: 0,
};
let meta_bytes = unsafe {
std::slice::from_raw_parts(&meta as *const Meta as *const u8, meta_size)
std::slice::from_raw_parts(&meta as *const _ as *const u8, meta_size)
};
buffer[..meta_size].copy_from_slice(meta_bytes);
loop {
// 写入 Meta 头
buffer[..meta_size].copy_from_slice(meta_bytes);
// 读取数据
match router.tun_reader.read(&mut buffer[meta_size..]) {
Ok(n) => {
if let Some(ref addr) = *router.endpoint.read().unwrap() {
router.encrypt(&mut buffer[meta_size..meta_size + n]);
#[cfg(target_os = "linux")]
let _ = router.socket.set_mark(router.config.mark);
let _ = router.socket.send_to(&buffer[..meta_size + n], addr);
Ok(n) if n > 0 => {
let packet = Packet {
data: buffer[..meta_size + n].to_vec(),
len: meta_size + n,
};
let _ = tx.try_send(packet);
}
_ => continue,
}
}
});
// 批处理发送线程
s.spawn(move |_| {
let mut batch = Vec::with_capacity(BATCH_SIZE);
let meta_size = mem::size_of::<Meta>();
let mut last_batch_time = Instant::now();
loop {
batch.clear();
// 收集数据包
match rx.recv_timeout(Duration::from_millis(BATCH_TIMEOUT_MILLIS)) {
Ok(packet) => batch.push(packet),
Err(_) => {
if batch.is_empty() { continue; }
}
}
// 继续收集直到批量大小或超时
while batch.len() < BATCH_SIZE {
match rx.try_recv() {
Ok(packet) => batch.push(packet),
Err(TryRecvError::Empty) => {
if last_batch_time.elapsed() > Duration::from_millis(BATCH_TIMEOUT_MILLIS) {
break;
}
}
Err(TryRecvError::Disconnected) => return,
}
Err(_) => continue,
}
if !batch.is_empty() {
// 获取endpoint
if let Some(ref addr) = *router.endpoint.read().unwrap() {
// 加密所有包
for packet in &mut batch {
router.encrypt(&mut packet.data[meta_size..packet.len]);
}
// 设置 mark(每批只设置一次)
#[cfg(target_os = "linux")]
let _ = router.socket.set_mark(router.config.mark);
// 批量发送
let mut total_bytes = 0;
for packet in &batch {
if let Ok(_) = router.socket.send_to(&packet.data[..packet.len], addr) {
total_bytes += packet.len;
}
}
// 更新统计
stats_clone.packets_sent.fetch_add(batch.len() as u64, Ordering::Relaxed);
stats_clone.bytes_sent.fetch_add(total_bytes as u64, Ordering::Relaxed);
}
last_batch_time = Instant::now();
}
}
});
}
// 优化的接收线程
for (socket, mut router_writers) in router_writers3 {
let stats_clone = Arc::clone(&stats);
let local_secret_clone = local_secret.clone();
let config_local_id = config.local_id;
s.spawn(move |_| {
// 使用 2048 字节接收缓冲区
let mut recv_buf = vec![MaybeUninit::uninit(); 2048];
// 预分配多个缓冲区
let mut recv_bufs: Vec<[MaybeUninit<u8>; MAX_PACKET_SIZE]> =
(0..4).map(|_| unsafe { MaybeUninit::uninit().assume_init() }).collect();
let mut buf_idx = 0;
loop {
let _ = (|| -> Result<(), Box<dyn Error>> {
let (len, addr) = socket.recv_from(&mut recv_buf)?;
let data: &mut [u8] = unsafe { transmute(&mut recv_buf[..len]) };
let packet = Ipv4Packet::new(data).ok_or("malformed packet")?;
let header_len = packet.get_header_length() as usize * 4;
let (_ip_header, rest) = data
.split_at_mut_checked(header_len)
.ok_or("malformed packet")?;
let (meta_bytes, payload) = rest
.split_at_mut_checked(size_of::<Meta>())
.ok_or("malformed packet")?;
let meta: &Meta = unsafe { transmute(meta_bytes.as_ptr()) };
if meta.dst_id == config.local_id && meta.reversed == 0 {
let router = router_writers
.get_mut(&meta.src_id)
.ok_or("missing router")?;
*router.endpoint.write().unwrap() = Some(addr);
router.decrypt(payload, &local_secret);
router.tun_writer.write_all(payload)?;
let recv_buf = &mut recv_bufs[buf_idx];
buf_idx = (buf_idx + 1) % recv_bufs.len();
match socket.recv_from(recv_buf) {
Ok((len, sock_addr)) => {
// 转换为 SocketAddr
let addr = match sock_addr.as_socket() {
Some(addr) => addr,
None => continue,
};
// 将 MaybeUninit 转换为初始化的数据
let data = unsafe {
std::slice::from_raw_parts_mut(
recv_buf.as_mut_ptr() as *mut u8,
len
)
};
// 快速路径检查
if len < 20 + mem::size_of::<Meta>() {
continue;
}
if let Some(packet) = Ipv4Packet::new(data) {
let header_len = packet.get_header_length() as usize * 4;
if header_len < data.len() {
let rest = &mut data[header_len..];
if rest.len() >= mem::size_of::<Meta>() {
let (meta_bytes, payload) = rest.split_at_mut(mem::size_of::<Meta>());
let meta: &Meta = unsafe { &*(meta_bytes.as_ptr() as *const Meta) };
if meta.dst_id == config_local_id && meta.reversed == 0 {
if let Some(router) = router_writers.get_mut(&meta.src_id) {
// 更新endpoint
*router.endpoint.write().unwrap() = Some(addr);
// 解密
router.decrypt(payload, &local_secret_clone);
// 写入TUN
let _ = router.tun_writer.write_all(payload);
// 更新统计
stats_clone.packets_recv.fetch_add(1, Ordering::Relaxed);
stats_clone.bytes_recv.fetch_add(payload.len() as u64, Ordering::Relaxed);
}
}
}
}
}
}
Ok(())
})();
Err(_) => continue,
}
}
});
}
// 统计线程
let stats_clone = Arc::clone(&stats);
s.spawn(move |_| {
let mut last_stats = (0u64, 0u64, 0u64, 0u64);
loop {
thread::sleep(Duration::from_secs(10));
let sent_pkts = stats_clone.packets_sent.load(Ordering::Relaxed);
let recv_pkts = stats_clone.packets_recv.load(Ordering::Relaxed);
let sent_bytes = stats_clone.bytes_sent.load(Ordering::Relaxed);
let recv_bytes = stats_clone.bytes_recv.load(Ordering::Relaxed);
let sent_pps = (sent_pkts - last_stats.0) / 10;
let recv_pps = (recv_pkts - last_stats.1) / 10;
let sent_mbps = ((sent_bytes - last_stats.2) * 8) / (10 * 1024 * 1024);
let recv_mbps = ((recv_bytes - last_stats.3) * 8) / (10 * 1024 * 1024);
println!(
"Stats - TX: {} pps, {} Mbps | RX: {} pps, {} Mbps",
sent_pps, sent_mbps, recv_pps, recv_mbps
);
last_stats = (sent_pkts, recv_pkts, sent_bytes, recv_bytes);
}
});
})
.unwrap();
Ok(())
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment