Commit 653c74fb authored by nanahira's avatar nanahira

fix crlf

parent 4a3859fc
Pipeline #42032 passed with stages
in 1 minute and 42 seconds
/.idea/
/target
/.idea/
/target
stages:
- build
- deploy
variables:
GIT_DEPTH: "1"
before_script:
- docker login -u $CI_REGISTRY_USER -p $CI_REGISTRY_PASSWORD $CI_REGISTRY
.build-image:
stage: build
script:
- docker build --pull -t $TARGET_IMAGE .
- docker push $TARGET_IMAGE
build-x86:
extends: .build-image
tags:
- docker
variables:
TARGET_IMAGE: $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_SLUG-x86
build-arm:
extends: .build-image
tags:
- docker-arm
variables:
TARGET_IMAGE: $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_SLUG-arm
.deploy:
stage: deploy
tags:
- docker
script:
- docker pull $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_SLUG-x86
- docker pull $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_SLUG-arm
- docker manifest create $TARGET_IMAGE --amend $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_SLUG-x86 --amend
$CI_REGISTRY_IMAGE:$CI_COMMIT_REF_SLUG-arm
- docker manifest push $TARGET_IMAGE
deploy_latest:
extends: .deploy
variables:
TARGET_IMAGE: $CI_REGISTRY_IMAGE:latest
only:
- master
deploy_branch:
extends: .deploy
variables:
TARGET_IMAGE: $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_SLUG
stages:
- build
- deploy
variables:
GIT_DEPTH: "1"
before_script:
- docker login -u $CI_REGISTRY_USER -p $CI_REGISTRY_PASSWORD $CI_REGISTRY
.build-image:
stage: build
script:
- docker build --pull -t $TARGET_IMAGE .
- docker push $TARGET_IMAGE
build-x86:
extends: .build-image
tags:
- docker
variables:
TARGET_IMAGE: $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_SLUG-x86
build-arm:
extends: .build-image
tags:
- docker-arm
variables:
TARGET_IMAGE: $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_SLUG-arm
.deploy:
stage: deploy
tags:
- docker
script:
- docker pull $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_SLUG-x86
- docker pull $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_SLUG-arm
- docker manifest create $TARGET_IMAGE --amend $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_SLUG-x86 --amend
$CI_REGISTRY_IMAGE:$CI_COMMIT_REF_SLUG-arm
- docker manifest push $TARGET_IMAGE
deploy_latest:
extends: .deploy
variables:
TARGET_IMAGE: $CI_REGISTRY_IMAGE:latest
only:
- master
deploy_branch:
extends: .deploy
variables:
TARGET_IMAGE: $CI_REGISTRY_IMAGE:$CI_COMMIT_REF_SLUG
FROM rust:1.91-alpine3.22 as chef
RUN apk add --no-cache musl-dev
RUN cargo install cargo-chef
WORKDIR /usr/src/app
FROM chef as planner
COPY Cargo.toml Cargo.lock ./
COPY src src
RUN cargo chef prepare --recipe-path recipe.json
FROM chef as builder
COPY --from=planner /usr/src/app/recipe.json recipe.json
RUN cargo chef cook --release --recipe-path recipe.json
COPY Cargo.toml Cargo.lock ./
COPY src src
RUN cargo build --release
FROM alpine:3.22
RUN apk --no-cache add libgcc libstdc++ bash iproute2 iptables iptables-legacy ipset netcat-openbsd jq iputils
COPY --from=builder /usr/src/app/target/release/tun1 /usr/local/bin/tun
COPY ./entrypoint.sh /entrypoint.sh
ENTRYPOINT ["/entrypoint.sh"]
CMD ["tun"]
FROM rust:1.91-alpine3.22 as chef
RUN apk add --no-cache musl-dev
RUN cargo install cargo-chef
WORKDIR /usr/src/app
FROM chef as planner
COPY Cargo.toml Cargo.lock ./
COPY src src
RUN cargo chef prepare --recipe-path recipe.json
FROM chef as builder
COPY --from=planner /usr/src/app/recipe.json recipe.json
RUN cargo chef cook --release --recipe-path recipe.json
COPY Cargo.toml Cargo.lock ./
COPY src src
RUN cargo build --release
FROM alpine:3.22
RUN apk --no-cache add libgcc libstdc++ bash iproute2 iptables iptables-legacy ipset netcat-openbsd jq iputils
COPY --from=builder /usr/src/app/target/release/tun1 /usr/local/bin/tun
COPY ./entrypoint.sh /entrypoint.sh
ENTRYPOINT ["/entrypoint.sh"]
CMD ["tun"]
mod config;
mod router;
use crate::config::{Config, Schema};
use crate::router::{Meta, Router, META_SIZE};
use anyhow::anyhow;
use anyhow::{Context, Result};
use crossbeam::epoch::{pin, Owned};
use crossbeam_utils::thread;
use itertools::Itertools;
use std::collections::BTreeMap;
use std::net::Shutdown;
use std::sync::atomic::Ordering;
use std::time::Duration;
use std::{env, mem::MaybeUninit};
use std::fs;
fn main() -> Result<()> {
println!("Starting");
let args: Vec<String> = env::args().collect();
if args.len() < 2 {
return Err(anyhow!("need JSON string or -c <config.json>"));
}
let config: Config;
if args[1] == "-c" || args[1] == "--config" {
// 从文件读
if args.len() < 3 {
return Err(anyhow!("missing value for -c/--config"));
}
let data = fs::read_to_string(&args[2])?;
config = serde_json::from_str(&data)?;
} else {
// 当作 JSON 字符串解析
config = serde_json::from_str(&args[1])?;
}
println!("Read config");
let routers = &config
.routers
.into_iter()
.sorted_by_key(|r| r.remote_id)
.map(|c| Router::new(c).map(|router| (router.config.remote_id, router)))
.collect::<Result<BTreeMap<u32, Router>, _>>()?;
for (_, group) in &routers
.values()
.filter(|r| r.config.schema == Schema::UDP && r.config.src_port != 0)
.chunk_by(|r| r.config.src_port)
{
Router::attach_filter_udp(group.collect())?;
}
println!("created tuns");
const TCP_RECONNECT: u64 = 10;
thread::scope(|s| {
// IP, UDP
for router in routers.values().filter(|&r| r.config.schema != Schema::TCP) {
s.spawn(|_| {
router.handle_outbound_ip_udp();
});
s.spawn(|_| {
router.handle_inbound_ip_udp();
});
}
for router in routers.values().filter(|&r| r.config.schema == Schema::TCP && r.config.dst_port != 0) {
s.spawn(|_| {
loop {
if let Ok(connection) = router.connect_tcp() {
let _ = thread::scope(|s| {
s.spawn(|_| router.handle_outbound_tcp(&connection));
s.spawn(|_| router.handle_inbound_tcp(&connection));
});
}
std::thread::sleep(Duration::from_secs(TCP_RECONNECT));
}
});
}
// tcp listeners
for router in routers
.values()
.filter(|&r| r.config.schema == Schema::TCP && r.config.dst_port == 0)
.unique_by(|r| r.config.src_port)
{
println!("listen on port {}", router.config.src_port);
let socket = router.listen_tcp();
s.spawn(move |s| {
// listen 或 accept 出错直接 panic
loop {
let (connection, _) = socket.accept().unwrap();
s.spawn(move |_| {
connection.set_tcp_nodelay(true).unwrap();
let mut meta_bytes = [MaybeUninit::uninit(); META_SIZE];
Router::recv_exact_tcp(&connection, &mut meta_bytes).unwrap();
let meta: &Meta = Meta::from_bytes(&meta_bytes);
if let Some(router) = routers.get(&meta.src_id)
&& meta.dst_id == router.config.local_id
{
// let connection = Arc::new(connection);
// tcp listener 只许一个连接,过来新连接就把前一个关掉。
{
let guard = pin();
let new_shared = Owned::new(connection).into_shared(&guard);
let old_shared = router.tcp_listener_connection.swap(new_shared, Ordering::AcqRel, &guard);
//
// SAFETY: this is guaranteed to still point to valid connection because
// the guard is swapped with AcqRel so we are for sure tracked by the pin
// list
//
if let Some(old) = unsafe { old_shared.as_ref() } {
let _ = old.shutdown(Shutdown::Both);
// SAFETY: At this point old_shared is guaranteed
// to be non-null (above if let checks that)
// And since it is already swapped out of the
// `tcp_listener_connection` no other thread
// should have access to it.
unsafe {
guard.defer_destroy(old_shared);
}
}
}
let _ = thread::scope(|s| {
s.spawn(|_| {
let guard = pin();
let shared = router.tcp_listener_connection.load(Ordering::Acquire, &guard);
// SAFETY: tcp_listener_connection shoud always either point to null or some valid thing
if let Some(connection) = unsafe { shared.as_ref() } {
router.handle_outbound_tcp(connection);
}
});
s.spawn(|_| {
let guard = pin();
let shared = router.tcp_listener_connection.load(Ordering::Acquire, &guard);
// SAFETY: tcp_listener_connection shoud always either point to null or some valid thing
if let Some(connection) = unsafe { shared.as_ref() } {
router.handle_inbound_tcp(&connection);
}
});
});
}
});
}
});
}
})
.unwrap();
Ok(())
}
mod config;
mod router;
use crate::config::{Config, Schema};
use crate::router::{Meta, Router, META_SIZE};
use anyhow::anyhow;
use anyhow::{Context, Result};
use crossbeam::epoch::{pin, Owned};
use crossbeam_utils::thread;
use itertools::Itertools;
use std::collections::BTreeMap;
use std::net::Shutdown;
use std::sync::atomic::Ordering;
use std::time::Duration;
use std::{env, mem::MaybeUninit};
use std::fs;
fn main() -> Result<()> {
println!("Starting");
let args: Vec<String> = env::args().collect();
if args.len() < 2 {
return Err(anyhow!("need JSON string or -c <config.json>"));
}
let config: Config;
if args[1] == "-c" || args[1] == "--config" {
// 从文件读
if args.len() < 3 {
return Err(anyhow!("missing value for -c/--config"));
}
let data = fs::read_to_string(&args[2])?;
config = serde_json::from_str(&data)?;
} else {
// 当作 JSON 字符串解析
config = serde_json::from_str(&args[1])?;
}
println!("Read config");
let routers = &config
.routers
.into_iter()
.sorted_by_key(|r| r.remote_id)
.map(|c| Router::new(c).map(|router| (router.config.remote_id, router)))
.collect::<Result<BTreeMap<u32, Router>, _>>()?;
for (_, group) in &routers
.values()
.filter(|r| r.config.schema == Schema::UDP && r.config.src_port != 0)
.chunk_by(|r| r.config.src_port)
{
Router::attach_filter_udp(group.collect())?;
}
println!("created tuns");
const TCP_RECONNECT: u64 = 10;
thread::scope(|s| {
// IP, UDP
for router in routers.values().filter(|&r| r.config.schema != Schema::TCP) {
s.spawn(|_| {
router.handle_outbound_ip_udp();
});
s.spawn(|_| {
router.handle_inbound_ip_udp();
});
}
for router in routers.values().filter(|&r| r.config.schema == Schema::TCP && r.config.dst_port != 0) {
s.spawn(|_| {
loop {
if let Ok(connection) = router.connect_tcp() {
let _ = thread::scope(|s| {
s.spawn(|_| router.handle_outbound_tcp(&connection));
s.spawn(|_| router.handle_inbound_tcp(&connection));
});
}
std::thread::sleep(Duration::from_secs(TCP_RECONNECT));
}
});
}
// tcp listeners
for router in routers
.values()
.filter(|&r| r.config.schema == Schema::TCP && r.config.dst_port == 0)
.unique_by(|r| r.config.src_port)
{
println!("listen on port {}", router.config.src_port);
let socket = router.listen_tcp();
s.spawn(move |s| {
// listen 或 accept 出错直接 panic
loop {
let (connection, _) = socket.accept().unwrap();
s.spawn(move |_| {
connection.set_tcp_nodelay(true).unwrap();
let mut meta_bytes = [MaybeUninit::uninit(); META_SIZE];
Router::recv_exact_tcp(&connection, &mut meta_bytes).unwrap();
let meta: &Meta = Meta::from_bytes(&meta_bytes);
if let Some(router) = routers.get(&meta.src_id)
&& meta.dst_id == router.config.local_id
{
// let connection = Arc::new(connection);
// tcp listener 只许一个连接,过来新连接就把前一个关掉。
{
let guard = pin();
let new_shared = Owned::new(connection).into_shared(&guard);
let old_shared = router.tcp_listener_connection.swap(new_shared, Ordering::AcqRel, &guard);
//
// SAFETY: this is guaranteed to still point to valid connection because
// the guard is swapped with AcqRel so we are for sure tracked by the pin
// list
//
if let Some(old) = unsafe { old_shared.as_ref() } {
let _ = old.shutdown(Shutdown::Both);
// SAFETY: At this point old_shared is guaranteed
// to be non-null (above if let checks that)
// And since it is already swapped out of the
// `tcp_listener_connection` no other thread
// should have access to it.
unsafe {
guard.defer_destroy(old_shared);
}
}
}
let _ = thread::scope(|s| {
s.spawn(|_| {
let guard = pin();
let shared = router.tcp_listener_connection.load(Ordering::Acquire, &guard);
// SAFETY: tcp_listener_connection shoud always either point to null or some valid thing
if let Some(connection) = unsafe { shared.as_ref() } {
router.handle_outbound_tcp(connection);
}
});
s.spawn(|_| {
let guard = pin();
let shared = router.tcp_listener_connection.load(Ordering::Acquire, &guard);
// SAFETY: tcp_listener_connection shoud always either point to null or some valid thing
if let Some(connection) = unsafe { shared.as_ref() } {
router.handle_inbound_tcp(&connection);
}
});
});
}
});
}
});
}
})
.unwrap();
Ok(())
}
use anyhow::{bail, ensure, Result};
use socket2::{Domain, Protocol, SockAddr, SockFilter, Socket, Type};
use std::net::Shutdown;
use std::{
ffi::c_void,
mem::MaybeUninit,
net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6, ToSocketAddrs},
ops::Range,
os::fd::{AsRawFd, FromRawFd},
process::{Command, ExitStatus},
sync::atomic::Ordering,
};
use tun::Device;
use crate::config::{ConfigRouter, Schema};
use crossbeam::epoch::{pin, Atomic};
use libc::{
setsockopt, sock_filter, sock_fprog, socklen_t, BPF_ABS, BPF_B, BPF_IND, BPF_JEQ, BPF_JMP, BPF_K, BPF_LD, BPF_LDX, BPF_MEM, BPF_MSH, BPF_RET, BPF_ST, BPF_W,
MSG_FASTOPEN, SOL_SOCKET, SO_ATTACH_REUSEPORT_CBPF,
};
pub const SECRET_LENGTH: usize = 32;
pub const META_SIZE: usize = size_of::<Meta>();
#[repr(C)]
#[derive(Debug, Clone, Copy, Default)]
pub struct Meta {
pub src_id: u32,
pub dst_id: u32,
}
impl Meta {
pub fn as_bytes(&self) -> &[u8; META_SIZE] {
unsafe { &*(self as *const Meta as *const [u8; META_SIZE]) }
}
pub fn from_bytes(bytes: &[MaybeUninit<u8>; META_SIZE]) -> &Meta {
unsafe { &*(bytes.as_ptr() as *const Meta) }
}
}
pub struct Router {
pub config: ConfigRouter,
pub tun: Device,
pub socket: Socket,
pub endpoint: Atomic<SockAddr>,
pub tcp_listener_connection: Atomic<Socket>,
}
#[inline]
fn xor_with_secret_offset<const L: usize>(data: &mut [u8], secret: &[u8; L], offset: usize) {
let len = data.len();
if len == 0 { return; }
let mut i = 0;
let mut key_pos = offset % L;
// 1) 先把 key_pos 补到 0(也就是把相位对齐到块边界),这样后面能走“完整块”
if key_pos != 0 {
let head = (L - key_pos).min(len);
for j in 0..head {
data[j] ^= secret[key_pos + j]; // 这里不会越界,因为 j < L - key_pos
}
i += head;
key_pos = 0;
}
// 2) 处理完整块(key_pos 已经对齐到 0)
while i + L <= len {
for j in 0..L {
data[i + j] ^= secret[j];
}
i += L;
}
// 3) 处理尾部
for j in 0..(len - i) {
data[i + j] ^= secret[j];
}
}
impl Router {
pub(crate) fn decrypt(&self, data: &mut [u8], secret: &[u8; SECRET_LENGTH]) {
xor_with_secret_offset::<SECRET_LENGTH>(data, secret, 0);
}
pub(crate) fn decrypt2(
&self,
data: &mut [u8],
secret: &[u8; SECRET_LENGTH],
range: Range<usize>,
) {
xor_with_secret_offset::<SECRET_LENGTH>(&mut data[range.clone()], secret, range.start);
}
pub(crate) fn encrypt(&self, data: &mut [u8]) {
xor_with_secret_offset::<SECRET_LENGTH>(data, &self.config.remote_secret, 0);
}
pub fn create_socket(config: &ConfigRouter) -> Result<Socket> {
println!("create_socket {}", config.remote_id);
match config.schema {
Schema::IP => {
let result = Socket::new(config.family, Type::RAW, Some(Protocol::from(config.proto as i32)))?;
if config.mark != 0 {
result.set_mark(config.mark)?;
}
Self::attach_filter_ip(config, &result)?;
Ok(result)
}
Schema::UDP => {
let result = Socket::new(config.family, Type::DGRAM, Some(Protocol::UDP))?;
if config.mark != 0 {
result.set_mark(config.mark)?;
}
if config.src_port != 0 {
result.set_reuse_port(true)?;
let addr = Self::bind_addr(config);
result.bind(&addr)?;
}
Ok(result)
}
Schema::TCP => Ok(unsafe { Socket::from_raw_fd(0) }),
}
}
pub fn listen_tcp(&self) -> Socket {
// listener
let result = Socket::new(Domain::IPV6, Type::STREAM, Some(Protocol::TCP)).unwrap();
result.set_reuse_address(true).unwrap();
let addr = SockAddr::from(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, self.config.src_port, 0, 0));
result.bind(&addr).unwrap();
result.listen(100).unwrap();
result
}
pub fn connect_tcp(&self) -> Result<Socket> {
// tcp client 的 socket 不要在初始化时创建,在循环里创建
// 创建 socket 和 获取 endpoint 失败会 panic,连接失败会 error
let result = Socket::new(self.config.family, Type::STREAM, Some(Protocol::TCP)).unwrap();
result.set_tcp_nodelay(true).unwrap();
if self.config.mark != 0 {
result.set_mark(self.config.mark).unwrap();
}
if self.config.src_port != 0 {
result.set_reuse_address(true).unwrap();
let addr = Self::bind_addr(&self.config);
result.bind(&addr)?;
}
let meta = Meta {
src_id: self.config.local_id,
dst_id: self.config.remote_id,
};
let guard = pin();
let endpoint_ref = self.endpoint.load(Ordering::Relaxed, &guard);
let endpoint = unsafe { endpoint_ref.as_ref() }.unwrap();
result.send_to_with_flags(meta.as_bytes(), endpoint, MSG_FASTOPEN)?;
Ok(result)
}
fn attach_filter_ip(config: &ConfigRouter, socket: &Socket) -> Result<()> {
// 由于多个对端可能会使用相同的 ipproto 号,这里确保每个 socket 上只会收到自己对应的对端发来的消息
// 构造 Meta 来计算正确的字节序比较值
let meta_bytes = [
config.remote_id.to_le_bytes(),
config.local_id.to_le_bytes(),
];
// BPF 按网络字节序(大端序)比较,所以需要把小端序字节当作大端序来构造比较值
let expected_src_id = u32::from_be_bytes(meta_bytes[0]);
let expected_dst_id = u32::from_be_bytes(meta_bytes[1]);
// IP filter 工作原理:
// 每个对端起一个 raw socket
// 根据报文内容判断是给谁的。拒绝掉不是给自己的报文
// IPv4 raw socket 带 IP 头,IPv6 不带
// Meta 结构:src_id(u32) + dst_id(u32) = 8 字节
let filters: &[SockFilter] = match socket.domain()? {
Domain::IPV4 => &[
// [IPv4] 计算 IPv4 头长度: X = 4 * (IP[0] & 0xf)
bpf_stmt(BPF_LDX | BPF_B | BPF_MSH, 0),
// A = Packet[X + 0:4] = src_id
bpf_stmt(BPF_LD | BPF_W | BPF_IND, 0),
// if A != expected_src_id, goto reject
bpf_jump(BPF_JMP | BPF_JEQ | BPF_K, expected_src_id, 0, 3),
// A = Packet[X + 4:8] = dst_id
bpf_stmt(BPF_LD | BPF_W | BPF_IND, 4),
// if A != expected_dst_id, goto reject
bpf_jump(BPF_JMP | BPF_JEQ | BPF_K, expected_dst_id, 0, 1),
// 【接受】
bpf_stmt(BPF_RET | BPF_K, u32::MAX),
// 【拒绝】
bpf_stmt(BPF_RET | BPF_K, 0),
],
Domain::IPV6 => &[
// raw socket IPv6 没有 header
// A = Packet[0:4] = src_id
bpf_stmt(BPF_LD | BPF_W | BPF_ABS, 0),
// if A != expected_src_id, goto reject
bpf_jump(BPF_JMP | BPF_JEQ | BPF_K, expected_src_id, 0, 3),
// A = Packet[4:8] = dst_id
bpf_stmt(BPF_LD | BPF_W | BPF_ABS, 4),
// if A != expected_dst_id, goto reject
bpf_jump(BPF_JMP | BPF_JEQ | BPF_K, expected_dst_id, 0, 1),
// 【接受】
bpf_stmt(BPF_RET | BPF_K, u32::MAX),
// 【拒绝】
bpf_stmt(BPF_RET | BPF_K, 0),
],
_ => bail!("unsupported family"),
};
socket.attach_filter(filters)?;
Ok(())
}
pub fn attach_filter_udp(group: Vec<&Router>) -> Result<()> {
// 预留空间:4 条前置指令 + 每个 router 5 条 + 1 条默认返回
let mut filters: Vec<SockFilter> = Vec::with_capacity(4 + group.len() * 5 + 1);
// udp filter 工作原理:
// 每个对端起一个 udp socket
// 根据报文内容判断是给谁的,调度给对应的端口复用组序号
// Meta 结构:src_id(u32) + dst_id(u32) = 8 字节
// 加载 src_id 并存储到 M[0]
filters.push(bpf_stmt(BPF_LD | BPF_W | BPF_ABS, 0)); // A = packet[0:4] = src_id
filters.push(bpf_stmt(BPF_ST, 0)); // M[0] = A
// 加载 dst_id 并存储到 M[1]
filters.push(bpf_stmt(BPF_LD | BPF_W | BPF_ABS, 4)); // A = packet[4:8] = dst_id
filters.push(bpf_stmt(BPF_ST, 1)); // M[1] = A
for (i, router) in group.iter().enumerate() {
// 字节序转换:将小端序ID转换为BPF期望的大端序比较值
let src_bytes = router.config.remote_id.to_le_bytes();
let dst_bytes = router.config.local_id.to_le_bytes();
let expected_src_id = u32::from_be_bytes(src_bytes);
let expected_dst_id = u32::from_be_bytes(dst_bytes);
// 每个 router 5 条指令:
// 0: LD M[0] ; A = src_id
// 1: JEQ expected_src_id, +0, +3 ; 匹配继续,不匹配跳过当前 router
// 2: LD M[1] ; A = dst_id
// 3: JEQ expected_dst_id, +0, +1 ; 匹配继续,不匹配跳过当前 router
// 4: RET i ; 返回索引
filters.push(bpf_stmt(BPF_LD | BPF_MEM, 0));
filters.push(bpf_jump(BPF_JMP | BPF_JEQ | BPF_K, expected_src_id, 0, 3));
filters.push(bpf_stmt(BPF_LD | BPF_MEM, 1));
filters.push(bpf_jump(BPF_JMP | BPF_JEQ | BPF_K, expected_dst_id, 0, 1));
filters.push(bpf_stmt(BPF_RET | BPF_K, i as u32));
}
// 默认返回(不匹配任何 router)
filters.push(bpf_stmt(BPF_RET | BPF_K, u32::MAX));
let prog = sock_fprog {
len: filters.len() as u16,
filter: filters.as_mut_ptr() as *mut sock_filter,
};
let fd = group[0].socket.as_raw_fd();
let ret = unsafe {
setsockopt(
fd,
SOL_SOCKET,
SO_ATTACH_REUSEPORT_CBPF,
&prog as *const _ as *const c_void,
size_of_val(&prog) as socklen_t,
)
};
ensure!(ret != -1, std::io::Error::last_os_error());
Ok(())
}
pub(crate) fn handle_outbound_ip_udp(&self) {
let mut buffer = [0u8; 1500];
// Pre-initialize with our Meta header (local -> remote)
let meta = Meta {
src_id: self.config.local_id,
dst_id: self.config.remote_id,
};
buffer[..META_SIZE].copy_from_slice(meta.as_bytes());
loop {
let n = self.tun.recv(&mut buffer[META_SIZE..]).unwrap(); // recv 失败直接 panic
let guard = pin();
let endpoint_ref = self.endpoint.load(Ordering::Relaxed, &guard);
if let Some(endpoint) = unsafe { endpoint_ref.as_ref() } {
self.encrypt(&mut buffer[META_SIZE..META_SIZE + n]);
let _ = self.socket.send_to(&buffer[..META_SIZE + n], endpoint);
}
}
}
pub(crate) fn handle_inbound_ip_udp(&self) {
let mut recv_buf = [MaybeUninit::uninit(); 1500];
loop {
// 收到一个非法报文只丢弃一个报文
let (len, addr) = { self.socket.recv_from(&mut recv_buf).unwrap() }; // recv 出错直接 panic
let packet = unsafe { std::slice::from_raw_parts_mut(recv_buf.as_mut_ptr().cast(), len) };
// if addr.is_ipv6() { println!("{:X?}", packet) }
// 只有 ipv4 raw 会给 IP报头
let offset = if self.config.family == Domain::IPV4 && self.config.schema == Schema::IP {
(packet[0] & 0x0f) as usize * 4
} else {
0
} + META_SIZE;
{
let guard = pin();
let current_shared = self.endpoint.load(Ordering::Relaxed, &guard);
let is_same = unsafe { current_shared.as_ref() }.map(|c| *c == addr).unwrap_or(false);
if !is_same {
let new_shared = crossbeam::epoch::Owned::new(addr).into_shared(&guard);
let old_shared = self.endpoint.swap(new_shared, Ordering::Release, &guard);
unsafe { guard.defer_destroy(old_shared) }
}
}
let payload = &mut packet[offset..];
self.decrypt(payload, &self.config.local_secret);
let _ = self.tun.send(payload);
}
}
pub(crate) fn handle_outbound_tcp(&self, connection: &Socket) {
let _ = (|| -> Result<()> {
let mut buffer = [0u8; 1500];
loop {
let n = self.tun.recv(&mut buffer)?;
self.encrypt(&mut buffer[..n]);
Router::send_all_tcp(&connection, &buffer[..n])?;
}
})();
let _ = connection.shutdown(Shutdown::Both);
}
pub(crate) fn handle_inbound_tcp(&self, connection: &Socket) {
let _ = (|| -> Result<()> {
let mut buf = [MaybeUninit::uninit(); 1500];
let packet: &mut [u8] = unsafe { std::slice::from_raw_parts_mut(buf.as_mut_ptr().cast(), buf.len()) };
loop {
Router::recv_exact_tcp(&connection, &mut buf[0..6])?;
self.decrypt2(packet, &self.config.local_secret, 0..6);
let version = packet[0] >> 4;
let total_len = match version {
4 => u16::from_be_bytes([packet[2], packet[3]]) as usize,
6 => u16::from_be_bytes([packet[4], packet[5]]) as usize + 40,
_ => bail!("Invalid IP version"),
};
ensure!(6 < total_len && total_len <= buf.len(), "Invalid total length");
Router::recv_exact_tcp(&connection, &mut buf[6..total_len])?;
self.decrypt2(packet, &self.config.local_secret, 6..total_len);
self.tun.send(&packet[..total_len])?;
}
})();
let _ = connection.shutdown(Shutdown::Both);
}
pub(crate) fn recv_exact_tcp(sock: &Socket, mut buf: &mut [MaybeUninit<u8>]) -> Result<()> {
while !buf.is_empty() {
let n = sock.recv(buf)?;
ensure!(n != 0, std::io::ErrorKind::UnexpectedEof);
buf = &mut buf[n..];
}
Ok(())
}
pub(crate) fn send_all_tcp(sock: &Socket, mut buf: &[u8]) -> Result<()> {
while !buf.is_empty() {
let n = sock.send(buf)?;
buf = &buf[n..];
}
Ok(())
}
fn create_tun_device(config: &ConfigRouter) -> Result<Device> {
println!("create_tun_device {}", config.remote_id);
let mut tun_config = tun::Configuration::default();
tun_config.tun_name(config.dev.as_str()).up();
let dev = tun::create(&tun_config)?;
Ok(dev)
}
fn run_up_script(config: &ConfigRouter) -> Result<ExitStatus> {
Ok(Command::new("sh").args(["-c", config.up.as_str()]).status()?)
}
fn create_endpoint(config: &ConfigRouter) -> Atomic<SockAddr> {
println!("create_endpoint {}", config.remote_id);
match (config.endpoint.clone(), config.dst_port)
.to_socket_addrs()
.unwrap_or_default()
.filter(|a| match config.family {
Domain::IPV4 => a.is_ipv4(),
Domain::IPV6 => a.is_ipv6(),
_ => false,
})
.next()
{
None => Atomic::null(),
Some(addr) => Atomic::new(addr.into()),
}
}
pub fn new(config: ConfigRouter) -> Result<Router> {
println!("creating {}", config.remote_id);
let router = Router {
tun: Self::create_tun_device(&config)?,
endpoint: Self::create_endpoint(&config),
socket: Self::create_socket(&config)?,
tcp_listener_connection: Atomic::null(),
config,
};
println!("run_up_script {}", &router.config.remote_id);
Self::run_up_script(&router.config)?;
Ok(router)
}
fn bind_addr(config: &ConfigRouter) -> SockAddr {
match config.family {
Domain::IPV4 => SockAddr::from(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, config.src_port)),
Domain::IPV6 => SockAddr::from(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, config.src_port, 0, 0)),
_ => panic!("unsupported family"),
}
}
}
fn bpf_stmt(code: u32, k: u32) -> SockFilter {
SockFilter::new(code as u16, 0, 0, k)
}
fn bpf_jump(code: u32, k: u32, jt: u8, jf: u8) -> SockFilter {
SockFilter::new(code as u16, jt, jf, k)
}
use anyhow::{bail, ensure, Result};
use socket2::{Domain, Protocol, SockAddr, SockFilter, Socket, Type};
use std::net::Shutdown;
use std::{
ffi::c_void,
mem::MaybeUninit,
net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6, ToSocketAddrs},
ops::Range,
os::fd::{AsRawFd, FromRawFd},
process::{Command, ExitStatus},
sync::atomic::Ordering,
};
use tun::Device;
use crate::config::{ConfigRouter, Schema};
use crossbeam::epoch::{pin, Atomic};
use libc::{
setsockopt, sock_filter, sock_fprog, socklen_t, BPF_ABS, BPF_B, BPF_IND, BPF_JEQ, BPF_JMP, BPF_K, BPF_LD, BPF_LDX, BPF_MEM, BPF_MSH, BPF_RET, BPF_ST, BPF_W,
MSG_FASTOPEN, SOL_SOCKET, SO_ATTACH_REUSEPORT_CBPF,
};
pub const SECRET_LENGTH: usize = 32;
pub const META_SIZE: usize = size_of::<Meta>();
#[repr(C)]
#[derive(Debug, Clone, Copy, Default)]
pub struct Meta {
pub src_id: u32,
pub dst_id: u32,
}
impl Meta {
pub fn as_bytes(&self) -> &[u8; META_SIZE] {
unsafe { &*(self as *const Meta as *const [u8; META_SIZE]) }
}
pub fn from_bytes(bytes: &[MaybeUninit<u8>; META_SIZE]) -> &Meta {
unsafe { &*(bytes.as_ptr() as *const Meta) }
}
}
pub struct Router {
pub config: ConfigRouter,
pub tun: Device,
pub socket: Socket,
pub endpoint: Atomic<SockAddr>,
pub tcp_listener_connection: Atomic<Socket>,
}
#[inline]
fn xor_with_secret_offset<const L: usize>(data: &mut [u8], secret: &[u8; L], offset: usize) {
let len = data.len();
if len == 0 { return; }
let mut i = 0;
let mut key_pos = offset % L;
// 1) 先把 key_pos 补到 0(也就是把相位对齐到块边界),这样后面能走“完整块”
if key_pos != 0 {
let head = (L - key_pos).min(len);
for j in 0..head {
data[j] ^= secret[key_pos + j]; // 这里不会越界,因为 j < L - key_pos
}
i += head;
key_pos = 0;
}
// 2) 处理完整块(key_pos 已经对齐到 0)
while i + L <= len {
for j in 0..L {
data[i + j] ^= secret[j];
}
i += L;
}
// 3) 处理尾部
for j in 0..(len - i) {
data[i + j] ^= secret[j];
}
}
impl Router {
pub(crate) fn decrypt(&self, data: &mut [u8], secret: &[u8; SECRET_LENGTH]) {
xor_with_secret_offset::<SECRET_LENGTH>(data, secret, 0);
}
pub(crate) fn decrypt2(
&self,
data: &mut [u8],
secret: &[u8; SECRET_LENGTH],
range: Range<usize>,
) {
xor_with_secret_offset::<SECRET_LENGTH>(&mut data[range.clone()], secret, range.start);
}
pub(crate) fn encrypt(&self, data: &mut [u8]) {
xor_with_secret_offset::<SECRET_LENGTH>(data, &self.config.remote_secret, 0);
}
pub fn create_socket(config: &ConfigRouter) -> Result<Socket> {
println!("create_socket {}", config.remote_id);
match config.schema {
Schema::IP => {
let result = Socket::new(config.family, Type::RAW, Some(Protocol::from(config.proto as i32)))?;
if config.mark != 0 {
result.set_mark(config.mark)?;
}
Self::attach_filter_ip(config, &result)?;
Ok(result)
}
Schema::UDP => {
let result = Socket::new(config.family, Type::DGRAM, Some(Protocol::UDP))?;
if config.mark != 0 {
result.set_mark(config.mark)?;
}
if config.src_port != 0 {
result.set_reuse_port(true)?;
let addr = Self::bind_addr(config);
result.bind(&addr)?;
}
Ok(result)
}
Schema::TCP => Ok(unsafe { Socket::from_raw_fd(0) }),
}
}
pub fn listen_tcp(&self) -> Socket {
// listener
let result = Socket::new(Domain::IPV6, Type::STREAM, Some(Protocol::TCP)).unwrap();
result.set_reuse_address(true).unwrap();
let addr = SockAddr::from(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, self.config.src_port, 0, 0));
result.bind(&addr).unwrap();
result.listen(100).unwrap();
result
}
pub fn connect_tcp(&self) -> Result<Socket> {
// tcp client 的 socket 不要在初始化时创建,在循环里创建
// 创建 socket 和 获取 endpoint 失败会 panic,连接失败会 error
let result = Socket::new(self.config.family, Type::STREAM, Some(Protocol::TCP)).unwrap();
result.set_tcp_nodelay(true).unwrap();
if self.config.mark != 0 {
result.set_mark(self.config.mark).unwrap();
}
if self.config.src_port != 0 {
result.set_reuse_address(true).unwrap();
let addr = Self::bind_addr(&self.config);
result.bind(&addr)?;
}
let meta = Meta {
src_id: self.config.local_id,
dst_id: self.config.remote_id,
};
let guard = pin();
let endpoint_ref = self.endpoint.load(Ordering::Relaxed, &guard);
let endpoint = unsafe { endpoint_ref.as_ref() }.unwrap();
result.send_to_with_flags(meta.as_bytes(), endpoint, MSG_FASTOPEN)?;
Ok(result)
}
fn attach_filter_ip(config: &ConfigRouter, socket: &Socket) -> Result<()> {
// 由于多个对端可能会使用相同的 ipproto 号,这里确保每个 socket 上只会收到自己对应的对端发来的消息
// 构造 Meta 来计算正确的字节序比较值
let meta_bytes = [
config.remote_id.to_le_bytes(),
config.local_id.to_le_bytes(),
];
// BPF 按网络字节序(大端序)比较,所以需要把小端序字节当作大端序来构造比较值
let expected_src_id = u32::from_be_bytes(meta_bytes[0]);
let expected_dst_id = u32::from_be_bytes(meta_bytes[1]);
// IP filter 工作原理:
// 每个对端起一个 raw socket
// 根据报文内容判断是给谁的。拒绝掉不是给自己的报文
// IPv4 raw socket 带 IP 头,IPv6 不带
// Meta 结构:src_id(u32) + dst_id(u32) = 8 字节
let filters: &[SockFilter] = match socket.domain()? {
Domain::IPV4 => &[
// [IPv4] 计算 IPv4 头长度: X = 4 * (IP[0] & 0xf)
bpf_stmt(BPF_LDX | BPF_B | BPF_MSH, 0),
// A = Packet[X + 0:4] = src_id
bpf_stmt(BPF_LD | BPF_W | BPF_IND, 0),
// if A != expected_src_id, goto reject
bpf_jump(BPF_JMP | BPF_JEQ | BPF_K, expected_src_id, 0, 3),
// A = Packet[X + 4:8] = dst_id
bpf_stmt(BPF_LD | BPF_W | BPF_IND, 4),
// if A != expected_dst_id, goto reject
bpf_jump(BPF_JMP | BPF_JEQ | BPF_K, expected_dst_id, 0, 1),
// 【接受】
bpf_stmt(BPF_RET | BPF_K, u32::MAX),
// 【拒绝】
bpf_stmt(BPF_RET | BPF_K, 0),
],
Domain::IPV6 => &[
// raw socket IPv6 没有 header
// A = Packet[0:4] = src_id
bpf_stmt(BPF_LD | BPF_W | BPF_ABS, 0),
// if A != expected_src_id, goto reject
bpf_jump(BPF_JMP | BPF_JEQ | BPF_K, expected_src_id, 0, 3),
// A = Packet[4:8] = dst_id
bpf_stmt(BPF_LD | BPF_W | BPF_ABS, 4),
// if A != expected_dst_id, goto reject
bpf_jump(BPF_JMP | BPF_JEQ | BPF_K, expected_dst_id, 0, 1),
// 【接受】
bpf_stmt(BPF_RET | BPF_K, u32::MAX),
// 【拒绝】
bpf_stmt(BPF_RET | BPF_K, 0),
],
_ => bail!("unsupported family"),
};
socket.attach_filter(filters)?;
Ok(())
}
pub fn attach_filter_udp(group: Vec<&Router>) -> Result<()> {
// 预留空间:4 条前置指令 + 每个 router 5 条 + 1 条默认返回
let mut filters: Vec<SockFilter> = Vec::with_capacity(4 + group.len() * 5 + 1);
// udp filter 工作原理:
// 每个对端起一个 udp socket
// 根据报文内容判断是给谁的,调度给对应的端口复用组序号
// Meta 结构:src_id(u32) + dst_id(u32) = 8 字节
// 加载 src_id 并存储到 M[0]
filters.push(bpf_stmt(BPF_LD | BPF_W | BPF_ABS, 0)); // A = packet[0:4] = src_id
filters.push(bpf_stmt(BPF_ST, 0)); // M[0] = A
// 加载 dst_id 并存储到 M[1]
filters.push(bpf_stmt(BPF_LD | BPF_W | BPF_ABS, 4)); // A = packet[4:8] = dst_id
filters.push(bpf_stmt(BPF_ST, 1)); // M[1] = A
for (i, router) in group.iter().enumerate() {
// 字节序转换:将小端序ID转换为BPF期望的大端序比较值
let src_bytes = router.config.remote_id.to_le_bytes();
let dst_bytes = router.config.local_id.to_le_bytes();
let expected_src_id = u32::from_be_bytes(src_bytes);
let expected_dst_id = u32::from_be_bytes(dst_bytes);
// 每个 router 5 条指令:
// 0: LD M[0] ; A = src_id
// 1: JEQ expected_src_id, +0, +3 ; 匹配继续,不匹配跳过当前 router
// 2: LD M[1] ; A = dst_id
// 3: JEQ expected_dst_id, +0, +1 ; 匹配继续,不匹配跳过当前 router
// 4: RET i ; 返回索引
filters.push(bpf_stmt(BPF_LD | BPF_MEM, 0));
filters.push(bpf_jump(BPF_JMP | BPF_JEQ | BPF_K, expected_src_id, 0, 3));
filters.push(bpf_stmt(BPF_LD | BPF_MEM, 1));
filters.push(bpf_jump(BPF_JMP | BPF_JEQ | BPF_K, expected_dst_id, 0, 1));
filters.push(bpf_stmt(BPF_RET | BPF_K, i as u32));
}
// 默认返回(不匹配任何 router)
filters.push(bpf_stmt(BPF_RET | BPF_K, u32::MAX));
let prog = sock_fprog {
len: filters.len() as u16,
filter: filters.as_mut_ptr() as *mut sock_filter,
};
let fd = group[0].socket.as_raw_fd();
let ret = unsafe {
setsockopt(
fd,
SOL_SOCKET,
SO_ATTACH_REUSEPORT_CBPF,
&prog as *const _ as *const c_void,
size_of_val(&prog) as socklen_t,
)
};
ensure!(ret != -1, std::io::Error::last_os_error());
Ok(())
}
pub(crate) fn handle_outbound_ip_udp(&self) {
let mut buffer = [0u8; 1500];
// Pre-initialize with our Meta header (local -> remote)
let meta = Meta {
src_id: self.config.local_id,
dst_id: self.config.remote_id,
};
buffer[..META_SIZE].copy_from_slice(meta.as_bytes());
loop {
let n = self.tun.recv(&mut buffer[META_SIZE..]).unwrap(); // recv 失败直接 panic
let guard = pin();
let endpoint_ref = self.endpoint.load(Ordering::Relaxed, &guard);
if let Some(endpoint) = unsafe { endpoint_ref.as_ref() } {
self.encrypt(&mut buffer[META_SIZE..META_SIZE + n]);
let _ = self.socket.send_to(&buffer[..META_SIZE + n], endpoint);
}
}
}
pub(crate) fn handle_inbound_ip_udp(&self) {
let mut recv_buf = [MaybeUninit::uninit(); 1500];
loop {
// 收到一个非法报文只丢弃一个报文
let (len, addr) = { self.socket.recv_from(&mut recv_buf).unwrap() }; // recv 出错直接 panic
let packet = unsafe { std::slice::from_raw_parts_mut(recv_buf.as_mut_ptr().cast(), len) };
// if addr.is_ipv6() { println!("{:X?}", packet) }
// 只有 ipv4 raw 会给 IP报头
let offset = if self.config.family == Domain::IPV4 && self.config.schema == Schema::IP {
(packet[0] & 0x0f) as usize * 4
} else {
0
} + META_SIZE;
{
let guard = pin();
let current_shared = self.endpoint.load(Ordering::Relaxed, &guard);
let is_same = unsafe { current_shared.as_ref() }.map(|c| *c == addr).unwrap_or(false);
if !is_same {
let new_shared = crossbeam::epoch::Owned::new(addr).into_shared(&guard);
let old_shared = self.endpoint.swap(new_shared, Ordering::Release, &guard);
unsafe { guard.defer_destroy(old_shared) }
}
}
let payload = &mut packet[offset..];
self.decrypt(payload, &self.config.local_secret);
let _ = self.tun.send(payload);
}
}
pub(crate) fn handle_outbound_tcp(&self, connection: &Socket) {
let _ = (|| -> Result<()> {
let mut buffer = [0u8; 1500];
loop {
let n = self.tun.recv(&mut buffer)?;
self.encrypt(&mut buffer[..n]);
Router::send_all_tcp(&connection, &buffer[..n])?;
}
})();
let _ = connection.shutdown(Shutdown::Both);
}
pub(crate) fn handle_inbound_tcp(&self, connection: &Socket) {
let _ = (|| -> Result<()> {
let mut buf = [MaybeUninit::uninit(); 1500];
let packet: &mut [u8] = unsafe { std::slice::from_raw_parts_mut(buf.as_mut_ptr().cast(), buf.len()) };
loop {
Router::recv_exact_tcp(&connection, &mut buf[0..6])?;
self.decrypt2(packet, &self.config.local_secret, 0..6);
let version = packet[0] >> 4;
let total_len = match version {
4 => u16::from_be_bytes([packet[2], packet[3]]) as usize,
6 => u16::from_be_bytes([packet[4], packet[5]]) as usize + 40,
_ => bail!("Invalid IP version"),
};
ensure!(6 < total_len && total_len <= buf.len(), "Invalid total length");
Router::recv_exact_tcp(&connection, &mut buf[6..total_len])?;
self.decrypt2(packet, &self.config.local_secret, 6..total_len);
self.tun.send(&packet[..total_len])?;
}
})();
let _ = connection.shutdown(Shutdown::Both);
}
pub(crate) fn recv_exact_tcp(sock: &Socket, mut buf: &mut [MaybeUninit<u8>]) -> Result<()> {
while !buf.is_empty() {
let n = sock.recv(buf)?;
ensure!(n != 0, std::io::ErrorKind::UnexpectedEof);
buf = &mut buf[n..];
}
Ok(())
}
pub(crate) fn send_all_tcp(sock: &Socket, mut buf: &[u8]) -> Result<()> {
while !buf.is_empty() {
let n = sock.send(buf)?;
buf = &buf[n..];
}
Ok(())
}
fn create_tun_device(config: &ConfigRouter) -> Result<Device> {
println!("create_tun_device {}", config.remote_id);
let mut tun_config = tun::Configuration::default();
tun_config.tun_name(config.dev.as_str()).up();
let dev = tun::create(&tun_config)?;
Ok(dev)
}
fn run_up_script(config: &ConfigRouter) -> Result<ExitStatus> {
Ok(Command::new("sh").args(["-c", config.up.as_str()]).status()?)
}
fn create_endpoint(config: &ConfigRouter) -> Atomic<SockAddr> {
println!("create_endpoint {}", config.remote_id);
match (config.endpoint.clone(), config.dst_port)
.to_socket_addrs()
.unwrap_or_default()
.filter(|a| match config.family {
Domain::IPV4 => a.is_ipv4(),
Domain::IPV6 => a.is_ipv6(),
_ => false,
})
.next()
{
None => Atomic::null(),
Some(addr) => Atomic::new(addr.into()),
}
}
pub fn new(config: ConfigRouter) -> Result<Router> {
println!("creating {}", config.remote_id);
let router = Router {
tun: Self::create_tun_device(&config)?,
endpoint: Self::create_endpoint(&config),
socket: Self::create_socket(&config)?,
tcp_listener_connection: Atomic::null(),
config,
};
println!("run_up_script {}", &router.config.remote_id);
Self::run_up_script(&router.config)?;
Ok(router)
}
fn bind_addr(config: &ConfigRouter) -> SockAddr {
match config.family {
Domain::IPV4 => SockAddr::from(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, config.src_port)),
Domain::IPV6 => SockAddr::from(SocketAddrV6::new(Ipv6Addr::UNSPECIFIED, config.src_port, 0, 0)),
_ => panic!("unsupported family"),
}
}
}
fn bpf_stmt(code: u32, k: u32) -> SockFilter {
SockFilter::new(code as u16, 0, 0, k)
}
fn bpf_jump(code: u32, k: u32, jt: u8, jf: u8) -> SockFilter {
SockFilter::new(code as u16, jt, jf, k)
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment