Commit 6866a72a authored by nanahira's avatar nanahira

change

parent ce345e4d
...@@ -13,3 +13,4 @@ base64 = "0.22.1" ...@@ -13,3 +13,4 @@ base64 = "0.22.1"
crossbeam = "0.8.4" crossbeam = "0.8.4"
crossbeam-utils = "0.8.20" crossbeam-utils = "0.8.20"
grouping_by = "0.2.2" grouping_by = "0.2.2"
libc = "0.2"
mod router; mod router;
use crate::router::{Router, RouterReader, RouterWriter, SECRET_LENGTH}; use crate::router::{Router, RouterReader, RouterWriter, SECRET_LENGTH};
use arc_swap::ArcSwap;
use bytes::{Bytes, BytesMut};
use crossbeam_channel::{bounded, Receiver, Sender, TryRecvError};
use parking_lot::RwLock;
use std::collections::HashMap; use std::collections::HashMap;
use std::env; use std::env;
use std::error::Error; use std::error::Error;
use std::intrinsics::transmute;
use std::io::{Read, Write}; use std::io::{Read, Write};
use std::mem::{self, MaybeUninit}; use std::mem::{self, MaybeUninit};
use std::net::SocketAddr; use std::net::SocketAddr;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering}; use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc; use std::sync::{Arc, Mutex};
use std::sync::mpsc::{sync_channel, Receiver, SyncSender, TryRecvError};
use std::time::{Duration, Instant}; use std::time::{Duration, Instant};
use std::thread;
#[repr(C)] #[repr(C)]
#[derive(Copy, Clone)] #[derive(Copy, Clone)]
...@@ -45,61 +42,32 @@ pub struct Config { ...@@ -45,61 +42,32 @@ pub struct Config {
pub routers: Vec<ConfigRouter>, pub routers: Vec<ConfigRouter>,
} }
use crossbeam_utils::thread; use crossbeam_utils::thread as scoped_thread;
use grouping_by::GroupingBy; use grouping_by::GroupingBy;
use pnet::packet::ipv4::Ipv4Packet; use pnet::packet::ipv4::Ipv4Packet;
use socket2::{Domain, Protocol, Socket, Type}; use socket2::Socket;
// 性能优化配置 // 性能优化配置
const BATCH_SIZE: usize = 64; // 增加批处理大小 const BATCH_SIZE: usize = 32;
const BATCH_TIMEOUT_MICROS: u64 = 50; // 减少批处理超时 const BATCH_TIMEOUT_MILLIS: u64 = 1;
const CHANNEL_SIZE: usize = 4096; // 增加通道容量 const CHANNEL_SIZE: usize = 1024;
const SOCKET_BUFFER_SIZE: usize = 8 * 1024 * 1024; // 8MB socket 缓冲区 const SOCKET_BUFFER_SIZE: usize = 4 * 1024 * 1024; // 4MB
const PACKET_POOL_SIZE: usize = 256; // 预分配的数据包池大小
const MAX_PACKET_SIZE: usize = 1500; const MAX_PACKET_SIZE: usize = 1500;
// 零拷贝数据包结构 // 数据包结构
#[derive(Clone)]
struct Packet { struct Packet {
data: BytesMut, data: Vec<u8>,
len: usize, len: usize,
timestamp: Instant,
} }
impl Packet { impl Packet {
fn new() -> Self { fn new() -> Self {
Self { Self {
data: BytesMut::with_capacity(MAX_PACKET_SIZE), data: vec![0u8; MAX_PACKET_SIZE],
len: 0, len: 0,
timestamp: Instant::now(),
} }
} }
fn reset(&mut self) {
self.data.clear();
self.len = 0;
self.timestamp = Instant::now();
}
}
// 数据包池 - 减少内存分配
struct PacketPool {
pool: Vec<Packet>,
used: AtomicU64,
}
impl PacketPool {
fn new(size: usize) -> Self {
let pool = (0..size).map(|_| Packet::new()).collect();
Self {
pool,
used: AtomicU64::new(0),
}
}
fn get(&self) -> Option<&mut Packet> {
let idx = self.used.fetch_add(1, Ordering::Relaxed) as usize % self.pool.len();
unsafe { Some(&mut *(self.pool.as_ptr().add(idx) as *mut Packet)) }
}
} }
// 统计信息 // 统计信息
...@@ -113,38 +81,12 @@ struct Stats { ...@@ -113,38 +81,12 @@ struct Stats {
fn optimize_socket(socket: &Socket) -> Result<(), Box<dyn Error>> { fn optimize_socket(socket: &Socket) -> Result<(), Box<dyn Error>> {
socket.set_send_buffer_size(SOCKET_BUFFER_SIZE)?; socket.set_send_buffer_size(SOCKET_BUFFER_SIZE)?;
socket.set_recv_buffer_size(SOCKET_BUFFER_SIZE)?; socket.set_recv_buffer_size(SOCKET_BUFFER_SIZE)?;
socket.set_nonblocking(false)?; // 使用阻塞模式配合批处理 socket.set_nonblocking(false)?;
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
{ {
// Linux 特定优化 // Linux 特定优化
use libc::{c_int, c_void, setsockopt, SOL_SOCKET}; socket.set_nodelay(true)?;
// SO_BUSY_POLL - 减少延迟
const SO_BUSY_POLL: c_int = 46;
let busy_poll: c_int = 50; // 50us
unsafe {
setsockopt(
socket.as_raw_fd(),
SOL_SOCKET,
SO_BUSY_POLL,
&busy_poll as *const _ as *const c_void,
mem::size_of::<c_int>() as u32,
);
}
// SO_INCOMING_CPU - CPU 亲和性
const SO_INCOMING_CPU: c_int = 49;
let cpu: c_int = 0;
unsafe {
setsockopt(
socket.as_raw_fd(),
SOL_SOCKET,
SO_INCOMING_CPU,
&cpu as *const _ as *const c_void,
mem::size_of::<c_int>() as u32,
);
}
} }
Ok(()) Ok(())
...@@ -162,7 +104,8 @@ fn main() -> Result<(), Box<dyn Error>> { ...@@ -162,7 +104,8 @@ fn main() -> Result<(), Box<dyn Error>> {
.map(|c| { .map(|c| {
Router::new(c, &mut sockets).map(|router| { Router::new(c, &mut sockets).map(|router| {
// 优化每个 socket // 优化每个 socket
if let Some(socket) = sockets.get(&router.key()) { let router_key = Router::key(c); // 使用关联函数
if let Some(socket) = sockets.get(&router_key) {
let _ = optimize_socket(socket); let _ = optimize_socket(socket);
} }
(c.remote_id, router) (c.remote_id, router)
...@@ -181,26 +124,17 @@ fn main() -> Result<(), Box<dyn Error>> { ...@@ -181,26 +124,17 @@ fn main() -> Result<(), Box<dyn Error>> {
}) })
.unzip(); .unzip();
// 使用 ArcSwap 存储 writers 以减少锁竞争 let router_writers3: Vec<(Arc<Socket>, HashMap<u8, RouterWriter>)> = router_writers
let router_writers_arc: Arc<HashMap<u8, Arc<RouterWriter>>> = Arc::new( .into_iter()
router_writers.into_iter() .grouping_by(|(_, v)| v.key())
.map(|(k, v)| (k, Arc::new(v))) .into_iter()
.collect() .map(|(k, v)| {
); (
Arc::clone(sockets.get_mut(&k).unwrap()),
let router_writers3: Vec<(Arc<Socket>, Arc<HashMap<u8, Arc<RouterWriter>>>)> = v.into_iter().collect(),
router_writers_arc.iter() )
.fold(HashMap::<u16, Vec<(u8, Arc<RouterWriter>)>>::new(), |mut acc, (id, writer)| { })
acc.entry(writer.key()).or_insert_with(Vec::new).push((*id, Arc::clone(writer))); .collect();
acc
})
.into_iter()
.map(|(key, writers)| {
let socket = Arc::clone(sockets.get(&key).unwrap());
let writers_map = Arc::new(writers.into_iter().collect::<HashMap<_, _>>());
(socket, writers_map)
})
.collect();
// 全局统计 // 全局统计
let stats = Arc::new(Stats { let stats = Arc::new(Stats {
...@@ -212,26 +146,20 @@ fn main() -> Result<(), Box<dyn Error>> { ...@@ -212,26 +146,20 @@ fn main() -> Result<(), Box<dyn Error>> {
println!("Created optimized TUN devices"); println!("Created optimized TUN devices");
thread::scope(|s| { scoped_thread::scope(|s| {
// 为每个 router 创建高性能读写线程组 // 为每个 router 创建读写线程
for (router_id, router) in router_readers.iter_mut() { for router in router_readers.values_mut() {
let router_id = *router_id; let router_id = router.config.remote_id;
let local_id = config.local_id; let local_id = config.local_id;
let stats_clone = Arc::clone(&stats); let stats_clone = Arc::clone(&stats);
// 创建多个通道用于负载均衡 // 创建批处理通道
let mut channels = Vec::new(); let (tx, rx): (SyncSender<Packet>, Receiver<Packet>) = sync_channel(CHANNEL_SIZE);
for _ in 0..2 { // 2个并行处理通道
channels.push(bounded::<Packet>(CHANNEL_SIZE));
}
// TUN 读取线程 - 零拷贝优化 // TUN 读取线程
let channels_clone = channels.clone();
let mut router_reader = router.clone(); // 假设实现了 Clone
s.spawn(move |_| { s.spawn(move |_| {
let packet_pool = PacketPool::new(PACKET_POOL_SIZE);
let meta_size = mem::size_of::<Meta>(); let meta_size = mem::size_of::<Meta>();
let mut channel_idx = 0; let mut buffer = vec![0u8; MAX_PACKET_SIZE];
// 预构建 Meta // 预构建 Meta
let meta = Meta { let meta = Meta {
...@@ -239,132 +167,120 @@ fn main() -> Result<(), Box<dyn Error>> { ...@@ -239,132 +167,120 @@ fn main() -> Result<(), Box<dyn Error>> {
dst_id: router_id, dst_id: router_id,
reversed: 0, reversed: 0,
}; };
let meta_bytes = unsafe {
std::slice::from_raw_parts(&meta as *const _ as *const u8, meta_size)
};
loop { loop {
if let Some(packet) = packet_pool.get() { // 写入 Meta 头
packet.reset(); buffer[..meta_size].copy_from_slice(meta_bytes);
packet.data.resize(MAX_PACKET_SIZE, 0);
// 读取数据
// 写入 Meta 头 match router.tun_reader.read(&mut buffer[meta_size..]) {
let meta_bytes = unsafe { Ok(n) if n > 0 => {
std::slice::from_raw_parts(&meta as *const _ as *const u8, meta_size) let packet = Packet {
}; data: buffer[..meta_size + n].to_vec(),
packet.data[..meta_size].copy_from_slice(meta_bytes); len: meta_size + n,
};
// 读取数据 let _ = tx.try_send(packet);
match router_reader.tun_reader.read(&mut packet.data[meta_size..]) {
Ok(n) if n > 0 => {
packet.len = meta_size + n;
packet.data.truncate(packet.len);
// 轮询发送到不同通道
let tx = &channels_clone[channel_idx].0;
let _ = tx.try_send(*packet);
channel_idx = (channel_idx + 1) % channels_clone.len();
}
_ => continue,
} }
_ => continue,
} }
} }
}); });
// 多个加密发送线程 // 批处理发送线程
for (tx, rx) in channels { s.spawn(move |_| {
let router_clone = router.clone(); let mut batch = Vec::with_capacity(BATCH_SIZE);
let stats_clone2 = Arc::clone(&stats_clone); let meta_size = mem::size_of::<Meta>();
let mut last_batch_time = Instant::now();
s.spawn(move |_| { loop {
let mut batch = Vec::with_capacity(BATCH_SIZE); batch.clear();
let meta_size = mem::size_of::<Meta>();
let mut last_batch_time = Instant::now();
loop { // 收集数据包
batch.clear(); match rx.recv_timeout(Duration::from_millis(BATCH_TIMEOUT_MILLIS)) {
Ok(packet) => batch.push(packet),
// 智能批处理 - 根据负载动态调整 Err(_) => {
let timeout = if batch.is_empty() { if batch.is_empty() { continue; }
Duration::from_millis(1)
} else {
Duration::from_micros(BATCH_TIMEOUT_MICROS)
};
// 收集数据包
match rx.recv_timeout(timeout) {
Ok(packet) => batch.push(packet),
Err(_) => {
if batch.is_empty() { continue; }
}
} }
}
// 继续收集直到批量大小或超时
while batch.len() < BATCH_SIZE { // 继续收集直到批量大小或超时
match rx.try_recv() { while batch.len() < BATCH_SIZE {
Ok(packet) => batch.push(packet), match rx.try_recv() {
Err(TryRecvError::Empty) => { Ok(packet) => batch.push(packet),
if last_batch_time.elapsed() > Duration::from_micros(BATCH_TIMEOUT_MICROS) { Err(TryRecvError::Empty) => {
break; if last_batch_time.elapsed() > Duration::from_millis(BATCH_TIMEOUT_MILLIS) {
} break;
} }
Err(TryRecvError::Disconnected) => return,
} }
Err(TryRecvError::Disconnected) => return,
} }
}
if !batch.is_empty() {
// 获取endpoint(使用缓存减少锁竞争) if !batch.is_empty() {
if let Some(ref addr) = *router_clone.endpoint.read().unwrap() { // 获取endpoint
// 并行加密 if let Some(ref addr) = *router.endpoint.read().unwrap() {
batch.par_iter_mut().for_each(|packet| { // 加密所有包
router_clone.encrypt(&mut packet.data[meta_size..packet.len]); for packet in &mut batch {
}); router.encrypt(&mut packet.data[meta_size..packet.len]);
}
// 设置 mark(每批只设置一次)
#[cfg(target_os = "linux")] // 设置 mark(每批只设置一次)
let _ = router_clone.socket.set_mark(router_clone.config.mark); #[cfg(target_os = "linux")]
let _ = router.socket.set_mark(router.config.mark);
// 批量发送
let mut total_bytes = 0; // 批量发送
for packet in &batch { let mut total_bytes = 0;
if let Ok(_) = router_clone.socket.send_to(&packet.data[..packet.len], addr) { for packet in &batch {
total_bytes += packet.len; if let Ok(_) = router.socket.send_to(&packet.data[..packet.len], addr) {
} total_bytes += packet.len;
} }
// 更新统计
stats_clone2.packets_sent.fetch_add(batch.len() as u64, Ordering::Relaxed);
stats_clone2.bytes_sent.fetch_add(total_bytes as u64, Ordering::Relaxed);
} }
last_batch_time = Instant::now(); // 更新统计
stats_clone.packets_sent.fetch_add(batch.len() as u64, Ordering::Relaxed);
stats_clone.bytes_sent.fetch_add(total_bytes as u64, Ordering::Relaxed);
} }
last_batch_time = Instant::now();
} }
}); }
} });
} }
// 优化的接收线程 // 优化的接收线程
for (socket, router_writers) in router_writers3 { for (socket, mut router_writers) in router_writers3 {
let stats_clone = Arc::clone(&stats); let stats_clone = Arc::clone(&stats);
let local_secret_clone = local_secret.clone(); let local_secret_clone = local_secret.clone();
let config_local_id = config.local_id; let config_local_id = config.local_id;
s.spawn(move |_| { s.spawn(move |_| {
// 预分配接收缓冲区池 // 预分配多个缓冲区
const RECV_POOL_SIZE: usize = 128; let mut recv_bufs: Vec<[MaybeUninit<u8>; MAX_PACKET_SIZE]> =
let mut recv_pool: Vec<Vec<u8>> = (0..RECV_POOL_SIZE) (0..4).map(|_| unsafe { MaybeUninit::uninit().assume_init() }).collect();
.map(|_| vec![0u8; MAX_PACKET_SIZE]) let mut buf_idx = 0;
.collect();
let mut pool_idx = 0;
// Endpoint 缓存
let mut endpoint_cache: HashMap<u8, SocketAddr> = HashMap::new();
loop { loop {
let recv_buf = &mut recv_pool[pool_idx]; let recv_buf = &mut recv_bufs[buf_idx];
pool_idx = (pool_idx + 1) % RECV_POOL_SIZE; buf_idx = (buf_idx + 1) % recv_bufs.len();
match socket.recv_from(recv_buf) { match socket.recv_from(recv_buf) {
Ok((len, addr)) => { Ok((len, sock_addr)) => {
let data = &mut recv_buf[..len]; // 转换为 SocketAddr
let addr = match sock_addr.as_socket() {
Some(addr) => addr,
None => continue,
};
// 将 MaybeUninit 转换为初始化的数据
let data = unsafe {
std::slice::from_raw_parts_mut(
recv_buf.as_mut_ptr() as *mut u8,
len
)
};
// 快速路径检查 // 快速路径检查
if len < 20 + mem::size_of::<Meta>() { if len < 20 + mem::size_of::<Meta>() {
...@@ -373,14 +289,15 @@ fn main() -> Result<(), Box<dyn Error>> { ...@@ -373,14 +289,15 @@ fn main() -> Result<(), Box<dyn Error>> {
if let Some(packet) = Ipv4Packet::new(data) { if let Some(packet) = Ipv4Packet::new(data) {
let header_len = packet.get_header_length() as usize * 4; let header_len = packet.get_header_length() as usize * 4;
if let Some((_ip_header, rest)) = data.split_at_mut_checked(header_len) { if header_len < data.len() {
if let Some((meta_bytes, payload)) = rest.split_at_mut_checked(mem::size_of::<Meta>()) { let rest = &mut data[header_len..];
if rest.len() >= mem::size_of::<Meta>() {
let (meta_bytes, payload) = rest.split_at_mut(mem::size_of::<Meta>());
let meta: &Meta = unsafe { &*(meta_bytes.as_ptr() as *const Meta) }; let meta: &Meta = unsafe { &*(meta_bytes.as_ptr() as *const Meta) };
if meta.dst_id == config_local_id && meta.reversed == 0 { if meta.dst_id == config_local_id && meta.reversed == 0 {
if let Some(router) = router_writers.get(&meta.src_id) { if let Some(router) = router_writers.get_mut(&meta.src_id) {
// 更新endpoint缓存 // 更新endpoint
endpoint_cache.insert(meta.src_id, addr);
*router.endpoint.write().unwrap() = Some(addr); *router.endpoint.write().unwrap() = Some(addr);
// 解密 // 解密
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment