Commit 1f73ae2b authored by nanamicat's avatar nanamicat

BTreeMap

parent dc2030f9
Pipeline #42038 failed with stages
in 2 minutes and 13 seconds
...@@ -2,7 +2,7 @@ mod config; ...@@ -2,7 +2,7 @@ mod config;
mod router; mod router;
use crate::config::{Config, Schema}; use crate::config::{Config, Schema};
use crate::router::{Meta, Router, META_SIZE}; use crate::router::Router;
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use crossbeam::epoch::{pin, Owned}; use crossbeam::epoch::{pin, Owned};
use crossbeam_utils::thread; use crossbeam_utils::thread;
...@@ -11,7 +11,7 @@ use std::collections::BTreeMap; ...@@ -11,7 +11,7 @@ use std::collections::BTreeMap;
use std::net::Shutdown; use std::net::Shutdown;
use std::sync::atomic::Ordering; use std::sync::atomic::Ordering;
use std::time::Duration; use std::time::Duration;
use std::{env, mem::MaybeUninit}; use std::env;
fn main() -> Result<()> { fn main() -> Result<()> {
println!("Starting"); println!("Starting");
...@@ -75,16 +75,13 @@ fn main() -> Result<()> { ...@@ -75,16 +75,13 @@ fn main() -> Result<()> {
let (connection, _) = socket.accept().unwrap(); let (connection, _) = socket.accept().unwrap();
s.spawn(move |_| { s.spawn(move |_| {
connection.set_tcp_nodelay(true).unwrap(); connection.set_tcp_nodelay(true).unwrap();
let mut meta_bytes = [0u8; 2];
let mut meta_bytes = [MaybeUninit::uninit(); META_SIZE];
Router::recv_exact_tcp(&connection, &mut meta_bytes).unwrap(); Router::recv_exact_tcp(&connection, &mut meta_bytes).unwrap();
let meta: &Meta = Meta::from_bytes(&meta_bytes); let src_id = meta_bytes[0];
if meta.reversed == 0 let dst_id = meta_bytes[1];
&& let Some(router) = routers.get(&meta.src_id) if let Some(router) = routers.get(&src_id)
&& meta.dst_id == router.config.local_id && dst_id == router.config.local_id
{ {
// let connection = Arc::new(connection);
// tcp listener 只许一个连接,过来新连接就把前一个关掉。 // tcp listener 只许一个连接,过来新连接就把前一个关掉。
{ {
let guard = pin(); let guard = pin();
......
use crate::config::{ConfigRouter, Schema};
use anyhow::{bail, ensure, Result}; use anyhow::{bail, ensure, Result};
use crossbeam::epoch::{pin, Atomic};
use libc::{
setsockopt, sock_filter, sock_fprog, socklen_t, BPF_ABS, BPF_B, BPF_H, BPF_IND, BPF_JEQ, BPF_JMP, BPF_K, BPF_LD, BPF_LDX, BPF_MSH, BPF_RET,
MSG_FASTOPEN, SOL_SOCKET, SO_ATTACH_REUSEPORT_CBPF,
};
use socket2::{Domain, Protocol, SockAddr, SockFilter, Socket, Type}; use socket2::{Domain, Protocol, SockAddr, SockFilter, Socket, Type};
use std::fs::{File, OpenOptions};
use std::net::Shutdown; use std::net::Shutdown;
use std::{ use std::{
ffi::c_void, ffi::c_void,
mem::MaybeUninit, mem::MaybeUninit,
net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6, ToSocketAddrs}, net::{Ipv4Addr, Ipv6Addr, SocketAddrV4, SocketAddrV6, ToSocketAddrs},
ops::Range,
os::fd::{AsRawFd, FromRawFd}, os::fd::{AsRawFd, FromRawFd},
process::{Command, ExitStatus}, process::{Command, ExitStatus},
sync::atomic::Ordering, sync::atomic::Ordering,
}; };
use tun::Device;
use crate::config::{ConfigRouter, Schema};
use crossbeam::epoch::{pin, Atomic};
use libc::{
setsockopt, sock_filter, sock_fprog, socklen_t, BPF_ABS, BPF_B, BPF_IND, BPF_JEQ, BPF_JMP, BPF_K, BPF_LD, BPF_LDX, BPF_MSH, BPF_RET, BPF_W,
MSG_FASTOPEN, SOL_SOCKET, SO_ATTACH_REUSEPORT_CBPF,
};
pub const SECRET_LENGTH: usize = 32; pub const SECRET_LENGTH: usize = 32;
pub const META_SIZE: usize = size_of::<Meta>(); pub const META_SIZE: usize = 4;
#[repr(C)]
#[derive(Debug, Clone, Copy, Default)]
pub struct Meta {
pub src_id: u8,
pub dst_id: u8,
pub reversed: u16,
}
impl Meta {
pub fn as_bytes(&self) -> &[u8; META_SIZE] {
unsafe { &*(self as *const Meta as *const [u8; META_SIZE]) }
}
pub fn from_bytes(bytes: &[MaybeUninit<u8>; META_SIZE]) -> &Meta {
unsafe { &*(bytes.as_ptr() as *const Meta) }
}
}
pub struct Router { pub struct Router {
pub config: ConfigRouter, pub config: ConfigRouter,
pub tun: Device, pub tun: File,
pub socket: Socket, pub socket: Socket,
pub endpoint: Atomic<SockAddr>, pub endpoint: Atomic<SockAddr>,
...@@ -48,17 +30,11 @@ pub struct Router { ...@@ -48,17 +30,11 @@ pub struct Router {
} }
impl Router { impl Router {
pub(crate) fn decrypt(&self, data: &mut [u8], secret: &[u8; SECRET_LENGTH]) { pub(crate) fn decrypt(&self, data: &mut [u8]) {
for (i, b) in data.iter_mut().enumerate() { for (i, b) in data.iter_mut().enumerate() {
*b ^= secret[i % SECRET_LENGTH]; *b ^= self.config.local_secret[i % SECRET_LENGTH];
} }
} }
pub(crate) fn decrypt2(&self, data: &mut [u8], secret: &[u8; SECRET_LENGTH], range: Range<usize>) {
for (i, b) in data[range.clone()].iter_mut().enumerate() {
*b ^= secret[(range.start + i) % SECRET_LENGTH];
}
}
pub(crate) fn encrypt(&self, data: &mut [u8]) { pub(crate) fn encrypt(&self, data: &mut [u8]) {
for (i, b) in data.iter_mut().enumerate() { for (i, b) in data.iter_mut().enumerate() {
*b ^= self.config.remote_secret[i % SECRET_LENGTH]; *b ^= self.config.remote_secret[i % SECRET_LENGTH];
...@@ -116,27 +92,17 @@ impl Router { ...@@ -116,27 +92,17 @@ impl Router {
result.bind(&addr)?; result.bind(&addr)?;
} }
let meta = Meta {
src_id: self.config.local_id,
dst_id: self.config.remote_id,
reversed: 0,
};
let guard = pin(); let guard = pin();
let endpoint_ref = self.endpoint.load(Ordering::Relaxed, &guard); let endpoint_ref = self.endpoint.load(Ordering::Relaxed, &guard);
let endpoint = unsafe { endpoint_ref.as_ref() }.unwrap(); let endpoint = unsafe { endpoint_ref.as_ref() }.unwrap();
result.send_to_with_flags(meta.as_bytes(), endpoint, MSG_FASTOPEN)?; result.send_to_with_flags(&[self.config.local_id, self.config.remote_id], endpoint, MSG_FASTOPEN)?;
Ok(result) Ok(result)
} }
fn attach_filter_ip(config: &ConfigRouter, socket: &Socket) -> Result<()> { fn attach_filter_ip(config: &ConfigRouter, socket: &Socket) -> Result<()> {
// 由于多个对端可能会使用相同的 ipprpto 号,这里确保每个 socket 上只会收到自己对应的对端发来的消息 // 由于多个对端可能会使用相同的 ipprpto 号,这里确保每个 socket 上只会收到自己对应的对端发来的消息
let meta = Meta { let value = u16::from_be_bytes([config.remote_id, config.local_id]) as u32;
src_id: config.remote_id,
dst_id: config.local_id,
reversed: 0,
};
let value = u32::from_be_bytes(*meta.as_bytes());
// 如果你的协议是 UDP,这里必须是 8 (跳过 UDP 头: SrcPort(2)+DstPort(2)+Len(2)+Sum(2)) // 如果你的协议是 UDP,这里必须是 8 (跳过 UDP 头: SrcPort(2)+DstPort(2)+Len(2)+Sum(2))
// 如果是纯自定义 IP 协议,这里是 0 // 如果是纯自定义 IP 协议,这里是 0
...@@ -152,7 +118,7 @@ impl Router { ...@@ -152,7 +118,7 @@ impl Router {
// [IPv4] 计算 IPv4 头长度: X = 4 * (IP[0] & 0xf) // [IPv4] 计算 IPv4 头长度: X = 4 * (IP[0] & 0xf)
bpf_stmt(BPF_LDX | BPF_B | BPF_MSH, 0), bpf_stmt(BPF_LDX | BPF_B | BPF_MSH, 0),
// A = Packet[X + payload_offset] // A = Packet[X + payload_offset]
bpf_stmt(BPF_LD | BPF_W | BPF_IND, payload_offset), bpf_stmt(BPF_LD | BPF_H | BPF_IND, payload_offset),
// if (A == target_val) goto Accept; else goto Reject; // if (A == target_val) goto Accept; else goto Reject;
bpf_jump(BPF_JMP | BPF_JEQ | BPF_K, value, 0, 1), bpf_jump(BPF_JMP | BPF_JEQ | BPF_K, value, 0, 1),
// 【接受 (True 路径)】 // 【接受 (True 路径)】
...@@ -162,7 +128,7 @@ impl Router { ...@@ -162,7 +128,7 @@ impl Router {
], ],
Domain::IPV6 => &[ Domain::IPV6 => &[
// raw socket IPv6 没有 header,加载第 0 字节到累加器 A // raw socket IPv6 没有 header,加载第 0 字节到累加器 A
bpf_stmt(BPF_LD | BPF_W | BPF_ABS, 0), bpf_stmt(BPF_LD | BPF_H | BPF_ABS, 0),
// if (A == target_val) goto Accept; else goto Reject; // if (A == target_val) goto Accept; else goto Reject;
bpf_jump(BPF_JMP | BPF_JEQ | BPF_K, value, 0, 1), bpf_jump(BPF_JMP | BPF_JEQ | BPF_K, value, 0, 1),
// 【接受 (True 路径)】 // 【接受 (True 路径)】
...@@ -178,17 +144,7 @@ impl Router { ...@@ -178,17 +144,7 @@ impl Router {
} }
pub fn attach_filter_udp(group: Vec<&Router>) -> Result<()> { pub fn attach_filter_udp(group: Vec<&Router>) -> Result<()> {
let values: Vec<u32> = group let values: Vec<u16> = group.iter().map(|&f| u16::from_be_bytes([f.config.remote_id, f.config.local_id])).collect();
.iter()
.map(|&f| {
let meta = Meta {
src_id: f.config.remote_id,
dst_id: f.config.local_id,
reversed: 0,
};
u32::from_be_bytes(*meta.as_bytes())
})
.collect();
let mut filters: Vec<SockFilter> = Vec::with_capacity(1 + values.len() * 2 + 1); let mut filters: Vec<SockFilter> = Vec::with_capacity(1 + values.len() * 2 + 1);
...@@ -197,10 +153,10 @@ impl Router { ...@@ -197,10 +153,10 @@ impl Router {
// 根据报文内容判断是给谁的,调度给对应的端口复用组序号 // 根据报文内容判断是给谁的,调度给对应的端口复用组序号
// Load the first 4 bytes of the packet into the accumulator (A) // Load the first 4 bytes of the packet into the accumulator (A)
filters.push(bpf_stmt(BPF_LD | BPF_W | BPF_ABS, 0)); filters.push(bpf_stmt(BPF_LD | BPF_H | BPF_ABS, 0));
for (i, &val) in values.iter().enumerate() { for (i, &val) in values.iter().enumerate() {
// 如果匹配,继续下一句(返回),如果不匹配,跳过下一句。 // 如果匹配,继续下一句(返回),如果不匹配,跳过下一句。
filters.push(bpf_jump(BPF_JMP | BPF_JEQ | BPF_K, val, 0, 1)); filters.push(bpf_jump(BPF_JMP | BPF_JEQ | BPF_K, val as u32, 0, 1));
// If match, return the index (i + 1, since 0 means drop) // If match, return the index (i + 1, since 0 means drop)
filters.push(bpf_stmt(BPF_RET | BPF_K, i as u32)); filters.push(bpf_stmt(BPF_RET | BPF_K, i as u32));
} }
...@@ -226,41 +182,27 @@ impl Router { ...@@ -226,41 +182,27 @@ impl Router {
} }
pub(crate) fn handle_outbound_ip_udp(&self) { pub(crate) fn handle_outbound_ip_udp(&self) {
let mut buffer = [0u8; 1500]; let mut buffer = [0u8; 1504];
// Pre-initialize with our Meta header (local -> remote)
let meta = Meta {
src_id: self.config.local_id,
dst_id: self.config.remote_id,
reversed: 0,
};
buffer[..META_SIZE].copy_from_slice(meta.as_bytes());
loop { loop {
let n = self.tun.recv(&mut buffer[META_SIZE..]).unwrap(); // recv 失败直接 panic let n = self.recv_tun(&mut buffer[..]).unwrap();
let payload = &mut buffer[..n];
payload[0] = self.config.local_id;
payload[1] = self.config.remote_id;
let guard = pin(); let guard = pin();
let endpoint_ref = self.endpoint.load(Ordering::Relaxed, &guard); let endpoint_ref = self.endpoint.load(Ordering::Relaxed, &guard);
if let Some(endpoint) = unsafe { endpoint_ref.as_ref() } { if let Some(endpoint) = unsafe { endpoint_ref.as_ref() } {
self.encrypt(&mut buffer[META_SIZE..META_SIZE + n]); self.encrypt(&mut payload[META_SIZE..]);
let _ = self.socket.send_to(&buffer[..META_SIZE + n], endpoint); let _ = self.socket.send_to(&payload, endpoint);
} }
} }
} }
pub(crate) fn handle_inbound_ip_udp(&self) { pub(crate) fn handle_inbound_ip_udp(&self) {
let mut recv_buf = [MaybeUninit::uninit(); 1500]; let mut buffer = [MaybeUninit::uninit(); 1500];
loop { loop {
// 收到一个非法报文只丢弃一个报文 // 收到一个非法报文只丢弃一个报文
let (len, addr) = { self.socket.recv_from(&mut recv_buf).unwrap() }; // recv 出错直接 panic let (len, addr) = { self.socket.recv_from(&mut buffer).unwrap() }; // recv 出错直接 panic
let packet = unsafe { std::slice::from_raw_parts_mut(recv_buf.as_mut_ptr().cast(), len) };
// if addr.is_ipv6() { println!("{:X?}", packet) }
// 只有 ipv4 raw 会给 IP报头
let offset = if self.config.family == Domain::IPV4 && self.config.schema == Schema::IP {
(packet[0] & 0x0f) as usize * 4
} else {
0
} + META_SIZE;
{ {
let guard = pin(); let guard = pin();
let current_shared = self.endpoint.load(Ordering::Relaxed, &guard); let current_shared = self.endpoint.load(Ordering::Relaxed, &guard);
...@@ -272,9 +214,16 @@ impl Router { ...@@ -272,9 +214,16 @@ impl Router {
} }
} }
let packet: &mut [u8] = unsafe { std::mem::transmute(&mut buffer[..len]) };
// 只有 ipv4 raw 会给 IP报头
let offset = if self.config.family == Domain::IPV4 && self.config.schema == Schema::IP {
(packet[0] & 0x0f) as usize * 4
} else {
0
};
let payload = &mut packet[offset..]; let payload = &mut packet[offset..];
self.decrypt(payload, &self.config.local_secret); self.decrypt(&mut payload[META_SIZE..]);
let _ = self.tun.send(payload); let _ = self.send_tun(payload);
} }
} }
...@@ -282,36 +231,34 @@ impl Router { ...@@ -282,36 +231,34 @@ impl Router {
let _ = (|| -> Result<()> { let _ = (|| -> Result<()> {
let mut buffer = [0u8; 1500]; let mut buffer = [0u8; 1500];
loop { loop {
let n = self.tun.recv(&mut buffer)?; let n = self.recv_tun(&mut buffer)?;
self.encrypt(&mut buffer[..n]); let payload = &mut buffer[..n];
Router::send_all_tcp(&connection, &buffer[..n])?; let len_bytes = (payload.len() as u16).to_le_bytes();
payload[0] = len_bytes[0];
payload[1] = len_bytes[1];
self.encrypt(&mut payload[META_SIZE..]);
Router::send_all_tcp(&connection, &payload)?;
} }
})(); })();
let _ = connection.shutdown(Shutdown::Both); let _ = connection.shutdown(Shutdown::Both);
} }
pub(crate) fn handle_inbound_tcp(&self, connection: &Socket) { pub(crate) fn handle_inbound_tcp(&self, connection: &Socket) {
let _ = (|| -> Result<()> { let _ = (|| -> Result<()> {
let mut buf = [MaybeUninit::uninit(); 1500]; let mut buffer = [0u8; 1500];
let packet: &mut [u8] = unsafe { std::slice::from_raw_parts_mut(buf.as_mut_ptr().cast(), buf.len()) };
loop { loop {
Router::recv_exact_tcp(&connection, &mut buf[0..6])?; Router::recv_exact_tcp(&connection, &mut buffer[..2])?;
self.decrypt2(packet, &self.config.local_secret, 0..6); let len = u16::from_le_bytes([buffer[0], buffer[1]]) as usize;
let version = packet[0] >> 4; Router::recv_exact_tcp(&connection, &mut buffer[2..len])?;
let total_len = match version { let payload = &mut buffer[..len];
4 => u16::from_be_bytes([packet[2], packet[3]]) as usize, self.decrypt(&mut payload[META_SIZE..]);
6 => u16::from_be_bytes([packet[4], packet[5]]) as usize + 40, self.send_tun(payload)?;
_ => bail!("Invalid IP version"),
};
ensure!(6 < total_len && total_len <= buf.len(), "Invalid total length");
Router::recv_exact_tcp(&connection, &mut buf[6..total_len])?;
self.decrypt2(packet, &self.config.local_secret, 6..total_len);
self.tun.send(&packet[..total_len])?;
} }
})(); })();
let _ = connection.shutdown(Shutdown::Both); let _ = connection.shutdown(Shutdown::Both);
} }
pub(crate) fn recv_exact_tcp(sock: &Socket, mut buf: &mut [MaybeUninit<u8>]) -> Result<()> { pub(crate) fn recv_exact_tcp(sock: &Socket, buf: &mut [u8]) -> Result<()> {
let mut buf: &mut [MaybeUninit<u8>] = unsafe { std::mem::transmute(buf) };
while !buf.is_empty() { while !buf.is_empty() {
let n = sock.recv(buf)?; let n = sock.recv(buf)?;
ensure!(n != 0, std::io::ErrorKind::UnexpectedEof); ensure!(n != 0, std::io::ErrorKind::UnexpectedEof);
...@@ -328,14 +275,40 @@ impl Router { ...@@ -328,14 +275,40 @@ impl Router {
Ok(()) Ok(())
} }
fn create_tun_device(config: &ConfigRouter) -> Result<Device> { fn create_tun_device(config: &ConfigRouter) -> File {
println!("create_tun_device {}", config.remote_id); let file = OpenOptions::new().read(true).write(true).open("/dev/net/tun").unwrap();
let mut tun_config = tun::Configuration::default(); let fd = file.as_raw_fd();
tun_config.tun_name(config.dev.as_str()).up();
unsafe {
let mut ifr: libc::ifreq = std::mem::zeroed();
// 关键:只设 IFF_TUN,不设 IFF_NO_PI
ifr.ifr_ifru.ifru_flags = libc::IFF_TUN as i16;
let mut name_bytes = config.dev.as_bytes().to_vec();
name_bytes.push(0);
libc::strncpy(ifr.ifr_name.as_mut_ptr(), name_bytes.as_ptr() as *const i8, 16);
let dev = tun::create(&tun_config)?; let res = libc::ioctl(fd, 0x400454ca, &ifr); // TUNSETIFF 的编码
Ok(dev) println!("{}", res);
if res < 0 {
panic!("ioctl failed");
}
}
file
} }
fn recv_tun(&self, buf: &mut [u8]) -> Result<usize> {
let n = unsafe { libc::read(self.tun.as_raw_fd(), buf.as_mut_ptr() as *mut c_void, buf.len()) };
ensure!(n > 0, "Read error: {}", std::io::Error::last_os_error());
Ok(n as usize)
}
fn send_tun(&self, buf: &[u8]) -> Result<()> {
let n = unsafe { libc::write(self.tun.as_raw_fd(), buf.as_ptr() as *mut c_void, buf.len()) };
ensure!(n > 0, "Write error: {}", std::io::Error::last_os_error());
Ok(())
}
fn run_up_script(config: &ConfigRouter) -> Result<ExitStatus> { fn run_up_script(config: &ConfigRouter) -> Result<ExitStatus> {
Ok(Command::new("sh").args(["-c", config.up.as_str()]).status()?) Ok(Command::new("sh").args(["-c", config.up.as_str()]).status()?)
} }
...@@ -360,7 +333,7 @@ impl Router { ...@@ -360,7 +333,7 @@ impl Router {
pub fn new(config: ConfigRouter) -> Result<Router> { pub fn new(config: ConfigRouter) -> Result<Router> {
println!("creating {}", config.remote_id); println!("creating {}", config.remote_id);
let router = Router { let router = Router {
tun: Self::create_tun_device(&config)?, tun: Self::create_tun_device(&config),
endpoint: Self::create_endpoint(&config), endpoint: Self::create_endpoint(&config),
socket: Self::create_socket(&config)?, socket: Self::create_socket(&config)?,
tcp_listener_connection: Atomic::null(), tcp_listener_connection: Atomic::null(),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment