Commit 2337b2b3 authored by nanamicat's avatar nanamicat

tcp

parent 50d9f945
Pipeline #41943 passed with stages
in 3 minutes and 8 seconds
......@@ -3,10 +3,13 @@ mod router;
use crate::router::{Meta, Router, META_SIZE, SECRET_LENGTH};
use crate::Schema::{TCP, UDP};
use anyhow::{Context, Result};
use crossbeam::epoch::{pin, Owned};
use crossbeam_utils::thread;
use itertools::Itertools;
use serde::{Deserialize, Deserializer};
use socket2::Domain;
use std::net::Shutdown;
use std::sync::atomic::Ordering;
use std::time::Duration;
use std::{collections::HashMap, env, mem::MaybeUninit, sync::Arc};
......@@ -57,7 +60,7 @@ where
fn main() -> Result<()> {
println!("Starting");
let config = Arc::new(serde_json::from_str::<Config>(env::args().nth(1).context("need param")?.as_str())?);
let config = serde_json::from_str::<Config>(env::args().nth(1).context("need param")?.as_str())?;
let local_secret: [u8; SECRET_LENGTH] = Router::create_secret(config.local_secret.as_str())?;
let routers = Arc::new(
......@@ -68,9 +71,9 @@ fn main() -> Result<()> {
.sorted_by_key(|r| r.remote_id)
.map(|c| {
let remote_id = c.remote_id;
Router::new(c, config.local_id).map(|r| (remote_id, Arc::new(r)))
Router::new(c, config.local_id).map(|r| (remote_id, r))
})
.collect::<Result<HashMap<u8, Arc<Router>>, _>>()?,
.collect::<Result<HashMap<u8, Router>, _>>()?,
);
for (_, group) in &routers
......@@ -120,23 +123,44 @@ fn main() -> Result<()> {
// accept 出错直接 panic
loop {
let (connection, _) = router.socket.accept().unwrap();
thread::scope(|s| {
s.spawn(|_| {
connection.set_tcp_nodelay(true).unwrap();
let mut meta_bytes = [MaybeUninit::uninit(); META_SIZE];
Router::recv_exact_tcp(&connection, &mut meta_bytes).unwrap();
let meta: &Meta = unsafe { &*(meta_bytes.as_ptr() as *const Meta) };
if meta.reversed == 0
&& meta.dst_id == config.local_id
&& let Some(router) = routers.get(&meta.src_id)
{
let _ = thread::scope(|s| {
s.spawn(|_| router.handle_outbound_tcp(&connection));
s.spawn(|_| router.handle_inbound_tcp(&connection, &local_secret));
let mut meta_bytes = [MaybeUninit::uninit(); META_SIZE];
Router::recv_exact_tcp(&connection, &mut meta_bytes).unwrap();
let meta: &Meta = Meta::from_bytes(&meta_bytes);
if meta.reversed == 0
&& meta.dst_id == config.local_id
&& let Some(router) = routers.get(&meta.src_id)
{
let connection = Arc::new(connection);
// tcp listener 只许一个连接,过来新连接就把前一个关掉。
{
let guard = pin();
let new_shared = Owned::new(connection.clone()).into_shared(&guard);
let old_shared = router.tcp_listener_connection.swap(new_shared, Ordering::Release, &guard);
unsafe {
if let Some(old) = old_shared.as_ref() {
let _ = old.shutdown(Shutdown::Both);
}
guard.defer_destroy(old_shared)
}
}
let _ = thread::scope(|s| {
s.spawn(|_| router.handle_outbound_tcp(&connection));
s.spawn(|_| router.handle_inbound_tcp(&connection, &local_secret));
});
}
});
}
})
.unwrap();
}
});
}
})
.map_err(|_| anyhow::anyhow!("Thread panicked"))?;
.unwrap();
Ok(())
}
......@@ -4,7 +4,6 @@ use base64::prelude::BASE64_STANDARD;
use base64::Engine;
use socket2::{Domain, Protocol, SockAddr, SockFilter, Socket, Type};
use std::net::Shutdown;
use std::thread::Scope;
use std::{
ffi::c_void,
mem::MaybeUninit,
......@@ -18,7 +17,7 @@ use std::{
use tun::Device;
use crate::Schema::IP;
use crossbeam::epoch::Atomic;
use crossbeam::epoch::{pin, Atomic};
use libc::{
setsockopt, sock_filter, sock_fprog, socklen_t, BPF_ABS, BPF_B, BPF_IND, BPF_JEQ, BPF_JMP, BPF_K, BPF_LD, BPF_LDX, BPF_MSH, BPF_RET, BPF_W,
MSG_FASTOPEN, SOL_SOCKET, SO_ATTACH_REUSEPORT_CBPF,
......@@ -38,7 +37,7 @@ impl Meta {
pub fn as_bytes(&self) -> &[u8; META_SIZE] {
unsafe { &*(self as *const Meta as *const [u8; META_SIZE]) }
}
pub fn from_bytes(bytes: &[u8]) -> &Meta {
pub fn from_bytes(bytes: &[MaybeUninit<u8>; META_SIZE]) -> &Meta {
unsafe { &*(bytes.as_ptr() as *const Meta) }
}
}
......@@ -49,6 +48,8 @@ pub struct Router {
pub tun: Device,
pub socket: Arc<Socket>,
pub endpoint: Arc<Atomic<SockAddr>>,
pub tcp_listener_connection: Arc<Atomic<Arc<Socket>>>,
}
impl Router {
......@@ -111,13 +112,14 @@ impl Router {
// tcp client 的 socket 不要在初始化时创建,在循环里创建
// 创建 socket 和 获取 endpoint 失败会 panic,连接失败会 error
let result = Socket::new(self.config.family, Type::STREAM, Some(Protocol::TCP)).unwrap();
result.set_tcp_nodelay(true).unwrap();
result.set_mark(self.config.mark).unwrap();
let meta = Meta {
src_id: local_id,
dst_id: self.config.remote_id,
reversed: 0,
};
let guard = crossbeam::epoch::pin();
let guard = pin();
let endpoint_ref = self.endpoint.load(Ordering::Relaxed, &guard);
let endpoint = unsafe { endpoint_ref.as_ref() }.unwrap();
......@@ -138,6 +140,11 @@ impl Router {
// 如果是纯自定义 IP 协议,这里是 0
let payload_offset = 0;
// IP filter 工作原理:
// 每个对端起一个 raw socket
// 根据报文内容判断是给谁的。拒绝掉不是给自己的报文
// IPv4 raw socket 带 IP 头,IPv6 不带
let filters: &[SockFilter] = match socket.domain()? {
Domain::IPV4 => &[
// [IPv4] 计算 IPv4 头长度: X = 4 * (IP[0] & 0xf)
......@@ -168,7 +175,7 @@ impl Router {
Ok(())
}
pub fn attach_filter_udp(group: Vec<&Arc<Router>>, local_id: u8) -> Result<()> {
pub fn attach_filter_udp(group: Vec<&Router>, local_id: u8) -> Result<()> {
let values: Vec<u32> = group
.iter()
.map(|&f| {
......@@ -182,6 +189,11 @@ impl Router {
.collect();
let mut filters: Vec<SockFilter> = Vec::with_capacity(1 + values.len() * 2 + 1);
// udp filter 工作原理:
// 每个对端起一个 udp socket
// 根据报文内容判断是给谁的,调度给对应的端口复用组序号
// Load the first 4 bytes of the packet into the accumulator (A)
filters.push(bpf_stmt(BPF_LD | BPF_W | BPF_ABS, 0));
for (i, &val) in values.iter().enumerate() {
......@@ -224,7 +236,7 @@ impl Router {
loop {
let n = self.tun.recv(&mut buffer[META_SIZE..]).unwrap(); // recv 失败直接 panic
let guard = crossbeam::epoch::pin();
let guard = pin();
let endpoint_ref = self.endpoint.load(Ordering::Relaxed, &guard);
if let Some(endpoint) = unsafe { endpoint_ref.as_ref() } {
self.encrypt(&mut buffer[META_SIZE..META_SIZE + n]);
......@@ -248,7 +260,7 @@ impl Router {
} + META_SIZE;
{
let guard = crossbeam::epoch::pin();
let guard = pin();
let current_shared = self.endpoint.load(Ordering::Relaxed, &guard);
let is_same = unsafe { current_shared.as_ref() }.map(|c| *c == addr).unwrap_or(false);
if !is_same {
......@@ -349,7 +361,7 @@ impl Router {
tun: Self::create_tun_device(&config)?,
endpoint: Self::create_endpoint(&config),
socket: Arc::new(Self::create_socket(&config, local_id)?),
tcp_listener_connection: Arc::new(Atomic::null()),
config,
};
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment