Commit d73826c4 authored by nanahira's avatar nanahira

fix

parent d59999fa
Pipeline #37450 failed with stages
in 1 minute and 50 seconds
......@@ -8,8 +8,6 @@ use std::intrinsics::transmute;
use std::io::{Read, Write};
use std::mem::MaybeUninit;
use std::sync::Arc;
use crossbeam::channel::{bounded, Sender, Receiver};
use num_cpus;
#[repr(C)]
pub struct Meta {
......@@ -38,43 +36,27 @@ pub struct Config {
pub local_secret: String,
pub routers: Vec<ConfigRouter>,
}
use crossbeam_utils::thread;
use grouping_by::GroupingBy;
use pnet::packet::ipv4::Ipv4Packet;
use socket2::Socket;
// 从 TUN 到 Socket 的任务
struct TunToSocketTask {
data: Vec<u8>,
len: usize,
meta_size: usize,
router: Arc<RouterReader>,
}
// 从 Socket 到 TUN 的任务
struct SocketToTunTask {
data: Vec<u8>,
len: usize,
addr: std::net::SocketAddr,
writers: Arc<HashMap<u8, RouterWriter>>,
local_id: u8,
local_secret: Arc<[u8; SECRET_LENGTH]>,
}
fn main() -> Result<(), Box<dyn Error>> {
// 获取系统 CPU 核心数
let num_threads = std::thread::available_parallelism()
.map(|n| n.get())
.unwrap_or(1);
println!("System has {} CPU cores available", num_threads);
let config: Config = serde_json::from_str(env::args().nth(1).ok_or("need param")?.as_str())?;
let local_secret: [u8; SECRET_LENGTH] = Router::create_secret(config.local_secret.as_str())?;
let local_secret = Arc::new(local_secret);
let mut sockets: HashMap<u16, Arc<Socket>> = HashMap::new();
let routers: HashMap<u8, Router> = config
.routers
.iter()
.map(|c| Router::new(c, &mut sockets).map(|router| (c.remote_id, router)))
.collect::<Result<_, _>>()?;
let (router_readers, router_writers): (
let (mut router_readers, router_writers): (
HashMap<u8, RouterReader>,
HashMap<u8, RouterWriter>,
) = routers
......@@ -84,155 +66,81 @@ fn main() -> Result<(), Box<dyn Error>> {
((id, reader), (id, writer))
})
.unzip();
// 将 router_readers 转换为 Arc 以便在多线程中共享
let router_readers: HashMap<u8, Arc<RouterReader>> = router_readers
.into_iter()
.map(|(id, reader)| (id, Arc::new(reader)))
.collect();
let router_writers3: Vec<(Arc<Socket>, Arc<HashMap<u8, RouterWriter>>)> = router_writers
let router_writers3: Vec<(Arc<Socket>, HashMap<u8, RouterWriter>)> = router_writers
.into_iter()
.grouping_by(|(_, v)| v.key())
.into_iter()
.map(|(k, v)| {
(
Arc::clone(sockets.get_mut(&k).unwrap()),
Arc::new(v.into_iter().collect()),
v.into_iter().collect(),
)
})
.collect();
println!("created tuns");
// 获取系统线程数
let num_threads = num_cpus::get();
println!("Using {} worker threads", num_threads);
// 创建处理通道,使用有界通道避免内存无限增长
let (tun_tx, tun_rx): (Sender<TunToSocketTask>, Receiver<TunToSocketTask>) = bounded(1000);
let (socket_tx, socket_rx): (Sender<SocketToTunTask>, Receiver<SocketToTunTask>) = bounded(1000);
thread::scope(|s| {
// 启动加密工作线程(处理 TUN -> Socket)
for _ in 0..num_threads/2 {
let tun_rx = tun_rx.clone();
s.spawn(move |_| {
while let Ok(task) = tun_rx.recv() {
let TunToSocketTask { mut data, len, meta_size, router } = task;
if let Some(ref addr) = *router.endpoint.read().unwrap() {
// 加密数据
router.encrypt(&mut data[meta_size..meta_size + len]);
// 设置 mark 并发送
#[cfg(target_os = "linux")]
let _ = router.socket.set_mark(router.config.mark);
let _ = router.socket.send_to(&data[..meta_size + len], addr);
}
}
});
}
// 启动解密工作线程(处理 Socket -> TUN)
for _ in 0..num_threads/2 {
let socket_rx = socket_rx.clone();
s.spawn(move |_| {
while let Ok(task) = socket_rx.recv() {
let SocketToTunTask { mut data, len, addr, writers, local_id, local_secret } = task;
if let Some(packet) = Ipv4Packet::new(&data[..len]) {
let header_len = packet.get_header_length() as usize * 4;
if let Some((_ip_header, rest)) = data.split_at_mut_checked(header_len) {
if let Some((meta_bytes, payload)) = rest.split_at_mut_checked(size_of::<Meta>()) {
let meta: &Meta = unsafe { transmute(meta_bytes.as_ptr()) };
if meta.dst_id == local_id && meta.reversed == 0 {
if let Some(router) = writers.get(&meta.src_id) {
*router.endpoint.write().unwrap() = Some(addr);
router.decrypt(payload, &local_secret);
let _ = router.tun_writer.write_all(payload);
}
}
}
}
}
}
});
}
// TUN 读取线程
for (router_id, router) in &router_readers {
let tun_tx = tun_tx.clone();
let router = Arc::clone(router);
let config_local_id = config.local_id;
s.spawn(move |_| {
for router in router_readers.values_mut() {
s.spawn(|_| {
let mut buffer = [0u8; 1500 - 20]; // minus typical IP header space
let meta_size = size_of::<Meta>();
loop {
let mut buffer = vec![0u8; 1500];
// 预先填充 Meta 头部
// Pre-initialize with our Meta header (local -> remote)
let meta = Meta {
src_id: config_local_id,
src_id: config.local_id,
dst_id: router.config.remote_id,
reversed: 0,
};
// Turn the Meta struct into bytes
let meta_bytes = unsafe {
std::slice::from_raw_parts(&meta as *const Meta as *const u8, meta_size)
};
buffer[..meta_size].copy_from_slice(meta_bytes);
// 从 TUN 读取数据
match router.tun_reader.read(&mut buffer[meta_size..]) {
Ok(n) => {
// 发送到工作线程处理
let _ = tun_tx.send(TunToSocketTask {
data: buffer,
len: n,
meta_size,
router: Arc::clone(&router),
});
}
Err(_) => break,
loop {
let n = router.tun_reader.read(&mut buffer[meta_size..]).unwrap();
if let Some(ref addr) = *router.endpoint.read().unwrap() {
router.encrypt(&mut buffer[meta_size..meta_size + n]);
#[cfg(target_os = "linux")]
let _ = router.socket.set_mark(router.config.mark);
let _ = router.socket.send_to(&buffer[..meta_size + n], addr);
}
}
});
}
// Socket 接收线程
for (socket, writers) in router_writers3 {
let socket_tx = socket_tx.clone();
let local_secret = Arc::clone(&local_secret);
let local_id = config.local_id;
for (socket, mut router_writers) in router_writers3 {
s.spawn(move |_| {
let mut recv_buf = vec![0u8; 1500];
let mut recv_buf = [MaybeUninit::uninit(); 1500];
loop {
match socket.recv_from(&mut recv_buf) {
Ok((len, addr)) => {
// 发送到工作线程处理
let _ = socket_tx.send(SocketToTunTask {
data: recv_buf.clone(),
len,
addr,
writers: Arc::clone(&writers),
local_id,
local_secret: Arc::clone(&local_secret),
});
}
Err(_) => continue,
let _ = (|| {
let (len, addr) = socket.recv_from(&mut recv_buf).unwrap();
let data: &mut [u8] = unsafe { transmute(&mut recv_buf[..len]) };
let packet = Ipv4Packet::new(data).ok_or("malformed packet")?;
let header_len = packet.get_header_length() as usize * 4;
let (_ip_header, rest) = data
.split_at_mut_checked(header_len)
.ok_or("malformed packet")?;
let (meta_bytes, payload) = rest
.split_at_mut_checked(size_of::<Meta>())
.ok_or("malformed packet")?;
let meta: &Meta = unsafe { transmute(meta_bytes.as_ptr()) };
if meta.dst_id == config.local_id && meta.reversed == 0 {
let router = router_writers
.get_mut(&meta.src_id)
.ok_or("missing router")?;
*router.endpoint.write().unwrap() = Some(addr);
router.decrypt(payload, &local_secret);
router.tun_writer.write_all(payload)?;
}
Ok::<(), Box<dyn Error>>(())
})();
}
});
}
})
.unwrap();
Ok(())
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment