Commit 2337b2b3 authored by nanamicat's avatar nanamicat

tcp

parent 50d9f945
Pipeline #41943 passed with stages
in 3 minutes and 8 seconds
...@@ -3,10 +3,13 @@ mod router; ...@@ -3,10 +3,13 @@ mod router;
use crate::router::{Meta, Router, META_SIZE, SECRET_LENGTH}; use crate::router::{Meta, Router, META_SIZE, SECRET_LENGTH};
use crate::Schema::{TCP, UDP}; use crate::Schema::{TCP, UDP};
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use crossbeam::epoch::{pin, Owned};
use crossbeam_utils::thread; use crossbeam_utils::thread;
use itertools::Itertools; use itertools::Itertools;
use serde::{Deserialize, Deserializer}; use serde::{Deserialize, Deserializer};
use socket2::Domain; use socket2::Domain;
use std::net::Shutdown;
use std::sync::atomic::Ordering;
use std::time::Duration; use std::time::Duration;
use std::{collections::HashMap, env, mem::MaybeUninit, sync::Arc}; use std::{collections::HashMap, env, mem::MaybeUninit, sync::Arc};
...@@ -57,7 +60,7 @@ where ...@@ -57,7 +60,7 @@ where
fn main() -> Result<()> { fn main() -> Result<()> {
println!("Starting"); println!("Starting");
let config = Arc::new(serde_json::from_str::<Config>(env::args().nth(1).context("need param")?.as_str())?); let config = serde_json::from_str::<Config>(env::args().nth(1).context("need param")?.as_str())?;
let local_secret: [u8; SECRET_LENGTH] = Router::create_secret(config.local_secret.as_str())?; let local_secret: [u8; SECRET_LENGTH] = Router::create_secret(config.local_secret.as_str())?;
let routers = Arc::new( let routers = Arc::new(
...@@ -68,9 +71,9 @@ fn main() -> Result<()> { ...@@ -68,9 +71,9 @@ fn main() -> Result<()> {
.sorted_by_key(|r| r.remote_id) .sorted_by_key(|r| r.remote_id)
.map(|c| { .map(|c| {
let remote_id = c.remote_id; let remote_id = c.remote_id;
Router::new(c, config.local_id).map(|r| (remote_id, Arc::new(r))) Router::new(c, config.local_id).map(|r| (remote_id, r))
}) })
.collect::<Result<HashMap<u8, Arc<Router>>, _>>()?, .collect::<Result<HashMap<u8, Router>, _>>()?,
); );
for (_, group) in &routers for (_, group) in &routers
...@@ -120,23 +123,44 @@ fn main() -> Result<()> { ...@@ -120,23 +123,44 @@ fn main() -> Result<()> {
// accept 出错直接 panic // accept 出错直接 panic
loop { loop {
let (connection, _) = router.socket.accept().unwrap(); let (connection, _) = router.socket.accept().unwrap();
thread::scope(|s| {
s.spawn(|_| {
connection.set_tcp_nodelay(true).unwrap();
let mut meta_bytes = [MaybeUninit::uninit(); META_SIZE]; let mut meta_bytes = [MaybeUninit::uninit(); META_SIZE];
Router::recv_exact_tcp(&connection, &mut meta_bytes).unwrap(); Router::recv_exact_tcp(&connection, &mut meta_bytes).unwrap();
let meta: &Meta = unsafe { &*(meta_bytes.as_ptr() as *const Meta) }; let meta: &Meta = Meta::from_bytes(&meta_bytes);
if meta.reversed == 0 if meta.reversed == 0
&& meta.dst_id == config.local_id && meta.dst_id == config.local_id
&& let Some(router) = routers.get(&meta.src_id) && let Some(router) = routers.get(&meta.src_id)
{ {
let connection = Arc::new(connection);
// tcp listener 只许一个连接,过来新连接就把前一个关掉。
{
let guard = pin();
let new_shared = Owned::new(connection.clone()).into_shared(&guard);
let old_shared = router.tcp_listener_connection.swap(new_shared, Ordering::Release, &guard);
unsafe {
if let Some(old) = old_shared.as_ref() {
let _ = old.shutdown(Shutdown::Both);
}
guard.defer_destroy(old_shared)
}
}
let _ = thread::scope(|s| { let _ = thread::scope(|s| {
s.spawn(|_| router.handle_outbound_tcp(&connection)); s.spawn(|_| router.handle_outbound_tcp(&connection));
s.spawn(|_| router.handle_inbound_tcp(&connection, &local_secret)); s.spawn(|_| router.handle_inbound_tcp(&connection, &local_secret));
}); });
} }
});
})
.unwrap();
} }
}); });
} }
}) })
.map_err(|_| anyhow::anyhow!("Thread panicked"))?; .unwrap();
Ok(()) Ok(())
} }
...@@ -4,7 +4,6 @@ use base64::prelude::BASE64_STANDARD; ...@@ -4,7 +4,6 @@ use base64::prelude::BASE64_STANDARD;
use base64::Engine; use base64::Engine;
use socket2::{Domain, Protocol, SockAddr, SockFilter, Socket, Type}; use socket2::{Domain, Protocol, SockAddr, SockFilter, Socket, Type};
use std::net::Shutdown; use std::net::Shutdown;
use std::thread::Scope;
use std::{ use std::{
ffi::c_void, ffi::c_void,
mem::MaybeUninit, mem::MaybeUninit,
...@@ -18,7 +17,7 @@ use std::{ ...@@ -18,7 +17,7 @@ use std::{
use tun::Device; use tun::Device;
use crate::Schema::IP; use crate::Schema::IP;
use crossbeam::epoch::Atomic; use crossbeam::epoch::{pin, Atomic};
use libc::{ use libc::{
setsockopt, sock_filter, sock_fprog, socklen_t, BPF_ABS, BPF_B, BPF_IND, BPF_JEQ, BPF_JMP, BPF_K, BPF_LD, BPF_LDX, BPF_MSH, BPF_RET, BPF_W, setsockopt, sock_filter, sock_fprog, socklen_t, BPF_ABS, BPF_B, BPF_IND, BPF_JEQ, BPF_JMP, BPF_K, BPF_LD, BPF_LDX, BPF_MSH, BPF_RET, BPF_W,
MSG_FASTOPEN, SOL_SOCKET, SO_ATTACH_REUSEPORT_CBPF, MSG_FASTOPEN, SOL_SOCKET, SO_ATTACH_REUSEPORT_CBPF,
...@@ -38,7 +37,7 @@ impl Meta { ...@@ -38,7 +37,7 @@ impl Meta {
pub fn as_bytes(&self) -> &[u8; META_SIZE] { pub fn as_bytes(&self) -> &[u8; META_SIZE] {
unsafe { &*(self as *const Meta as *const [u8; META_SIZE]) } unsafe { &*(self as *const Meta as *const [u8; META_SIZE]) }
} }
pub fn from_bytes(bytes: &[u8]) -> &Meta { pub fn from_bytes(bytes: &[MaybeUninit<u8>; META_SIZE]) -> &Meta {
unsafe { &*(bytes.as_ptr() as *const Meta) } unsafe { &*(bytes.as_ptr() as *const Meta) }
} }
} }
...@@ -49,6 +48,8 @@ pub struct Router { ...@@ -49,6 +48,8 @@ pub struct Router {
pub tun: Device, pub tun: Device,
pub socket: Arc<Socket>, pub socket: Arc<Socket>,
pub endpoint: Arc<Atomic<SockAddr>>, pub endpoint: Arc<Atomic<SockAddr>>,
pub tcp_listener_connection: Arc<Atomic<Arc<Socket>>>,
} }
impl Router { impl Router {
...@@ -111,13 +112,14 @@ impl Router { ...@@ -111,13 +112,14 @@ impl Router {
// tcp client 的 socket 不要在初始化时创建,在循环里创建 // tcp client 的 socket 不要在初始化时创建,在循环里创建
// 创建 socket 和 获取 endpoint 失败会 panic,连接失败会 error // 创建 socket 和 获取 endpoint 失败会 panic,连接失败会 error
let result = Socket::new(self.config.family, Type::STREAM, Some(Protocol::TCP)).unwrap(); let result = Socket::new(self.config.family, Type::STREAM, Some(Protocol::TCP)).unwrap();
result.set_tcp_nodelay(true).unwrap();
result.set_mark(self.config.mark).unwrap(); result.set_mark(self.config.mark).unwrap();
let meta = Meta { let meta = Meta {
src_id: local_id, src_id: local_id,
dst_id: self.config.remote_id, dst_id: self.config.remote_id,
reversed: 0, reversed: 0,
}; };
let guard = crossbeam::epoch::pin(); let guard = pin();
let endpoint_ref = self.endpoint.load(Ordering::Relaxed, &guard); let endpoint_ref = self.endpoint.load(Ordering::Relaxed, &guard);
let endpoint = unsafe { endpoint_ref.as_ref() }.unwrap(); let endpoint = unsafe { endpoint_ref.as_ref() }.unwrap();
...@@ -138,6 +140,11 @@ impl Router { ...@@ -138,6 +140,11 @@ impl Router {
// 如果是纯自定义 IP 协议,这里是 0 // 如果是纯自定义 IP 协议,这里是 0
let payload_offset = 0; let payload_offset = 0;
// IP filter 工作原理:
// 每个对端起一个 raw socket
// 根据报文内容判断是给谁的。拒绝掉不是给自己的报文
// IPv4 raw socket 带 IP 头,IPv6 不带
let filters: &[SockFilter] = match socket.domain()? { let filters: &[SockFilter] = match socket.domain()? {
Domain::IPV4 => &[ Domain::IPV4 => &[
// [IPv4] 计算 IPv4 头长度: X = 4 * (IP[0] & 0xf) // [IPv4] 计算 IPv4 头长度: X = 4 * (IP[0] & 0xf)
...@@ -168,7 +175,7 @@ impl Router { ...@@ -168,7 +175,7 @@ impl Router {
Ok(()) Ok(())
} }
pub fn attach_filter_udp(group: Vec<&Arc<Router>>, local_id: u8) -> Result<()> { pub fn attach_filter_udp(group: Vec<&Router>, local_id: u8) -> Result<()> {
let values: Vec<u32> = group let values: Vec<u32> = group
.iter() .iter()
.map(|&f| { .map(|&f| {
...@@ -182,6 +189,11 @@ impl Router { ...@@ -182,6 +189,11 @@ impl Router {
.collect(); .collect();
let mut filters: Vec<SockFilter> = Vec::with_capacity(1 + values.len() * 2 + 1); let mut filters: Vec<SockFilter> = Vec::with_capacity(1 + values.len() * 2 + 1);
// udp filter 工作原理:
// 每个对端起一个 udp socket
// 根据报文内容判断是给谁的,调度给对应的端口复用组序号
// Load the first 4 bytes of the packet into the accumulator (A) // Load the first 4 bytes of the packet into the accumulator (A)
filters.push(bpf_stmt(BPF_LD | BPF_W | BPF_ABS, 0)); filters.push(bpf_stmt(BPF_LD | BPF_W | BPF_ABS, 0));
for (i, &val) in values.iter().enumerate() { for (i, &val) in values.iter().enumerate() {
...@@ -224,7 +236,7 @@ impl Router { ...@@ -224,7 +236,7 @@ impl Router {
loop { loop {
let n = self.tun.recv(&mut buffer[META_SIZE..]).unwrap(); // recv 失败直接 panic let n = self.tun.recv(&mut buffer[META_SIZE..]).unwrap(); // recv 失败直接 panic
let guard = crossbeam::epoch::pin(); let guard = pin();
let endpoint_ref = self.endpoint.load(Ordering::Relaxed, &guard); let endpoint_ref = self.endpoint.load(Ordering::Relaxed, &guard);
if let Some(endpoint) = unsafe { endpoint_ref.as_ref() } { if let Some(endpoint) = unsafe { endpoint_ref.as_ref() } {
self.encrypt(&mut buffer[META_SIZE..META_SIZE + n]); self.encrypt(&mut buffer[META_SIZE..META_SIZE + n]);
...@@ -248,7 +260,7 @@ impl Router { ...@@ -248,7 +260,7 @@ impl Router {
} + META_SIZE; } + META_SIZE;
{ {
let guard = crossbeam::epoch::pin(); let guard = pin();
let current_shared = self.endpoint.load(Ordering::Relaxed, &guard); let current_shared = self.endpoint.load(Ordering::Relaxed, &guard);
let is_same = unsafe { current_shared.as_ref() }.map(|c| *c == addr).unwrap_or(false); let is_same = unsafe { current_shared.as_ref() }.map(|c| *c == addr).unwrap_or(false);
if !is_same { if !is_same {
...@@ -349,7 +361,7 @@ impl Router { ...@@ -349,7 +361,7 @@ impl Router {
tun: Self::create_tun_device(&config)?, tun: Self::create_tun_device(&config)?,
endpoint: Self::create_endpoint(&config), endpoint: Self::create_endpoint(&config),
socket: Arc::new(Self::create_socket(&config, local_id)?), socket: Arc::new(Self::create_socket(&config, local_id)?),
tcp_listener_connection: Arc::new(Atomic::null()),
config, config,
}; };
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment