Commit 10b685b8 authored by nanahira's avatar nanahira

u32

parent f99f8f1a
...@@ -11,8 +11,8 @@ pub struct Config { ...@@ -11,8 +11,8 @@ pub struct Config {
} }
#[derive(Deserialize, Clone)] #[derive(Deserialize, Clone)]
pub struct ConfigRouter { pub struct ConfigRouter {
pub local_id: u8, pub local_id: u32,
pub remote_id: u8, pub remote_id: u32,
#[serde(default)] #[serde(default)]
pub schema: Schema, pub schema: Schema,
#[serde(default)] #[serde(default)]
......
...@@ -42,7 +42,7 @@ fn main() -> Result<()> { ...@@ -42,7 +42,7 @@ fn main() -> Result<()> {
.into_iter() .into_iter()
.sorted_by_key(|r| r.remote_id) .sorted_by_key(|r| r.remote_id)
.map(|c| Router::new(c).map(|router| (router.config.remote_id, router))) .map(|c| Router::new(c).map(|router| (router.config.remote_id, router)))
.collect::<Result<BTreeMap<u8, Router>, _>>()?; .collect::<Result<BTreeMap<u32, Router>, _>>()?;
for (_, group) in &routers for (_, group) in &routers
.values() .values()
...@@ -99,8 +99,7 @@ fn main() -> Result<()> { ...@@ -99,8 +99,7 @@ fn main() -> Result<()> {
let mut meta_bytes = [MaybeUninit::uninit(); META_SIZE]; let mut meta_bytes = [MaybeUninit::uninit(); META_SIZE];
Router::recv_exact_tcp(&connection, &mut meta_bytes).unwrap(); Router::recv_exact_tcp(&connection, &mut meta_bytes).unwrap();
let meta: &Meta = Meta::from_bytes(&meta_bytes); let meta: &Meta = Meta::from_bytes(&meta_bytes);
if meta.reversed == 0 if let Some(router) = routers.get(&meta.src_id)
&& let Some(router) = routers.get(&meta.src_id)
&& meta.dst_id == router.config.local_id && meta.dst_id == router.config.local_id
{ {
// let connection = Arc::new(connection); // let connection = Arc::new(connection);
......
...@@ -15,7 +15,7 @@ use tun::Device; ...@@ -15,7 +15,7 @@ use tun::Device;
use crate::config::{ConfigRouter, Schema}; use crate::config::{ConfigRouter, Schema};
use crossbeam::epoch::{pin, Atomic}; use crossbeam::epoch::{pin, Atomic};
use libc::{ use libc::{
setsockopt, sock_filter, sock_fprog, socklen_t, BPF_ABS, BPF_B, BPF_IND, BPF_JEQ, BPF_JMP, BPF_K, BPF_LD, BPF_LDX, BPF_MSH, BPF_RET, BPF_W, setsockopt, sock_filter, sock_fprog, socklen_t, BPF_ABS, BPF_B, BPF_IND, BPF_JEQ, BPF_JMP, BPF_K, BPF_LD, BPF_LDX, BPF_MEM, BPF_MSH, BPF_RET, BPF_ST, BPF_W,
MSG_FASTOPEN, SOL_SOCKET, SO_ATTACH_REUSEPORT_CBPF, MSG_FASTOPEN, SOL_SOCKET, SO_ATTACH_REUSEPORT_CBPF,
}; };
...@@ -25,9 +25,8 @@ pub const META_SIZE: usize = size_of::<Meta>(); ...@@ -25,9 +25,8 @@ pub const META_SIZE: usize = size_of::<Meta>();
#[repr(C)] #[repr(C)]
#[derive(Debug, Clone, Copy, Default)] #[derive(Debug, Clone, Copy, Default)]
pub struct Meta { pub struct Meta {
pub src_id: u8, pub src_id: u32,
pub dst_id: u8, pub dst_id: u32,
pub reversed: u16,
} }
impl Meta { impl Meta {
pub fn as_bytes(&self) -> &[u8; META_SIZE] { pub fn as_bytes(&self) -> &[u8; META_SIZE] {
...@@ -119,7 +118,6 @@ impl Router { ...@@ -119,7 +118,6 @@ impl Router {
let meta = Meta { let meta = Meta {
src_id: self.config.local_id, src_id: self.config.local_id,
dst_id: self.config.remote_id, dst_id: self.config.remote_id,
reversed: 0,
}; };
let guard = pin(); let guard = pin();
let endpoint_ref = self.endpoint.load(Ordering::Relaxed, &guard); let endpoint_ref = self.endpoint.load(Ordering::Relaxed, &guard);
...@@ -130,44 +128,54 @@ impl Router { ...@@ -130,44 +128,54 @@ impl Router {
} }
fn attach_filter_ip(config: &ConfigRouter, socket: &Socket) -> Result<()> { fn attach_filter_ip(config: &ConfigRouter, socket: &Socket) -> Result<()> {
// 由于多个对端可能会使用相同的 ipprpto 号,这里确保每个 socket 上只会收到自己对应的对端发来的消息 // 由于多个对端可能会使用相同的 ipproto 号,这里确保每个 socket 上只会收到自己对应的对端发来的消息
let meta = Meta {
src_id: config.remote_id, // 构造 Meta 来计算正确的字节序比较值
dst_id: config.local_id, let meta_bytes = [
reversed: 0, config.remote_id.to_le_bytes(),
}; config.local_id.to_le_bytes(),
let value = u32::from_be_bytes(*meta.as_bytes()); ];
// 如果你的协议是 UDP,这里必须是 8 (跳过 UDP 头: SrcPort(2)+DstPort(2)+Len(2)+Sum(2)) // BPF 按网络字节序(大端序)比较,所以需要把小端序字节当作大端序来构造比较值
// 如果是纯自定义 IP 协议,这里是 0 let expected_src_id = u32::from_be_bytes(meta_bytes[0]);
let payload_offset = 0; let expected_dst_id = u32::from_be_bytes(meta_bytes[1]);
// IP filter 工作原理: // IP filter 工作原理:
// 每个对端起一个 raw socket // 每个对端起一个 raw socket
// 根据报文内容判断是给谁的。拒绝掉不是给自己的报文 // 根据报文内容判断是给谁的。拒绝掉不是给自己的报文
// IPv4 raw socket 带 IP 头,IPv6 不带 // IPv4 raw socket 带 IP 头,IPv6 不带
// Meta 结构:src_id(u32) + dst_id(u32) = 8 字节
let filters: &[SockFilter] = match socket.domain()? { let filters: &[SockFilter] = match socket.domain()? {
Domain::IPV4 => &[ Domain::IPV4 => &[
// [IPv4] 计算 IPv4 头长度: X = 4 * (IP[0] & 0xf) // [IPv4] 计算 IPv4 头长度: X = 4 * (IP[0] & 0xf)
bpf_stmt(BPF_LDX | BPF_B | BPF_MSH, 0), bpf_stmt(BPF_LDX | BPF_B | BPF_MSH, 0),
// A = Packet[X + payload_offset] // A = Packet[X + 0:4] = src_id
bpf_stmt(BPF_LD | BPF_W | BPF_IND, payload_offset), bpf_stmt(BPF_LD | BPF_W | BPF_IND, 0),
// if (A == target_val) goto Accept; else goto Reject; // if A != expected_src_id, goto reject
bpf_jump(BPF_JMP | BPF_JEQ | BPF_K, value, 0, 1), bpf_jump(BPF_JMP | BPF_JEQ | BPF_K, expected_src_id, 0, 3),
// 【接受 (True 路径)】 // A = Packet[X + 4:8] = dst_id
bpf_stmt(BPF_LD | BPF_W | BPF_IND, 4),
// if A != expected_dst_id, goto reject
bpf_jump(BPF_JMP | BPF_JEQ | BPF_K, expected_dst_id, 0, 1),
// 【接受】
bpf_stmt(BPF_RET | BPF_K, u32::MAX), bpf_stmt(BPF_RET | BPF_K, u32::MAX),
// 【拒绝 (False 路径) // 【拒绝】
bpf_stmt(BPF_RET | BPF_K, 0), bpf_stmt(BPF_RET | BPF_K, 0),
], ],
Domain::IPV6 => &[ Domain::IPV6 => &[
// raw socket IPv6 没有 header,加载第 0 字节到累加器 A // raw socket IPv6 没有 header
// A = Packet[0:4] = src_id
bpf_stmt(BPF_LD | BPF_W | BPF_ABS, 0), bpf_stmt(BPF_LD | BPF_W | BPF_ABS, 0),
// if (A == target_val) goto Accept; else goto Reject; // if A != expected_src_id, goto reject
bpf_jump(BPF_JMP | BPF_JEQ | BPF_K, value, 0, 1), bpf_jump(BPF_JMP | BPF_JEQ | BPF_K, expected_src_id, 0, 3),
// 【接受 (True 路径)】 // A = Packet[4:8] = dst_id
bpf_stmt(BPF_LD | BPF_W | BPF_ABS, 4),
// if A != expected_dst_id, goto reject
bpf_jump(BPF_JMP | BPF_JEQ | BPF_K, expected_dst_id, 0, 1),
// 【接受】
bpf_stmt(BPF_RET | BPF_K, u32::MAX), bpf_stmt(BPF_RET | BPF_K, u32::MAX),
// 【拒绝 (False 路径) // 【拒绝】
bpf_stmt(BPF_RET | BPF_K, 0), bpf_stmt(BPF_RET | BPF_K, 0),
], ],
_ => bail!("unsupported family"), _ => bail!("unsupported family"),
...@@ -178,33 +186,44 @@ impl Router { ...@@ -178,33 +186,44 @@ impl Router {
} }
pub fn attach_filter_udp(group: Vec<&Router>) -> Result<()> { pub fn attach_filter_udp(group: Vec<&Router>) -> Result<()> {
let values: Vec<u32> = group // 预留空间:4 条前置指令 + 每个 router 5 条 + 1 条默认返回
.iter() let mut filters: Vec<SockFilter> = Vec::with_capacity(4 + group.len() * 5 + 1);
.map(|&f| {
let meta = Meta {
src_id: f.config.remote_id,
dst_id: f.config.local_id,
reversed: 0,
};
u32::from_be_bytes(*meta.as_bytes())
})
.collect();
let mut filters: Vec<SockFilter> = Vec::with_capacity(1 + values.len() * 2 + 1);
// udp filter 工作原理: // udp filter 工作原理:
// 每个对端起一个 udp socket // 每个对端起一个 udp socket
// 根据报文内容判断是给谁的,调度给对应的端口复用组序号 // 根据报文内容判断是给谁的,调度给对应的端口复用组序号
// Meta 结构:src_id(u32) + dst_id(u32) = 8 字节
// Load the first 4 bytes of the packet into the accumulator (A)
filters.push(bpf_stmt(BPF_LD | BPF_W | BPF_ABS, 0)); // 加载 src_id 并存储到 M[0]
for (i, &val) in values.iter().enumerate() { filters.push(bpf_stmt(BPF_LD | BPF_W | BPF_ABS, 0)); // A = packet[0:4] = src_id
// 如果匹配,继续下一句(返回),如果不匹配,跳过下一句。 filters.push(bpf_stmt(BPF_ST, 0)); // M[0] = A
filters.push(bpf_jump(BPF_JMP | BPF_JEQ | BPF_K, val, 0, 1));
// If match, return the index (i + 1, since 0 means drop) // 加载 dst_id 并存储到 M[1]
filters.push(bpf_stmt(BPF_LD | BPF_W | BPF_ABS, 4)); // A = packet[4:8] = dst_id
filters.push(bpf_stmt(BPF_ST, 1)); // M[1] = A
for (i, router) in group.iter().enumerate() {
// 字节序转换:将小端序ID转换为BPF期望的大端序比较值
let src_bytes = router.config.remote_id.to_le_bytes();
let dst_bytes = router.config.local_id.to_le_bytes();
let expected_src_id = u32::from_be_bytes(src_bytes);
let expected_dst_id = u32::from_be_bytes(dst_bytes);
// 每个 router 5 条指令:
// 0: LD M[0] ; A = src_id
// 1: JEQ expected_src_id, +0, +3 ; 匹配继续,不匹配跳过当前 router
// 2: LD M[1] ; A = dst_id
// 3: JEQ expected_dst_id, +0, +1 ; 匹配继续,不匹配跳过当前 router
// 4: RET i ; 返回索引
filters.push(bpf_stmt(BPF_LD | BPF_MEM, 0));
filters.push(bpf_jump(BPF_JMP | BPF_JEQ | BPF_K, expected_src_id, 0, 3));
filters.push(bpf_stmt(BPF_LD | BPF_MEM, 1));
filters.push(bpf_jump(BPF_JMP | BPF_JEQ | BPF_K, expected_dst_id, 0, 1));
filters.push(bpf_stmt(BPF_RET | BPF_K, i as u32)); filters.push(bpf_stmt(BPF_RET | BPF_K, i as u32));
} }
// If no match found after all comparisons, drop the packet
// 默认返回(不匹配任何 router)
filters.push(bpf_stmt(BPF_RET | BPF_K, u32::MAX)); filters.push(bpf_stmt(BPF_RET | BPF_K, u32::MAX));
let prog = sock_fprog { let prog = sock_fprog {
...@@ -232,7 +251,6 @@ impl Router { ...@@ -232,7 +251,6 @@ impl Router {
let meta = Meta { let meta = Meta {
src_id: self.config.local_id, src_id: self.config.local_id,
dst_id: self.config.remote_id, dst_id: self.config.remote_id,
reversed: 0,
}; };
buffer[..META_SIZE].copy_from_slice(meta.as_bytes()); buffer[..META_SIZE].copy_from_slice(meta.as_bytes());
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment