Commit 2117fd01 authored by nanamicat's avatar nanamicat

udp

parent 67667c4d
Pipeline #41923 passed with stages
in 2 minutes and 13 seconds
mod router; mod router;
use crate::router::{Router, SECRET_LENGTH}; use crate::router::{Router, SECRET_LENGTH};
use crossbeam_utils::thread; use crossbeam_utils::thread;
use std::collections::HashMap;
use std::env; use std::env;
use std::error::Error; use std::error::Error;
use std::io::{Read, Write}; use std::io::{Read, Write};
use std::mem::MaybeUninit; use std::mem::MaybeUninit;
use std::mem::{size_of, transmute}; use std::mem::{size_of, transmute};
use std::sync::Arc;
use std::sync::atomic::Ordering; use std::sync::atomic::Ordering;
#[repr(C)] #[repr(C)]
...@@ -15,7 +18,8 @@ pub struct Meta { ...@@ -15,7 +18,8 @@ pub struct Meta {
pub reversed: u16, pub reversed: u16,
} }
use serde::Deserialize; use serde::{Deserialize, Deserializer};
use socket2::{Domain, Socket};
#[derive(Deserialize)] #[derive(Deserialize)]
pub struct Config { pub struct Config {
...@@ -26,8 +30,16 @@ pub struct Config { ...@@ -26,8 +30,16 @@ pub struct Config {
#[derive(Deserialize)] #[derive(Deserialize)]
pub struct ConfigRouter { pub struct ConfigRouter {
pub remote_id: u8, pub remote_id: u8,
#[serde(default)]
pub schema: Schema,
#[serde(default)]
pub proto: u8, pub proto: u8,
pub family: u8, #[serde(default)]
pub src_port: u16,
#[serde(default)]
pub dst_port: u16,
#[serde(deserialize_with = "deserialize_domain")]
pub family: Domain,
pub mark: u32, pub mark: u32,
pub endpoint: String, pub endpoint: String,
pub remote_secret: String, pub remote_secret: String,
...@@ -35,16 +47,43 @@ pub struct ConfigRouter { ...@@ -35,16 +47,43 @@ pub struct ConfigRouter {
pub up: String, pub up: String,
} }
#[derive(Deserialize, Default)]
pub enum Schema {
#[default]
IP,
UDP,
TCP,
FakeTCP,
}
fn deserialize_domain<'de, D>(d: D) -> Result<Domain, D::Error>
where
D: Deserializer<'de>,
{
match u8::deserialize(d)? {
4 => Ok(Domain::IPV4),
6 => Ok(Domain::IPV6),
_ => Err(serde::de::Error::custom("Invalid domain")),
}
}
fn main() -> Result<(), Box<dyn Error>> { fn main() -> Result<(), Box<dyn Error>> {
println!("Starting"); println!("Starting");
let config: Config = serde_json::from_str(env::args().nth(1).ok_or("need param")?.as_str())?; let config: Config = serde_json::from_str(env::args().nth(1).ok_or("need param")?.as_str())?;
let local_secret: [u8; SECRET_LENGTH] = Router::create_secret(config.local_secret.as_str())?; let local_secret: [u8; SECRET_LENGTH] = Router::create_secret(config.local_secret.as_str())?;
let mut udp_groups: HashMap<u16, (Arc<Socket>, Vec<u8>)> = HashMap::new();
let routers: Vec<Router> = config let routers: Vec<Router> = config
.routers .routers
.into_iter() .into_iter()
.map(|c| Router::new(c, config.local_id)) .map(|c| Router::new(c, config.local_id, &mut udp_groups))
.collect::<Result<Vec<_>, _>>()?; .collect::<Result<Vec<_>, _>>()?;
for (socket, group) in udp_groups.values() {
Router::attach_filter_udp(socket, group, config.local_id)?;
}
println!("created tuns"); println!("created tuns");
const META_SIZE: usize = size_of::<Meta>(); const META_SIZE: usize = size_of::<Meta>();
...@@ -53,7 +92,7 @@ fn main() -> Result<(), Box<dyn Error>> { ...@@ -53,7 +92,7 @@ fn main() -> Result<(), Box<dyn Error>> {
let (mut reader, mut writer) = router.split(); let (mut reader, mut writer) = router.split();
s.spawn(move |_| { s.spawn(move |_| {
let mut buffer = [0u8; 1500 - 20]; // minus typical IP header space let mut buffer = [0u8; 1500];
// Pre-initialize with our Meta header (local -> remote) // Pre-initialize with our Meta header (local -> remote)
let meta = Meta { let meta = Meta {
...@@ -84,11 +123,13 @@ fn main() -> Result<(), Box<dyn Error>> { ...@@ -84,11 +123,13 @@ fn main() -> Result<(), Box<dyn Error>> {
let (len, addr) = writer.socket.recv_from(&mut recv_buf).unwrap(); let (len, addr) = writer.socket.recv_from(&mut recv_buf).unwrap();
let packet: &mut [u8] = unsafe { transmute(&mut recv_buf[..len]) }; let packet: &mut [u8] = unsafe { transmute(&mut recv_buf[..len]) };
// if addr.is_ipv6() { println!("{:X?}", packet) } // if addr.is_ipv6() { println!("{:X?}", packet) }
let offset = if addr.is_ipv4() { // 只有 ipv4 raw 会给 IP报头
(packet[0] & 0x0f) as usize * 4 let offset =
} else { if addr.is_ipv4() && addr.as_socket_ipv4().ok_or("?")?.port() == 0 {
0 (packet[0] & 0x0f) as usize * 4
} + META_SIZE; } else {
0
} + META_SIZE;
let guard = crossbeam::epoch::pin(); let guard = crossbeam::epoch::pin();
let current_shared = writer.endpoint.load(Ordering::SeqCst, &guard); let current_shared = writer.endpoint.load(Ordering::SeqCst, &guard);
...@@ -107,13 +148,13 @@ fn main() -> Result<(), Box<dyn Error>> { ...@@ -107,13 +148,13 @@ fn main() -> Result<(), Box<dyn Error>> {
let payload = &mut packet[offset..]; let payload = &mut packet[offset..];
writer.decrypt(payload, &local_secret); writer.decrypt(payload, &local_secret);
writer.tun_writer.write_all(payload)?; writer.tun_writer.write_all(payload)?;
writer.last_activity = std::time::Instant::now();
Ok::<(), Box<dyn Error>>(()) Ok::<(), Box<dyn Error>>(())
})(); })();
} }
}); });
} }
}) })
.unwrap(); .map_err(|_| "Thread panicked")?;
Ok(()) Ok(())
} }
use base64::prelude::BASE64_STANDARD; use base64::prelude::BASE64_STANDARD;
use base64::Engine; use base64::Engine;
use socket2::{Domain, Protocol, SockAddr, SockFilter, Socket, Type}; use socket2::{Domain, Protocol, SockAddr, SockFilter, Socket, Type};
use std::net::ToSocketAddrs; use std::collections::HashMap;
use std::ffi::c_void;
use std::net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6, ToSocketAddrs};
use std::os::fd::AsRawFd;
use std::process::{Command, ExitStatus}; use std::process::{Command, ExitStatus};
use std::sync::Arc; use std::sync::Arc;
use tun::{Reader, Writer}; use tun::{Reader, Writer};
pub const SECRET_LENGTH: usize = 32; pub const SECRET_LENGTH: usize = 32;
use crate::{ConfigRouter, Meta}; use crate::{ConfigRouter, Meta, Schema};
use crossbeam::epoch::Atomic; use crossbeam::epoch::Atomic;
use libc::{BPF_ABS, BPF_B, BPF_IND, BPF_JEQ, BPF_JMP, BPF_K, BPF_LD, BPF_LDX, BPF_MSH, BPF_RET, BPF_W}; use libc::{
sock_filter, sock_fprog, socklen_t, BPF_ABS, BPF_B, BPF_IND, BPF_JEQ, BPF_JMP, BPF_K, BPF_LD, BPF_LDX,
BPF_MSH, BPF_RET, BPF_W, SOL_SOCKET, SO_ATTACH_REUSEPORT_CBPF,
};
// tun -> raw // tun -> raw
pub struct RouterReader { pub struct RouterReader {
...@@ -32,6 +38,7 @@ pub struct RouterWriter { ...@@ -32,6 +38,7 @@ pub struct RouterWriter {
pub tun_writer: Writer, pub tun_writer: Writer,
pub socket: Arc<Socket>, pub socket: Arc<Socket>,
pub endpoint: Arc<Atomic<SockAddr>>, pub endpoint: Arc<Atomic<SockAddr>>,
pub last_activity: std::time::Instant,
} }
impl RouterWriter { impl RouterWriter {
...@@ -62,26 +69,69 @@ impl Router { ...@@ -62,26 +69,69 @@ impl Router {
Ok(secret) Ok(secret)
} }
fn create_raw_socket( fn create_socket(
config: &ConfigRouter, config: &ConfigRouter,
local_id: u8, local_id: u8,
groups: &mut HashMap<u16, (Arc<Socket>, Vec<u8>)>,
) -> Result<Arc<Socket>, Box<dyn std::error::Error>> { ) -> Result<Arc<Socket>, Box<dyn std::error::Error>> {
let result = Socket::new( match config.schema {
if config.family == 6 { Schema::IP => {
Domain::IPV6 let result = Socket::new(
} else { config.family,
Domain::IPV4 Type::RAW,
}, Some(Protocol::from(config.proto as i32)),
Type::RAW, )?;
Some(Protocol::from(config.proto as i32)), #[cfg(target_os = "linux")]
)?; result.set_mark(config.mark)?;
#[cfg(target_os = "linux")] Self::attach_filter_raw(config, local_id, &result)?;
result.set_mark(config.mark)?; Ok(Arc::new(result))
Self::attach_readable_filter(config, local_id, &result)?; }
Ok(Arc::new(result)) Schema::UDP => {
let result = Socket::new(config.family, Type::DGRAM, Some(Protocol::UDP))?;
if config.src_port != 0 {
result.set_reuse_port(true)?;
let addr = match config.family {
Domain::IPV4 => SockAddr::from(SocketAddrV4::new(
Ipv4Addr::UNSPECIFIED,
config.src_port,
)),
Domain::IPV6 => SockAddr::from(SocketAddrV6::new(
Ipv6Addr::UNSPECIFIED,
config.src_port,
0,
0,
)),
_ => return Err("unsupported family".into()),
};
result.bind(&addr)?;
let result1 = Arc::new(result);
match groups.get_mut(&config.src_port) {
None => {
groups
.insert(config.src_port, (result1.clone(), vec![config.remote_id]));
}
Some((_, group)) => {
group.push(config.remote_id);
}
}
Ok(result1)
} else {
Ok(Arc::new(result))
}
}
Schema::TCP => {
let result = Socket::new(config.family, Type::STREAM, Some(Protocol::TCP))?;
Ok(Arc::new(result))
}
Schema::FakeTCP => {
let result = Socket::new(config.family, Type::STREAM, Some(Protocol::TCP))?;
Ok(Arc::new(result))
}
}
} }
fn attach_readable_filter( fn attach_filter_raw(
config: &ConfigRouter, config: &ConfigRouter,
local_id: u8, local_id: u8,
socket: &Socket, socket: &Socket,
...@@ -101,8 +151,8 @@ impl Router { ...@@ -101,8 +151,8 @@ impl Router {
// 如果是纯自定义 IP 协议,这里是 0 // 如果是纯自定义 IP 协议,这里是 0
let payload_offset = 0; let payload_offset = 0;
let filter: &[SockFilter] = match config.family { let filters: &[SockFilter] = match socket.domain()? {
4 => &[ Domain::IPV4 => &[
// [IPv4] 计算 IPv4 头长度: X = 4 * (IP[0] & 0xf) // [IPv4] 计算 IPv4 头长度: X = 4 * (IP[0] & 0xf)
bpf_stmt(BPF_LDX | BPF_B | BPF_MSH, 0), bpf_stmt(BPF_LDX | BPF_B | BPF_MSH, 0),
// A = Packet[X + payload_offset] // A = Packet[X + payload_offset]
...@@ -114,7 +164,7 @@ impl Router { ...@@ -114,7 +164,7 @@ impl Router {
// 【拒绝 (False 路径)】 // 【拒绝 (False 路径)】
bpf_stmt(BPF_RET | BPF_K, 0), bpf_stmt(BPF_RET | BPF_K, 0),
], ],
6 => &[ Domain::IPV6 => &[
// raw socket IPv6 没有 header,加载第 0 字节到累加器 A // raw socket IPv6 没有 header,加载第 0 字节到累加器 A
bpf_stmt(BPF_LD | BPF_W | BPF_ABS, 0), bpf_stmt(BPF_LD | BPF_W | BPF_ABS, 0),
// if (A == target_val) goto Accept; else goto Reject; // if (A == target_val) goto Accept; else goto Reject;
...@@ -126,7 +176,69 @@ impl Router { ...@@ -126,7 +176,69 @@ impl Router {
], ],
_ => Err("unsupported family")?, _ => Err("unsupported family")?,
}; };
socket.attach_filter(filter)?; socket.attach_filter(filters)?;
Ok(())
}
pub fn attach_filter_udp(
socket: &Arc<Socket>,
group: &Vec<u8>,
local_id: u8,
) -> Result<(), Box<dyn std::error::Error>> {
let values: Vec<u32> = group
.iter()
.map(|&f| {
const META_SIZE: usize = size_of::<Meta>();
let meta = Meta {
src_id: f,
dst_id: local_id,
reversed: 0,
};
let meta_bytes: [u8; META_SIZE] =
unsafe { *(&meta as *const Meta as *const [u8; META_SIZE]) };
u32::from_be_bytes(meta_bytes)
})
.collect();
let mut filters: Vec<SockFilter> = Vec::with_capacity(1 + values.len() * 2 + 1);
// Load the first 4 bytes of the packet into the accumulator (A)
filters.push(bpf_stmt(BPF_LD | BPF_W | BPF_ABS, 0));
for (i, &val) in values.iter().enumerate() {
// 如果匹配,继续下一句(返回),如果不匹配,跳过下一句。
filters.push(bpf_jump(BPF_JMP | BPF_JEQ | BPF_K, val, 0, 1));
// If match, return the index (i + 1, since 0 means drop)
filters.push(bpf_stmt(BPF_RET | BPF_K, i as u32));
}
// If no match found after all comparisons, drop the packet
filters.push(bpf_stmt(BPF_RET | BPF_K, u32::MAX));
Self::attach_reuseport_cbpf(socket, &mut filters)?;
Ok(())
}
fn attach_reuseport_cbpf(
sock: &Arc<Socket>,
code: &mut [SockFilter],
) -> Result<(), Box<dyn std::error::Error>> {
let prog = sock_fprog {
len: code.len() as u16,
filter: code.as_mut_ptr() as *mut sock_filter,
};
let fd = sock.as_raw_fd();
let ret = unsafe {
libc::setsockopt(
fd,
SOL_SOCKET,
SO_ATTACH_REUSEPORT_CBPF,
&prog as *const _ as *const c_void,
size_of_val(&prog) as socklen_t,
)
};
if ret == -1 {
Err(std::io::Error::last_os_error())?;
}
Ok(()) Ok(())
} }
...@@ -144,20 +256,32 @@ impl Router { ...@@ -144,20 +256,32 @@ impl Router {
Command::new("sh").args(["-c", config.up.as_str()]).status() Command::new("sh").args(["-c", config.up.as_str()]).status()
} }
fn create_endpoint( fn create_endpoint(config: &ConfigRouter) -> Arc<Atomic<SockAddr>> {
config: &ConfigRouter, let addr = match (config.endpoint.clone(), config.dst_port)
) -> Result<Arc<Atomic<SockAddr>>, Box<dyn std::error::Error>> { .to_socket_addrs()
let parsed = (config.endpoint.clone(), 0u16) .unwrap_or_default()
.to_socket_addrs()? .filter(|a| match config.family {
Domain::IPV4 => a.is_ipv4(),
Domain::IPV6 => a.is_ipv6(),
_ => false,
})
.next() .next()
.ok_or(config.endpoint.clone())?; {
Ok(Arc::new(Atomic::new(parsed.into()))) None => Atomic::null(),
Some(addr) => Atomic::new(addr.into()),
};
Arc::new(addr)
} }
pub fn new(config: ConfigRouter, local_id: u8) -> Result<Router, Box<dyn std::error::Error>> { pub fn new(
config: ConfigRouter,
local_id: u8,
udp_count: &mut HashMap<u16, (Arc<Socket>, Vec<u8>)>,
) -> Result<Router, Box<dyn std::error::Error>> {
let secret = Self::create_secret(config.remote_secret.as_str())?; let secret = Self::create_secret(config.remote_secret.as_str())?;
let endpoint = Self::create_endpoint(&config)?; let endpoint = Self::create_endpoint(&config);
let socket = Self::create_raw_socket(&config, local_id)?; let socket = Self::create_socket(&config, local_id, udp_count)?;
let (tun_reader, tun_writer) = Self::create_tun_device(&config)?; let (tun_reader, tun_writer) = Self::create_tun_device(&config)?;
Self::run_up_script(&config)?; Self::run_up_script(&config)?;
...@@ -178,6 +302,7 @@ impl Router { ...@@ -178,6 +302,7 @@ impl Router {
endpoint: self.endpoint.clone(), endpoint: self.endpoint.clone(),
tun_writer: self.tun_writer, tun_writer: self.tun_writer,
socket: self.socket.clone(), socket: self.socket.clone(),
last_activity: std::time::Instant::now(),
}; };
let reader = RouterReader { let reader = RouterReader {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment