Commit 2ec72cdb authored by nanamicat's avatar nanamicat

rust-next

parent f5ffc9fe
Pipeline #41895 passed with stages
in 2 minutes and 57 seconds
This diff is collapsed.
...@@ -4,12 +4,12 @@ version = "0.1.0" ...@@ -4,12 +4,12 @@ version = "0.1.0"
edition = "2021" edition = "2021"
[dependencies] [dependencies]
tun = "0.7" tun = "0.8"
socket2 = { version = "0.5.8", features = ["all"] } socket2 = { version = "0.6.1", features = ["all"] }
pnet = "0.35.0" pnet = "0.35.0"
serde = { version = "1.0.217", features = ["derive"] } serde = { version = "1.0.228", features = ["derive"] }
serde_json = "1.0" serde_json = "1.0"
base64 = "0.22.1" base64 = "0.22.1"
crossbeam = "0.8.4" crossbeam = "0.8.4"
crossbeam-utils = "0.8.20" crossbeam-utils = "0.8.21"
grouping_by = "0.2.2" libc = "0.2.178"
FROM rust:1.84-alpine3.21 as builder FROM rust:1.91.1-alpine3.22 as builder
RUN apk add --no-cache musl-dev RUN apk add --no-cache musl-dev
WORKDIR /usr/src/app WORKDIR /usr/src/app
...@@ -12,7 +12,7 @@ RUN mkdir src && \ ...@@ -12,7 +12,7 @@ RUN mkdir src && \
COPY src src COPY src src
RUN cargo build --release RUN cargo build --release
FROM alpine:3.21 FROM alpine:3.22
RUN apk --no-cache add libgcc libstdc++ bash iproute2 iptables iptables-legacy ipset netcat-openbsd jq RUN apk --no-cache add libgcc libstdc++ bash iproute2 iptables iptables-legacy ipset netcat-openbsd jq
COPY --from=builder /usr/src/app/target/release/tun1 /usr/local/bin/tun COPY --from=builder /usr/src/app/target/release/tun1 /usr/local/bin/tun
COPY ./entrypoint.sh /entrypoint.sh COPY ./entrypoint.sh /entrypoint.sh
......
mod router; mod router;
use crate::router::{Router, SECRET_LENGTH};
use crate::router::{Router, RouterReader, RouterWriter, SECRET_LENGTH};
use std::collections::HashMap;
use std::env; use std::env;
use std::error::Error; use std::error::Error;
use std::intrinsics::transmute;
use std::io::{Read, Write}; use std::io::{Read, Write};
use std::mem::MaybeUninit; use std::mem::size_of;
use std::sync::Arc; use std::mem::{transmute, MaybeUninit};
use std::sync::atomic::Ordering;
#[repr(C)] #[repr(C)]
pub struct Meta { pub struct Meta {
...@@ -18,10 +16,16 @@ pub struct Meta { ...@@ -18,10 +16,16 @@ pub struct Meta {
use serde::Deserialize; use serde::Deserialize;
#[derive(Deserialize)]
pub struct Config {
pub local_id: u8,
pub local_secret: String,
pub routers: Vec<ConfigRouter>,
}
#[derive(Deserialize)] #[derive(Deserialize)]
pub struct ConfigRouter { pub struct ConfigRouter {
pub remote_id: u8, pub remote_id: u8,
pub proto: i32, pub proto: u8,
pub family: u8, pub family: u8,
pub mark: u32, pub mark: u32,
pub endpoint: String, pub endpoint: String,
...@@ -30,85 +34,55 @@ pub struct ConfigRouter { ...@@ -30,85 +34,55 @@ pub struct ConfigRouter {
pub up: String, pub up: String,
} }
#[derive(Deserialize)]
pub struct Config {
pub local_id: u8,
pub local_secret: String,
pub routers: Vec<ConfigRouter>,
}
use crossbeam_utils::thread; use crossbeam_utils::thread;
use grouping_by::GroupingBy;
use pnet::packet::ipv4::Ipv4Packet; use pnet::packet::ipv4::Ipv4Packet;
use socket2::Socket;
fn main() -> Result<(), Box<dyn Error>> { fn main() -> Result<(), Box<dyn Error>> {
let config: Config = serde_json::from_str(env::args().nth(1).ok_or("need param")?.as_str())?; let config: Config = serde_json::from_str(env::args().nth(1).ok_or("need param")?.as_str())?;
let local_secret: [u8; SECRET_LENGTH] = Router::create_secret(config.local_secret.as_str())?; let local_secret: [u8; SECRET_LENGTH] = Router::create_secret(config.local_secret.as_str())?;
let mut sockets: HashMap<u16, Arc<Socket>> = HashMap::new(); let routers: Vec<Router> = config
let routers: HashMap<u8, Router> = config
.routers .routers
.iter()
.map(|c| Router::new(c, &mut sockets).map(|router| (c.remote_id, router)))
.collect::<Result<_, _>>()?;
let (mut router_readers, router_writers): (
HashMap<u8, RouterReader>,
HashMap<u8, RouterWriter>,
) = routers
.into_iter() .into_iter()
.map(|(id, router)| { .map(|c| Router::new(c, config.local_id))
let (reader, writer) = router.split(); .collect::<Result<Vec<_>, _>>()?;
((id, reader), (id, writer))
})
.unzip();
let router_writers3: Vec<(Arc<Socket>, HashMap<u8, RouterWriter>)> = router_writers
.into_iter()
.grouping_by(|(_, v)| v.key())
.into_iter()
.map(|(k, v)| {
(
Arc::clone(sockets.get_mut(&k).unwrap()),
v.into_iter().collect(),
)
})
.collect();
println!("created tuns"); println!("created tuns");
thread::scope(|s| { thread::scope(|s| {
for router in router_readers.values_mut() { for router in routers {
s.spawn(|_| { let (mut reader, mut writer) = router.split();
let mut buffer = [0u8; 1500 - 20]; // minus typical IP header space
let meta_size = size_of::<Meta>();
s.spawn(move |_| {
let mut buffer = [0u8; 1500 - 20]; // minus typical IP header space
const META_SIZE: usize = size_of::<Meta>();
// Pre-initialize with our Meta header (local -> remote) // Pre-initialize with our Meta header (local -> remote)
let meta = Meta { let meta = Meta {
src_id: config.local_id, src_id: config.local_id,
dst_id: router.config.remote_id, dst_id: reader.config.remote_id,
reversed: 0, reversed: 0,
}; };
// Turn the Meta struct into bytes // Turn the Meta struct into bytes
let meta_bytes = unsafe { let meta_bytes: &[u8; META_SIZE] =
std::slice::from_raw_parts(&meta as *const Meta as *const u8, meta_size) unsafe { &*(&meta as *const Meta as *const [u8; META_SIZE]) };
}; buffer[..META_SIZE].copy_from_slice(meta_bytes);
buffer[..meta_size].copy_from_slice(meta_bytes);
loop { loop {
let n = router.tun_reader.read(&mut buffer[meta_size..]).unwrap(); let n = reader.tun_reader.read(&mut buffer[META_SIZE..]).unwrap();
if let Some(ref addr) = *router.endpoint.read().unwrap() { let guard = crossbeam::epoch::pin();
router.encrypt(&mut buffer[meta_size..meta_size + n]); let shared = reader.endpoint.load(Ordering::Acquire, &guard);
#[cfg(target_os = "linux")] if let Some(addr) = unsafe { shared.as_ref() } {
let _ = router.socket.set_mark(router.config.mark); reader.encrypt(&mut buffer[META_SIZE..META_SIZE + n]);
let _ = router.socket.send_to(&buffer[..meta_size + n], addr); let _ = reader.socket.send_to(&buffer[..META_SIZE + n], addr);
} }
} }
}); });
}
for (socket, mut router_writers) in router_writers3 {
s.spawn(move |_| { s.spawn(move |_| {
let mut recv_buf = [MaybeUninit::uninit(); 1500]; let mut recv_buf = [MaybeUninit::uninit(); 1500];
loop { loop {
let _ = (|| { let _ = (|| {
let (len, addr) = socket.recv_from(&mut recv_buf).unwrap(); let (len, addr) = writer.socket.recv_from(&mut recv_buf).unwrap();
let data: &mut [u8] = unsafe { transmute(&mut recv_buf[..len]) }; let data: &mut [u8] = unsafe { transmute(&mut recv_buf[..len]) };
let packet = Ipv4Packet::new(data).ok_or("malformed packet")?; let packet = Ipv4Packet::new(data).ok_or("malformed packet")?;
...@@ -119,16 +93,24 @@ fn main() -> Result<(), Box<dyn Error>> { ...@@ -119,16 +93,24 @@ fn main() -> Result<(), Box<dyn Error>> {
let (meta_bytes, payload) = rest let (meta_bytes, payload) = rest
.split_at_mut_checked(size_of::<Meta>()) .split_at_mut_checked(size_of::<Meta>())
.ok_or("malformed packet")?; .ok_or("malformed packet")?;
let meta: &Meta = unsafe { transmute(meta_bytes.as_ptr()) };
if meta.dst_id == config.local_id && meta.reversed == 0 { let guard = crossbeam::epoch::pin();
let router = router_writers let current_shared = writer.endpoint.load(Ordering::SeqCst, &guard);
.get_mut(&meta.src_id) let is_same = unsafe { current_shared.as_ref() }
.ok_or("missing router")?; .map(|c| *c == addr)
*router.endpoint.write().unwrap() = Some(addr); .unwrap_or(false);
router.decrypt(payload, &local_secret); if !is_same {
router.tun_writer.write_all(payload)?; let new_shared = crossbeam::epoch::Owned::new(addr).into_shared(&guard);
let old_shared =
writer.endpoint.swap(new_shared, Ordering::SeqCst, &guard);
unsafe {
guard.defer_destroy(old_shared);
}
} }
writer.decrypt(payload, &local_secret);
writer.tun_writer.write_all(payload)?;
Ok::<(), Box<dyn Error>>(()) Ok::<(), Box<dyn Error>>(())
})(); })();
} }
......
use socket2::{Domain, Protocol, SockAddr, Socket, Type}; use base64::prelude::BASE64_STANDARD;
use std::collections::hash_map::Entry; use base64::Engine;
use std::collections::HashMap; use socket2::{Domain, Protocol, SockAddr, SockFilter, Socket, Type};
use std::net::ToSocketAddrs; use std::net::ToSocketAddrs;
use std::process::{Command, ExitStatus}; use std::process::{Command, ExitStatus};
use std::sync::{Arc, RwLock}; use std::sync::Arc;
use tun::{Reader, Writer}; use tun::{Reader, Writer};
pub const SECRET_LENGTH: usize = 32; pub const SECRET_LENGTH: usize = 32;
use crate::ConfigRouter; use crate::{ConfigRouter, Meta};
use base64::prelude::*; use crossbeam::epoch::Atomic;
// tun -> raw // tun -> raw
pub struct RouterReader<'a> { pub struct RouterReader {
pub config: &'a ConfigRouter, pub config: ConfigRouter,
pub secret: [u8; SECRET_LENGTH], pub secret: [u8; SECRET_LENGTH],
pub tun_reader: Reader, pub tun_reader: Reader,
pub socket: Arc<Socket>, pub socket: Arc<Socket>,
pub endpoint: Arc<RwLock<Option<SockAddr>>>, pub endpoint: Arc<Atomic<SockAddr>>,
} }
impl<'a> RouterReader<'a> { impl RouterReader {
pub(crate) fn encrypt(&self, data: &mut [u8]) { pub(crate) fn encrypt(&self, data: &mut [u8]) {
for (i, b) in data.iter_mut().enumerate() { for (i, b) in data.iter_mut().enumerate() {
*b ^= self.secret[i % SECRET_LENGTH]; *b ^= self.secret[i % SECRET_LENGTH];
...@@ -27,34 +27,30 @@ impl<'a> RouterReader<'a> { ...@@ -27,34 +27,30 @@ impl<'a> RouterReader<'a> {
} }
// raw -> tun // raw -> tun
pub struct RouterWriter<'a> { pub struct RouterWriter {
pub config: &'a ConfigRouter,
pub tun_writer: Writer, pub tun_writer: Writer,
pub endpoint: Arc<RwLock<Option<SockAddr>>>, pub socket: Arc<Socket>,
pub endpoint: Arc<Atomic<SockAddr>>,
} }
impl<'a> RouterWriter<'a> { impl RouterWriter {
pub(crate) fn decrypt(&self, data: &mut [u8], secret: &[u8; SECRET_LENGTH]) { pub(crate) fn decrypt(&self, data: &mut [u8], secret: &[u8; SECRET_LENGTH]) {
for (i, b) in data.iter_mut().enumerate() { for (i, b) in data.iter_mut().enumerate() {
*b ^= secret[i % SECRET_LENGTH]; *b ^= secret[i % SECRET_LENGTH];
} }
} }
pub(crate) fn key(&self) -> u16 {
Router::key(self.config)
}
} }
pub struct Router<'a> { pub struct Router {
pub config: &'a ConfigRouter, pub config: ConfigRouter,
pub secret: [u8; SECRET_LENGTH], pub secret: [u8; SECRET_LENGTH],
pub tun_reader: Reader, pub tun_reader: Reader,
pub tun_writer: Writer, pub tun_writer: Writer,
pub socket: Arc<Socket>, pub socket: Arc<Socket>,
pub endpoint: Arc<RwLock<Option<SockAddr>>>, pub endpoint: Arc<Atomic<SockAddr>>,
} }
impl<'a> Router<'a> { impl Router {
pub(crate) fn create_secret( pub(crate) fn create_secret(
config: &str, config: &str,
) -> Result<[u8; SECRET_LENGTH], Box<dyn std::error::Error>> { ) -> Result<[u8; SECRET_LENGTH], Box<dyn std::error::Error>> {
...@@ -65,30 +61,78 @@ impl<'a> Router<'a> { ...@@ -65,30 +61,78 @@ impl<'a> Router<'a> {
Ok(secret) Ok(secret)
} }
fn key(config: &ConfigRouter) -> u16 {
(config.family as u16) << 8 | config.proto as u16
}
fn create_raw_socket( fn create_raw_socket(
config: &ConfigRouter, config: &ConfigRouter,
sockets: &mut HashMap<u16, Arc<Socket>>, local_id: u8,
) -> Result<Arc<Socket>, Box<dyn std::error::Error>> { ) -> Result<Arc<Socket>, std::io::Error> {
let key = Router::key(config); let result = Socket::new(
let result = match sockets.entry(key) {
Entry::Occupied(entry) => entry.get().clone(),
Entry::Vacant(entry) => entry
.insert(Arc::new(Socket::new(
if config.family == 6 { if config.family == 6 {
Domain::IPV6 Domain::IPV6
} else { } else {
Domain::IPV4 Domain::IPV4
}, },
Type::RAW, Type::RAW,
Some(Protocol::from(config.proto)), Some(Protocol::from(config.proto as i32)),
)?)) )?;
.clone(), #[cfg(target_os = "linux")]
result.set_mark(config.mark)?;
Self::attach_readable_filter(config, local_id, &result)?;
Ok(Arc::new(result))
}
fn attach_readable_filter(
config: &ConfigRouter,
local_id: u8,
socket: &Socket,
) -> std::io::Result<()> {
// 由于多个对端可能会使用相同的 ipprpto 号,这里确保每个 socket 上只会收到自己对应的对端发来的消息
const META_SIZE: usize = size_of::<Meta>();
let meta = Meta {
src_id: config.remote_id,
dst_id: local_id,
reversed: 0,
}; };
Ok(result) let meta_bytes: [u8; META_SIZE] =
unsafe { *(&meta as *const Meta as *const [u8; META_SIZE]) };
let target_val = u32::from_be_bytes(meta_bytes);
// 如果你的协议是 UDP,这里必须是 8 (跳过 UDP 头: SrcPort(2)+DstPort(2)+Len(2)+Sum(2))
// 如果是纯自定义 IP 协议,这里是 0
let payload_offset = 0;
socket.attach_filter(&[
// 1. 【加载 IP 头长度到寄存器 X】
// BPF_LDX (Load X) | BPF_B (Byte) | BPF_MSH (Magic Shift)
// 这是一个特殊的 BPF 指令,专门用于计算 IPv4 头长度:X = 4 * (IP[0] & 0xf)
bpf_stmt(libc::BPF_LDX | libc::BPF_B | libc::BPF_MSH, 0),
// 2. 【读取 Payload 数据到累加器 A】
// BPF_LD (Load) | BPF_W (Word, 4 bytes) | BPF_IND (Indirect, relative to X)
// 逻辑:A = Packet[X + k]
bpf_stmt(libc::BPF_LD | libc::BPF_W | libc::BPF_IND, payload_offset),
// 3. 【比较并跳转】
// BPF_JMP (Jump) | BPF_JEQ (Jump if Equal) | BPF_K (Const)
// 逻辑:if (A == target_val) goto True(0); else goto False(1);
// jt=0: 继续执行下一条
// jf=1: 跳过下一条 (直接跳到 Reject)
bpf_jump(
libc::BPF_JMP | libc::BPF_JEQ | libc::BPF_K,
target_val,
0,
1,
),
// 4. 【接受 (True 路径)】
// BPF_RET (Return) | BPF_K (Constant)
// 返回 -1 (0xFFFFFFFF) 表示截取整个包的最大长度(即接收包)
bpf_stmt(libc::BPF_RET | libc::BPF_K, u32::MAX),
// 5. 【拒绝 (False 路径)】
// BPF_RET (Return) | BPF_K (Constant)
// 返回 0 表示截取 0 字节(即丢弃包)
bpf_stmt(libc::BPF_RET | libc::BPF_K, 0),
])?;
Ok(())
} }
fn create_tun_device( fn create_tun_device(
config: &ConfigRouter, config: &ConfigRouter,
) -> Result<(Reader, Writer), Box<dyn std::error::Error>> { ) -> Result<(Reader, Writer), Box<dyn std::error::Error>> {
...@@ -104,21 +148,18 @@ impl<'a> Router<'a> { ...@@ -104,21 +148,18 @@ impl<'a> Router<'a> {
fn create_endpoint( fn create_endpoint(
config: &ConfigRouter, config: &ConfigRouter,
) -> Result<Arc<RwLock<Option<SockAddr>>>, Box<dyn std::error::Error>> { ) -> Result<Arc<Atomic<SockAddr>>, Box<dyn std::error::Error>> {
let parsed = (config.endpoint.clone(), 0u16) let parsed = (config.endpoint.clone(), 0u16)
.to_socket_addrs()? .to_socket_addrs()?
.next() .next()
.ok_or(config.endpoint.clone())?; .ok_or(config.endpoint.clone())?;
Ok(Arc::new(RwLock::new(Some(parsed.into())))) Ok(Arc::new(Atomic::new(parsed.into())))
} }
pub fn new( pub fn new(config: ConfigRouter, local_id: u8) -> Result<Router, Box<dyn std::error::Error>> {
config: &'a ConfigRouter,
sockets: &mut HashMap<u16, Arc<Socket>>,
) -> Result<Router<'a>, Box<dyn std::error::Error>> {
let secret = Self::create_secret(config.remote_secret.as_str())?; let secret = Self::create_secret(config.remote_secret.as_str())?;
let endpoint = Self::create_endpoint(&config)?; let endpoint = Self::create_endpoint(&config)?;
let socket = Self::create_raw_socket(&config, sockets)?; let socket = Self::create_raw_socket(&config, local_id)?;
let (tun_reader, tun_writer) = Self::create_tun_device(&config)?; let (tun_reader, tun_writer) = Self::create_tun_device(&config)?;
Self::run_up_script(&config)?; Self::run_up_script(&config)?;
...@@ -134,11 +175,11 @@ impl<'a> Router<'a> { ...@@ -134,11 +175,11 @@ impl<'a> Router<'a> {
Ok(router) Ok(router)
} }
pub fn split(self) -> (RouterReader<'a>, RouterWriter<'a>) { pub fn split(self) -> (RouterReader, RouterWriter) {
let writer = RouterWriter { let writer = RouterWriter {
config: self.config, endpoint: self.endpoint.clone(),
endpoint: Arc::clone(&self.endpoint),
tun_writer: self.tun_writer, tun_writer: self.tun_writer,
socket: self.socket.clone(),
}; };
let reader = RouterReader { let reader = RouterReader {
...@@ -152,3 +193,11 @@ impl<'a> Router<'a> { ...@@ -152,3 +193,11 @@ impl<'a> Router<'a> {
(reader, writer) (reader, writer)
} }
} }
fn bpf_stmt(code: u32, k: u32) -> SockFilter {
SockFilter::new(code as u16, 0, 0, k)
}
fn bpf_jump(code: u32, k: u32, jt: u8, jf: u8) -> SockFilter {
SockFilter::new(code as u16, jt, jf, k)
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment