Commit f4b3c34c authored by nanahira's avatar nanahira

Revert "try"

This reverts commit 74855653.
parent 74855653
...@@ -8,9 +8,6 @@ use std::intrinsics::transmute; ...@@ -8,9 +8,6 @@ use std::intrinsics::transmute;
use std::io::{Read, Write}; use std::io::{Read, Write};
use std::mem::MaybeUninit; use std::mem::MaybeUninit;
use std::sync::Arc; use std::sync::Arc;
use std::sync::atomic::{AtomicBool, Ordering};
use std::net::SocketAddr;
use parking_lot::RwLock;
#[repr(C)] #[repr(C)]
pub struct Meta { pub struct Meta {
...@@ -41,16 +38,8 @@ pub struct Config { ...@@ -41,16 +38,8 @@ pub struct Config {
} }
use crossbeam_utils::thread; use crossbeam_utils::thread;
use grouping_by::GroupingBy; use grouping_by::GroupingBy;
use pnet::packet::ipv4::Ipv4Packet;
use socket2::Socket; use socket2::Socket;
// 优化的 RouterWriter 包装器,用于缓存 mark 状态
struct OptimizedRouterWriter {
writer: RouterWriter,
#[cfg(target_os = "linux")]
mark_set: AtomicBool,
}
fn main() -> Result<(), Box<dyn Error>> { fn main() -> Result<(), Box<dyn Error>> {
let config: Config = serde_json::from_str(env::args().nth(1).ok_or("need param")?.as_str())?; let config: Config = serde_json::from_str(env::args().nth(1).ok_or("need param")?.as_str())?;
let local_secret: [u8; SECRET_LENGTH] = Router::create_secret(config.local_secret.as_str())?; let local_secret: [u8; SECRET_LENGTH] = Router::create_secret(config.local_secret.as_str())?;
...@@ -70,51 +59,36 @@ fn main() -> Result<(), Box<dyn Error>> { ...@@ -70,51 +59,36 @@ fn main() -> Result<(), Box<dyn Error>> {
((id, reader), (id, writer)) ((id, reader), (id, writer))
}) })
.unzip(); .unzip();
let router_writers3: Vec<(Arc<Socket>, HashMap<u8, RouterWriter>)> = router_writers
// 使用 parking_lot 的 RwLock 替换标准库的 RwLock
let router_writers = router_writers
.into_iter()
.map(|(id, writer)| {
(id, Arc::new(RwLock::new(OptimizedRouterWriter {
writer,
#[cfg(target_os = "linux")]
mark_set: AtomicBool::new(false),
})))
})
.collect::<HashMap<_, _>>();
let router_writers3: Vec<(Arc<Socket>, HashMap<u8, Arc<RwLock<OptimizedRouterWriter>>>)> = router_writers
.iter()
.map(|(id, writer)| (*id, writer.read().writer.key(), Arc::clone(writer)))
.into_iter() .into_iter()
.grouping_by(|(_, key, _)| *key) .grouping_by(|(_, v)| v.key())
.into_iter() .into_iter()
.map(|(k, v)| { .map(|(k, v)| {
( (
Arc::clone(sockets.get_mut(&k).unwrap()), Arc::clone(sockets.get_mut(&k).unwrap()),
v.into_iter().map(|(id, _, writer)| (id, writer)).collect(), v.into_iter().collect(),
) )
}) })
.collect(); .collect();
println!("created tuns"); println!("created tuns");
thread::scope(|s| { thread::scope(|s| {
// 发送线程池
for router in router_readers.values_mut() { for router in router_readers.values_mut() {
s.spawn(|_| { #[cfg(target_os = "linux")]
let mark_set = std::sync::atomic::AtomicBool::new(false);
s.spawn(move |_| {
// 使用更大的缓冲区以支持巨帧 // 使用更大的缓冲区以支持巨帧
let mut buffer = vec![0u8; 9000]; let mut buffer = vec![0u8; 9000];
let meta_size = size_of::<Meta>(); let meta_size = size_of::<Meta>();
// 预初始化 Meta 头部 // 预初始化 Meta 头部(local -> remote)
let meta = Meta { let meta = Meta {
src_id: config.local_id, src_id: config.local_id,
dst_id: router.config.remote_id, dst_id: router.config.remote_id,
reversed: 0, reversed: 0,
}; };
// 直接写入缓冲区,避免额外的切片操作
// 直接写入缓冲区,避免额外的内存分配
unsafe { unsafe {
let meta_ptr = buffer.as_mut_ptr() as *mut Meta; let meta_ptr = buffer.as_mut_ptr() as *mut Meta;
*meta_ptr = meta; *meta_ptr = meta;
...@@ -126,15 +100,15 @@ fn main() -> Result<(), Box<dyn Error>> { ...@@ -126,15 +100,15 @@ fn main() -> Result<(), Box<dyn Error>> {
// 使用 try_read 减少锁争用 // 使用 try_read 减少锁争用
if let Ok(endpoint_guard) = router.endpoint.try_read() { if let Ok(endpoint_guard) = router.endpoint.try_read() {
if let Some(ref addr) = *endpoint_guard { if let Some(ref addr) = *endpoint_guard {
// 原地加密,避免额外的内存拷贝
router.encrypt(&mut buffer[meta_size..meta_size + n]); router.encrypt(&mut buffer[meta_size..meta_size + n]);
#[cfg(target_os = "linux")] #[cfg(target_os = "linux")]
{ {
// 只在第一次或 mark 改变时设置 // 只在第一次设置 mark
if !router.mark_set.load(Ordering::Relaxed) { use std::sync::atomic::Ordering;
if !mark_set.load(Ordering::Relaxed) {
let _ = router.socket.set_mark(router.config.mark); let _ = router.socket.set_mark(router.config.mark);
router.mark_set.store(true, Ordering::Relaxed); mark_set.store(true, Ordering::Relaxed);
} }
} }
...@@ -143,76 +117,62 @@ fn main() -> Result<(), Box<dyn Error>> { ...@@ -143,76 +117,62 @@ fn main() -> Result<(), Box<dyn Error>> {
} }
} }
Err(_) => { Err(_) => {
// 忽略读取错误,继续循环 // TUN 读取失败时短暂休眠,避免 CPU 空转
continue; std::thread::sleep(std::time::Duration::from_millis(1));
} }
} }
} }
}); });
} }
// 接收线程池 for (socket, mut router_writers) in router_writers3 {
for (socket, router_writers_map) in router_writers3 {
s.spawn(move |_| { s.spawn(move |_| {
// 使用更大的缓冲区 // 使用更大的缓冲区和重用内存
let mut recv_buf = vec![MaybeUninit::uninit(); 9000]; let mut recv_buf = vec![MaybeUninit::uninit(); 9000];
let meta_size = size_of::<Meta>(); let meta_size = size_of::<Meta>();
loop { loop {
match socket.recv_from(&mut recv_buf) { match socket.recv_from(&mut recv_buf) {
Ok((len, addr)) => { Ok((len, addr)) => {
// 快速路径:直接处理数据,不创建 Ipv4Packet // 快速边界检查
if len < 20 + meta_size { if len < 20 + meta_size {
continue; // 数据包太小,跳过 continue;
} }
let data: &mut [u8] = unsafe { let data: &mut [u8] = unsafe { transmute(&mut recv_buf[..len]) };
std::slice::from_raw_parts_mut(
recv_buf.as_mut_ptr() as *mut u8,
len
)
};
// 快速获取 IP 头部长度 // 优化:直接计算 IP 头部长度,避免创建 Ipv4Packet
let header_len = ((data[0] & 0x0f) as usize) * 4; let header_len = ((data[0] & 0x0f) as usize) * 4;
if len < header_len + meta_size { if len < header_len + meta_size {
continue; // 数据包格式错误,跳过 continue;
} }
// 直接解析 Meta 结构 // 直接从内存读取 Meta,避免额外的切片操作
let meta: &Meta = unsafe { let meta: &Meta = unsafe {
&*(data[header_len..].as_ptr() as *const Meta) &*(data.as_ptr().add(header_len) as *const Meta)
}; };
if meta.dst_id == config.local_id && meta.reversed == 0 { if meta.dst_id == config.local_id && meta.reversed == 0 {
// 使用 try_read 减少锁争用 if let Some(router) = router_writers.get_mut(&meta.src_id) {
if let Some(router_lock) = router_writers_map.get(&meta.src_id) { // 使用 try_write 减少锁争用
if let Ok(mut router) = router_lock.try_write() { if let Ok(mut endpoint) = router.endpoint.try_write() {
// 更新端点地址 *endpoint = Some(addr);
if let Ok(mut endpoint) = router.writer.endpoint.try_write() {
*endpoint = Some(addr);
}
// 原地解密
let payload_start = header_len + meta_size;
let payload_len = len - payload_start;
router.writer.decrypt(
&mut data[payload_start..len],
&local_secret
);
// 写入 TUN 设备
let _ = router.writer.tun_writer.write_all(
&data[payload_start..len]
);
} }
let payload_start = header_len + meta_size;
let payload = &mut data[payload_start..];
router.decrypt(payload, &local_secret);
// 忽略写入错误,继续处理下一个数据包
let _ = router.tun_writer.write_all(payload);
} }
} }
} }
Err(_) => { Err(_) => {
// 忽略接收错误,继续循环 // Socket 接收失败时短暂休眠
continue; std::thread::sleep(std::time::Duration::from_millis(1));
} }
} }
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment