Commit 6866a72a authored by nanahira's avatar nanahira

change

parent ce345e4d
......@@ -13,3 +13,4 @@ base64 = "0.22.1"
crossbeam = "0.8.4"
crossbeam-utils = "0.8.20"
grouping_by = "0.2.2"
libc = "0.2"
mod router;
use crate::router::{Router, RouterReader, RouterWriter, SECRET_LENGTH};
use arc_swap::ArcSwap;
use bytes::{Bytes, BytesMut};
use crossbeam_channel::{bounded, Receiver, Sender, TryRecvError};
use parking_lot::RwLock;
use std::collections::HashMap;
use std::env;
use std::error::Error;
use std::intrinsics::transmute;
use std::io::{Read, Write};
use std::mem::{self, MaybeUninit};
use std::net::SocketAddr;
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use std::sync::mpsc::{sync_channel, Receiver, SyncSender, TryRecvError};
use std::time::{Duration, Instant};
use std::thread;
#[repr(C)]
#[derive(Copy, Clone)]
......@@ -45,61 +42,32 @@ pub struct Config {
pub routers: Vec<ConfigRouter>,
}
use crossbeam_utils::thread;
use crossbeam_utils::thread as scoped_thread;
use grouping_by::GroupingBy;
use pnet::packet::ipv4::Ipv4Packet;
use socket2::{Domain, Protocol, Socket, Type};
use socket2::Socket;
// 性能优化配置
const BATCH_SIZE: usize = 64; // 增加批处理大小
const BATCH_TIMEOUT_MICROS: u64 = 50; // 减少批处理超时
const CHANNEL_SIZE: usize = 4096; // 增加通道容量
const SOCKET_BUFFER_SIZE: usize = 8 * 1024 * 1024; // 8MB socket 缓冲区
const PACKET_POOL_SIZE: usize = 256; // 预分配的数据包池大小
const BATCH_SIZE: usize = 32;
const BATCH_TIMEOUT_MILLIS: u64 = 1;
const CHANNEL_SIZE: usize = 1024;
const SOCKET_BUFFER_SIZE: usize = 4 * 1024 * 1024; // 4MB
const MAX_PACKET_SIZE: usize = 1500;
// 零拷贝数据包结构
// 数据包结构
#[derive(Clone)]
struct Packet {
data: BytesMut,
data: Vec<u8>,
len: usize,
timestamp: Instant,
}
impl Packet {
fn new() -> Self {
Self {
data: BytesMut::with_capacity(MAX_PACKET_SIZE),
data: vec![0u8; MAX_PACKET_SIZE],
len: 0,
timestamp: Instant::now(),
}
}
fn reset(&mut self) {
self.data.clear();
self.len = 0;
self.timestamp = Instant::now();
}
}
// 数据包池 - 减少内存分配
struct PacketPool {
pool: Vec<Packet>,
used: AtomicU64,
}
impl PacketPool {
fn new(size: usize) -> Self {
let pool = (0..size).map(|_| Packet::new()).collect();
Self {
pool,
used: AtomicU64::new(0),
}
}
fn get(&self) -> Option<&mut Packet> {
let idx = self.used.fetch_add(1, Ordering::Relaxed) as usize % self.pool.len();
unsafe { Some(&mut *(self.pool.as_ptr().add(idx) as *mut Packet)) }
}
}
// 统计信息
......@@ -113,38 +81,12 @@ struct Stats {
fn optimize_socket(socket: &Socket) -> Result<(), Box<dyn Error>> {
socket.set_send_buffer_size(SOCKET_BUFFER_SIZE)?;
socket.set_recv_buffer_size(SOCKET_BUFFER_SIZE)?;
socket.set_nonblocking(false)?; // 使用阻塞模式配合批处理
socket.set_nonblocking(false)?;
#[cfg(target_os = "linux")]
{
// Linux 特定优化
use libc::{c_int, c_void, setsockopt, SOL_SOCKET};
// SO_BUSY_POLL - 减少延迟
const SO_BUSY_POLL: c_int = 46;
let busy_poll: c_int = 50; // 50us
unsafe {
setsockopt(
socket.as_raw_fd(),
SOL_SOCKET,
SO_BUSY_POLL,
&busy_poll as *const _ as *const c_void,
mem::size_of::<c_int>() as u32,
);
}
// SO_INCOMING_CPU - CPU 亲和性
const SO_INCOMING_CPU: c_int = 49;
let cpu: c_int = 0;
unsafe {
setsockopt(
socket.as_raw_fd(),
SOL_SOCKET,
SO_INCOMING_CPU,
&cpu as *const _ as *const c_void,
mem::size_of::<c_int>() as u32,
);
}
socket.set_nodelay(true)?;
}
Ok(())
......@@ -162,7 +104,8 @@ fn main() -> Result<(), Box<dyn Error>> {
.map(|c| {
Router::new(c, &mut sockets).map(|router| {
// 优化每个 socket
if let Some(socket) = sockets.get(&router.key()) {
let router_key = Router::key(c); // 使用关联函数
if let Some(socket) = sockets.get(&router_key) {
let _ = optimize_socket(socket);
}
(c.remote_id, router)
......@@ -181,26 +124,17 @@ fn main() -> Result<(), Box<dyn Error>> {
})
.unzip();
// 使用 ArcSwap 存储 writers 以减少锁竞争
let router_writers_arc: Arc<HashMap<u8, Arc<RouterWriter>>> = Arc::new(
router_writers.into_iter()
.map(|(k, v)| (k, Arc::new(v)))
.collect()
);
let router_writers3: Vec<(Arc<Socket>, Arc<HashMap<u8, Arc<RouterWriter>>>)> =
router_writers_arc.iter()
.fold(HashMap::<u16, Vec<(u8, Arc<RouterWriter>)>>::new(), |mut acc, (id, writer)| {
acc.entry(writer.key()).or_insert_with(Vec::new).push((*id, Arc::clone(writer)));
acc
})
.into_iter()
.map(|(key, writers)| {
let socket = Arc::clone(sockets.get(&key).unwrap());
let writers_map = Arc::new(writers.into_iter().collect::<HashMap<_, _>>());
(socket, writers_map)
})
.collect();
let router_writers3: Vec<(Arc<Socket>, HashMap<u8, RouterWriter>)> = router_writers
.into_iter()
.grouping_by(|(_, v)| v.key())
.into_iter()
.map(|(k, v)| {
(
Arc::clone(sockets.get_mut(&k).unwrap()),
v.into_iter().collect(),
)
})
.collect();
// 全局统计
let stats = Arc::new(Stats {
......@@ -212,26 +146,20 @@ fn main() -> Result<(), Box<dyn Error>> {
println!("Created optimized TUN devices");
thread::scope(|s| {
// 为每个 router 创建高性能读写线程组
for (router_id, router) in router_readers.iter_mut() {
let router_id = *router_id;
scoped_thread::scope(|s| {
// 为每个 router 创建读写线程
for router in router_readers.values_mut() {
let router_id = router.config.remote_id;
let local_id = config.local_id;
let stats_clone = Arc::clone(&stats);
// 创建多个通道用于负载均衡
let mut channels = Vec::new();
for _ in 0..2 { // 2个并行处理通道
channels.push(bounded::<Packet>(CHANNEL_SIZE));
}
// 创建批处理通道
let (tx, rx): (SyncSender<Packet>, Receiver<Packet>) = sync_channel(CHANNEL_SIZE);
// TUN 读取线程 - 零拷贝优化
let channels_clone = channels.clone();
let mut router_reader = router.clone(); // 假设实现了 Clone
// TUN 读取线程
s.spawn(move |_| {
let packet_pool = PacketPool::new(PACKET_POOL_SIZE);
let meta_size = mem::size_of::<Meta>();
let mut channel_idx = 0;
let mut buffer = vec![0u8; MAX_PACKET_SIZE];
// 预构建 Meta
let meta = Meta {
......@@ -239,132 +167,120 @@ fn main() -> Result<(), Box<dyn Error>> {
dst_id: router_id,
reversed: 0,
};
let meta_bytes = unsafe {
std::slice::from_raw_parts(&meta as *const _ as *const u8, meta_size)
};
loop {
if let Some(packet) = packet_pool.get() {
packet.reset();
packet.data.resize(MAX_PACKET_SIZE, 0);
// 写入 Meta 头
let meta_bytes = unsafe {
std::slice::from_raw_parts(&meta as *const _ as *const u8, meta_size)
};
packet.data[..meta_size].copy_from_slice(meta_bytes);
// 读取数据
match router_reader.tun_reader.read(&mut packet.data[meta_size..]) {
Ok(n) if n > 0 => {
packet.len = meta_size + n;
packet.data.truncate(packet.len);
// 轮询发送到不同通道
let tx = &channels_clone[channel_idx].0;
let _ = tx.try_send(*packet);
channel_idx = (channel_idx + 1) % channels_clone.len();
}
_ => continue,
// 写入 Meta 头
buffer[..meta_size].copy_from_slice(meta_bytes);
// 读取数据
match router.tun_reader.read(&mut buffer[meta_size..]) {
Ok(n) if n > 0 => {
let packet = Packet {
data: buffer[..meta_size + n].to_vec(),
len: meta_size + n,
};
let _ = tx.try_send(packet);
}
_ => continue,
}
}
});
// 多个加密发送线程
for (tx, rx) in channels {
let router_clone = router.clone();
let stats_clone2 = Arc::clone(&stats_clone);
// 批处理发送线程
s.spawn(move |_| {
let mut batch = Vec::with_capacity(BATCH_SIZE);
let meta_size = mem::size_of::<Meta>();
let mut last_batch_time = Instant::now();
s.spawn(move |_| {
let mut batch = Vec::with_capacity(BATCH_SIZE);
let meta_size = mem::size_of::<Meta>();
let mut last_batch_time = Instant::now();
loop {
batch.clear();
loop {
batch.clear();
// 智能批处理 - 根据负载动态调整
let timeout = if batch.is_empty() {
Duration::from_millis(1)
} else {
Duration::from_micros(BATCH_TIMEOUT_MICROS)
};
// 收集数据包
match rx.recv_timeout(timeout) {
Ok(packet) => batch.push(packet),
Err(_) => {
if batch.is_empty() { continue; }
}
// 收集数据包
match rx.recv_timeout(Duration::from_millis(BATCH_TIMEOUT_MILLIS)) {
Ok(packet) => batch.push(packet),
Err(_) => {
if batch.is_empty() { continue; }
}
// 继续收集直到批量大小或超时
while batch.len() < BATCH_SIZE {
match rx.try_recv() {
Ok(packet) => batch.push(packet),
Err(TryRecvError::Empty) => {
if last_batch_time.elapsed() > Duration::from_micros(BATCH_TIMEOUT_MICROS) {
break;
}
}
// 继续收集直到批量大小或超时
while batch.len() < BATCH_SIZE {
match rx.try_recv() {
Ok(packet) => batch.push(packet),
Err(TryRecvError::Empty) => {
if last_batch_time.elapsed() > Duration::from_millis(BATCH_TIMEOUT_MILLIS) {
break;
}
Err(TryRecvError::Disconnected) => return,
}
Err(TryRecvError::Disconnected) => return,
}
if !batch.is_empty() {
// 获取endpoint(使用缓存减少锁竞争)
if let Some(ref addr) = *router_clone.endpoint.read().unwrap() {
// 并行加密
batch.par_iter_mut().for_each(|packet| {
router_clone.encrypt(&mut packet.data[meta_size..packet.len]);
});
// 设置 mark(每批只设置一次)
#[cfg(target_os = "linux")]
let _ = router_clone.socket.set_mark(router_clone.config.mark);
// 批量发送
let mut total_bytes = 0;
for packet in &batch {
if let Ok(_) = router_clone.socket.send_to(&packet.data[..packet.len], addr) {
total_bytes += packet.len;
}
}
if !batch.is_empty() {
// 获取endpoint
if let Some(ref addr) = *router.endpoint.read().unwrap() {
// 加密所有包
for packet in &mut batch {
router.encrypt(&mut packet.data[meta_size..packet.len]);
}
// 设置 mark(每批只设置一次)
#[cfg(target_os = "linux")]
let _ = router.socket.set_mark(router.config.mark);
// 批量发送
let mut total_bytes = 0;
for packet in &batch {
if let Ok(_) = router.socket.send_to(&packet.data[..packet.len], addr) {
total_bytes += packet.len;
}
// 更新统计
stats_clone2.packets_sent.fetch_add(batch.len() as u64, Ordering::Relaxed);
stats_clone2.bytes_sent.fetch_add(total_bytes as u64, Ordering::Relaxed);
}
last_batch_time = Instant::now();
// 更新统计
stats_clone.packets_sent.fetch_add(batch.len() as u64, Ordering::Relaxed);
stats_clone.bytes_sent.fetch_add(total_bytes as u64, Ordering::Relaxed);
}
last_batch_time = Instant::now();
}
});
}
}
});
}
// 优化的接收线程
for (socket, router_writers) in router_writers3 {
for (socket, mut router_writers) in router_writers3 {
let stats_clone = Arc::clone(&stats);
let local_secret_clone = local_secret.clone();
let config_local_id = config.local_id;
s.spawn(move |_| {
// 预分配接收缓冲区池
const RECV_POOL_SIZE: usize = 128;
let mut recv_pool: Vec<Vec<u8>> = (0..RECV_POOL_SIZE)
.map(|_| vec![0u8; MAX_PACKET_SIZE])
.collect();
let mut pool_idx = 0;
// Endpoint 缓存
let mut endpoint_cache: HashMap<u8, SocketAddr> = HashMap::new();
// 预分配多个缓冲区
let mut recv_bufs: Vec<[MaybeUninit<u8>; MAX_PACKET_SIZE]> =
(0..4).map(|_| unsafe { MaybeUninit::uninit().assume_init() }).collect();
let mut buf_idx = 0;
loop {
let recv_buf = &mut recv_pool[pool_idx];
pool_idx = (pool_idx + 1) % RECV_POOL_SIZE;
let recv_buf = &mut recv_bufs[buf_idx];
buf_idx = (buf_idx + 1) % recv_bufs.len();
match socket.recv_from(recv_buf) {
Ok((len, addr)) => {
let data = &mut recv_buf[..len];
Ok((len, sock_addr)) => {
// 转换为 SocketAddr
let addr = match sock_addr.as_socket() {
Some(addr) => addr,
None => continue,
};
// 将 MaybeUninit 转换为初始化的数据
let data = unsafe {
std::slice::from_raw_parts_mut(
recv_buf.as_mut_ptr() as *mut u8,
len
)
};
// 快速路径检查
if len < 20 + mem::size_of::<Meta>() {
......@@ -373,14 +289,15 @@ fn main() -> Result<(), Box<dyn Error>> {
if let Some(packet) = Ipv4Packet::new(data) {
let header_len = packet.get_header_length() as usize * 4;
if let Some((_ip_header, rest)) = data.split_at_mut_checked(header_len) {
if let Some((meta_bytes, payload)) = rest.split_at_mut_checked(mem::size_of::<Meta>()) {
if header_len < data.len() {
let rest = &mut data[header_len..];
if rest.len() >= mem::size_of::<Meta>() {
let (meta_bytes, payload) = rest.split_at_mut(mem::size_of::<Meta>());
let meta: &Meta = unsafe { &*(meta_bytes.as_ptr() as *const Meta) };
if meta.dst_id == config_local_id && meta.reversed == 0 {
if let Some(router) = router_writers.get(&meta.src_id) {
// 更新endpoint缓存
endpoint_cache.insert(meta.src_id, addr);
if let Some(router) = router_writers.get_mut(&meta.src_id) {
// 更新endpoint
*router.endpoint.write().unwrap() = Some(addr);
// 解密
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment