Commit 1d2cbed5 authored by nanamicat's avatar nanamicat

clean

parent 7d4d292d
mod router; mod router;
use crate::router::{Router, META_SIZE, SECRET_LENGTH};
use crate::Schema::{IP, TCP, UDP}; use crate::Schema::{IP, TCP, UDP};
use anyhow::{bail, ensure, Context, Result}; use crate::router::{META_SIZE, Router, SECRET_LENGTH, Meta};
use anyhow::{Context, Result, bail, ensure};
use crossbeam_utils::thread; use crossbeam_utils::thread;
use itertools::Itertools; use itertools::Itertools;
use serde::{Deserialize, Deserializer}; use serde::{Deserialize, Deserializer};
...@@ -13,16 +13,9 @@ use std::{ ...@@ -13,16 +13,9 @@ use std::{
collections::HashMap, collections::HashMap,
env, env,
mem::MaybeUninit, mem::MaybeUninit,
sync::{atomic::Ordering, Arc}, sync::{Arc, atomic::Ordering},
}; };
#[repr(C)]
pub struct Meta {
pub src_id: u8,
pub dst_id: u8,
pub reversed: u16,
}
#[derive(Deserialize)] #[derive(Deserialize)]
pub struct Config { pub struct Config {
pub local_id: u8, pub local_id: u8,
...@@ -109,7 +102,7 @@ fn main() -> Result<()> { ...@@ -109,7 +102,7 @@ fn main() -> Result<()> {
dst_id: router.config.remote_id, dst_id: router.config.remote_id,
reversed: 0, reversed: 0,
}; };
buffer[..META_SIZE].copy_from_slice(unsafe { &*(&meta as *const Meta as *const [u8; META_SIZE]) }); buffer[..META_SIZE].copy_from_slice(meta.as_bytes());
loop { loop {
let n = router.tun.recv(&mut buffer[META_SIZE..]).unwrap(); // recv 失败直接 panic let n = router.tun.recv(&mut buffer[META_SIZE..]).unwrap(); // recv 失败直接 panic
...@@ -128,8 +121,7 @@ fn main() -> Result<()> { ...@@ -128,8 +121,7 @@ fn main() -> Result<()> {
let _ = (|| -> Result<()> { let _ = (|| -> Result<()> {
// 收到一个非法报文只丢弃一个报文 // 收到一个非法报文只丢弃一个报文
let (len, addr) = { router.socket.recv_from(&mut recv_buf).unwrap() }; // recv 出错直接 panic let (len, addr) = { router.socket.recv_from(&mut recv_buf).unwrap() }; // recv 出错直接 panic
let packet = unsafe { std::slice::from_raw_parts_mut(recv_buf.as_mut_ptr().cast(), len) };
let packet: &mut [u8] = unsafe { std::slice::from_raw_parts_mut(recv_buf.as_mut_ptr() as *mut u8, len) };
// if addr.is_ipv6() { println!("{:X?}", packet) } // if addr.is_ipv6() { println!("{:X?}", packet) }
// 只有 ipv4 raw 会给 IP报头 // 只有 ipv4 raw 会给 IP报头
let offset = if router.config.family == Domain::IPV4 && router.config.schema == IP { let offset = if router.config.family == Domain::IPV4 && router.config.schema == IP {
...@@ -220,7 +212,7 @@ fn handle_tcp(router: &Arc<Router>, connection: socket2::Socket, local_secret: & ...@@ -220,7 +212,7 @@ fn handle_tcp(router: &Arc<Router>, connection: socket2::Socket, local_secret: &
s.spawn(|_| { s.spawn(|_| {
let _ = (|| -> Result<()> { let _ = (|| -> Result<()> {
let mut buf = [MaybeUninit::uninit(); 1500]; let mut buf = [MaybeUninit::uninit(); 1500];
let packet: &mut [u8] = unsafe { std::slice::from_raw_parts_mut(buf.as_mut_ptr() as *mut u8, buf.len()) }; let packet: &mut [u8] = unsafe { std::slice::from_raw_parts_mut(buf.as_mut_ptr().cast(), buf.len()) };
loop { loop {
Router::recv_exact(&connection, &mut buf[0..6])?; Router::recv_exact(&connection, &mut buf[0..6])?;
router.decrypt2(packet, &local_secret, 0..6); router.decrypt2(packet, &local_secret, 0..6);
......
...@@ -12,16 +12,32 @@ use std::{ ...@@ -12,16 +12,32 @@ use std::{
sync::Arc, sync::Arc,
}; };
use crate::{ConfigRouter, Meta, Schema}; use crate::{ConfigRouter, Schema};
use anyhow::{bail, ensure, Result}; use anyhow::{bail, ensure, Error, Result};
use tun::Device; use tun::Device;
use crossbeam::epoch::Atomic; use crossbeam::epoch::Atomic;
use libc::{sock_filter, sock_fprog, socklen_t, BPF_ABS, BPF_B, BPF_IND, BPF_JEQ, BPF_JMP, BPF_K, BPF_LD, BPF_LDX, BPF_MSH, BPF_RET, BPF_W, MSG_FASTOPEN, SOL_SOCKET, SO_ATTACH_REUSEPORT_CBPF}; use libc::{setsockopt, sock_filter, sock_fprog, socklen_t, BPF_ABS, BPF_B, BPF_IND, BPF_JEQ, BPF_JMP, BPF_K, BPF_LD, BPF_LDX, BPF_MSH, BPF_RET, BPF_W, MSG_FASTOPEN, SOL_SOCKET, SO_ATTACH_REUSEPORT_CBPF};
pub const SECRET_LENGTH: usize = 32; pub const SECRET_LENGTH: usize = 32;
pub const META_SIZE: usize = size_of::<Meta>(); pub const META_SIZE: usize = size_of::<Meta>();
#[repr(C)]
#[derive(Debug, Clone, Copy, Default)]
pub struct Meta {
pub src_id: u8,
pub dst_id: u8,
pub reversed: u16,
}
impl Meta {
pub fn as_bytes(&self) -> &[u8; META_SIZE] {
unsafe { &*(self as *const Meta as *const [u8; META_SIZE]) }
}
pub fn from_bytes(bytes: &[u8]) -> &Meta {
unsafe { &*(bytes.as_ptr() as *const Meta) }
}
}
pub struct Router { pub struct Router {
pub config: ConfigRouter, pub config: ConfigRouter,
pub secret: [u8; SECRET_LENGTH], pub secret: [u8; SECRET_LENGTH],
...@@ -32,11 +48,7 @@ pub struct Router { ...@@ -32,11 +48,7 @@ pub struct Router {
impl Router { impl Router {
pub(crate) fn create_secret(config: &str) -> Result<[u8; SECRET_LENGTH]> { pub(crate) fn create_secret(config: &str) -> Result<[u8; SECRET_LENGTH]> {
let mut secret = [0u8; SECRET_LENGTH]; BASE64_STANDARD.decode(config)?.as_slice().try_into().map_err(Error::from)
let decoded = BASE64_STANDARD.decode(config)?;
let len = decoded.len().min(SECRET_LENGTH);
secret[..len].copy_from_slice(&decoded[..len]);
Ok(secret)
} }
pub(crate) fn decrypt(&self, data: &mut [u8], secret: &[u8; SECRET_LENGTH]) { pub(crate) fn decrypt(&self, data: &mut [u8], secret: &[u8; SECRET_LENGTH]) {
...@@ -81,14 +93,12 @@ impl Router { ...@@ -81,14 +93,12 @@ impl Router {
fn attach_filter_raw(config: &ConfigRouter, local_id: u8, socket: &Socket) -> Result<()> { fn attach_filter_raw(config: &ConfigRouter, local_id: u8, socket: &Socket) -> Result<()> {
// 由于多个对端可能会使用相同的 ipprpto 号,这里确保每个 socket 上只会收到自己对应的对端发来的消息 // 由于多个对端可能会使用相同的 ipprpto 号,这里确保每个 socket 上只会收到自己对应的对端发来的消息
const META_SIZE: usize = size_of::<Meta>();
let meta = Meta { let meta = Meta {
src_id: config.remote_id, src_id: config.remote_id,
dst_id: local_id, dst_id: local_id,
reversed: 0, reversed: 0,
}; };
let meta_bytes: [u8; META_SIZE] = unsafe { *(&meta as *const Meta as *const [u8; META_SIZE]) }; let value = u32::from_be_bytes(*meta.as_bytes());
let value = u32::from_be_bytes(meta_bytes);
// 如果你的协议是 UDP,这里必须是 8 (跳过 UDP 头: SrcPort(2)+DstPort(2)+Len(2)+Sum(2)) // 如果你的协议是 UDP,这里必须是 8 (跳过 UDP 头: SrcPort(2)+DstPort(2)+Len(2)+Sum(2))
// 如果是纯自定义 IP 协议,这里是 0 // 如果是纯自定义 IP 协议,这里是 0
...@@ -150,8 +160,7 @@ impl Router { ...@@ -150,8 +160,7 @@ impl Router {
dst_id: local_id, dst_id: local_id,
reversed: 0, reversed: 0,
}; };
let meta_bytes: [u8; META_SIZE] = unsafe { *(&meta as *const Meta as *const [u8; META_SIZE]) }; u32::from_be_bytes(*meta.as_bytes())
u32::from_be_bytes(meta_bytes)
}) })
.collect(); .collect();
...@@ -178,7 +187,7 @@ impl Router { ...@@ -178,7 +187,7 @@ impl Router {
let fd = sock.as_raw_fd(); let fd = sock.as_raw_fd();
let ret = unsafe { let ret = unsafe {
libc::setsockopt( setsockopt(
fd, fd,
SOL_SOCKET, SOL_SOCKET,
SO_ATTACH_REUSEPORT_CBPF, SO_ATTACH_REUSEPORT_CBPF,
...@@ -256,8 +265,7 @@ impl Router { ...@@ -256,8 +265,7 @@ impl Router {
dst_id: self.config.remote_id, dst_id: self.config.remote_id,
reversed: 0, reversed: 0,
}; };
let buf = unsafe { std::slice::from_raw_parts(&meta as *const Meta as *const u8, META_SIZE) }; result.send_to_with_flags(meta.as_bytes(), endpoint, MSG_FASTOPEN)?;
result.send_to_with_flags(buf, endpoint, MSG_FASTOPEN)?;
Ok(result) Ok(result)
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment