Commit 50d9f945 authored by nanamicat's avatar nanamicat

tcp

parent 1d2cbed5
mod router; mod router;
use crate::Schema::{IP, TCP, UDP}; use crate::router::{Meta, Router, META_SIZE, SECRET_LENGTH};
use crate::router::{META_SIZE, Router, SECRET_LENGTH, Meta}; use crate::Schema::{TCP, UDP};
use anyhow::{Context, Result, bail, ensure}; use anyhow::{Context, Result};
use crossbeam_utils::thread; use crossbeam_utils::thread;
use itertools::Itertools; use itertools::Itertools;
use serde::{Deserialize, Deserializer}; use serde::{Deserialize, Deserializer};
use socket2::Domain; use socket2::Domain;
use std::net::Shutdown;
use std::time::Duration; use std::time::Duration;
use std::{ use std::{collections::HashMap, env, mem::MaybeUninit, sync::Arc};
collections::HashMap,
env,
mem::MaybeUninit,
sync::{Arc, atomic::Ordering},
};
#[derive(Deserialize)] #[derive(Deserialize)]
pub struct Config { pub struct Config {
...@@ -94,67 +88,22 @@ fn main() -> Result<()> { ...@@ -94,67 +88,22 @@ fn main() -> Result<()> {
// IP, UDP // IP, UDP
for router in routers.values().filter(|&r| !(r.config.schema != TCP)) { for router in routers.values().filter(|&r| !(r.config.schema != TCP)) {
s.spawn(|_| { s.spawn(|_| {
let mut buffer = [0u8; 1500]; router.handle_outbound_ip_udp(config.local_id);
// Pre-initialize with our Meta header (local -> remote)
let meta = Meta {
src_id: config.local_id,
dst_id: router.config.remote_id,
reversed: 0,
};
buffer[..META_SIZE].copy_from_slice(meta.as_bytes());
loop {
let n = router.tun.recv(&mut buffer[META_SIZE..]).unwrap(); // recv 失败直接 panic
let guard = crossbeam::epoch::pin();
let endpoint_ref = router.endpoint.load(Ordering::Relaxed, &guard);
if let Some(endpoint) = unsafe { endpoint_ref.as_ref() } {
router.encrypt(&mut buffer[META_SIZE..META_SIZE + n]);
let _ = router.socket.send_to(&buffer[..META_SIZE + n], endpoint);
}
}
}); });
s.spawn(|_| { s.spawn(|_| {
let mut recv_buf = [MaybeUninit::uninit(); 1500]; router.handle_inbound_ip_udp(&local_secret);
loop {
let _ = (|| -> Result<()> {
// 收到一个非法报文只丢弃一个报文
let (len, addr) = { router.socket.recv_from(&mut recv_buf).unwrap() }; // recv 出错直接 panic
let packet = unsafe { std::slice::from_raw_parts_mut(recv_buf.as_mut_ptr().cast(), len) };
// if addr.is_ipv6() { println!("{:X?}", packet) }
// 只有 ipv4 raw 会给 IP报头
let offset = if router.config.family == Domain::IPV4 && router.config.schema == IP {
(packet[0] & 0x0f) as usize * 4
} else {
0
} + META_SIZE;
{
let guard = crossbeam::epoch::pin();
let current_shared = router.endpoint.load(Ordering::Relaxed, &guard);
let is_same = unsafe { current_shared.as_ref() }.map(|c| *c == addr).unwrap_or(false);
if !is_same {
let new_shared = crossbeam::epoch::Owned::new(addr).into_shared(&guard);
let old_shared = router.endpoint.swap(new_shared, Ordering::Release, &guard);
unsafe { guard.defer_destroy(old_shared) }
}
}
let payload = &mut packet[offset..];
router.decrypt(payload, &local_secret);
router.tun.send(payload)?;
Ok(())
})();
}
}); });
} }
for router in routers.values().filter(|&r| r.config.schema == TCP && r.config.dst_port != 0) { for router in routers.values().filter(|&r| r.config.schema == TCP && r.config.dst_port != 0) {
s.spawn(|_| { s.spawn(|_| {
loop { loop {
if let Ok(connection) = router.tcp_connect(config.local_id) { if let Ok(connection) = router.connect_tcp(config.local_id) {
let _ = handle_tcp(router, connection, &local_secret); let _ = thread::scope(|s| {
s.spawn(|_| router.handle_outbound_tcp(&connection));
s.spawn(|_| router.handle_inbound_tcp(&connection, &local_secret));
});
} }
std::thread::sleep(Duration::from_secs(TCP_RECONNECT)); std::thread::sleep(Duration::from_secs(TCP_RECONNECT));
} }
...@@ -168,68 +117,25 @@ fn main() -> Result<()> { ...@@ -168,68 +117,25 @@ fn main() -> Result<()> {
.unique_by(|r| r.config.src_port) .unique_by(|r| r.config.src_port)
{ {
s.spawn(|_| { s.spawn(|_| {
// listen 或 accept 出错直接 panic // accept 出错直接 panic
let listener = router.tcp_listen().unwrap();
loop { loop {
let (connection, _) = listener.accept().unwrap(); let (connection, _) = router.socket.accept().unwrap();
let _ = (|| -> Result<()> {
// 为了写起来方便,每个 tcp connection 有两秒钟时间发送握手报文,如果没收到就关闭连接再响应下一个
// 正常的连接是 fast open 的,期望会瞬间连好
connection.set_read_timeout(Option::from(Duration::from_secs(2)))?;
let mut meta_bytes = [MaybeUninit::uninit(); META_SIZE]; let mut meta_bytes = [MaybeUninit::uninit(); META_SIZE];
Router::recv_exact(&connection, &mut meta_bytes)?; Router::recv_exact_tcp(&connection, &mut meta_bytes).unwrap();
let meta: &Meta = unsafe { &*(meta_bytes.as_ptr() as *const Meta) }; let meta: &Meta = unsafe { &*(meta_bytes.as_ptr() as *const Meta) };
if meta.reversed == 0 if meta.reversed == 0
&& meta.dst_id == config.local_id && meta.dst_id == config.local_id
&& let Some(router) = routers.get(&meta.src_id) && let Some(router) = routers.get(&meta.src_id)
{ {
connection.set_read_timeout(None)?; let _ = thread::scope(|s| {
handle_tcp(router, connection, &local_secret)?; s.spawn(|_| router.handle_outbound_tcp(&connection));
} s.spawn(|_| router.handle_inbound_tcp(&connection, &local_secret));
Ok(())
})();
}
}); });
} }
})
.map_err(|_| anyhow::anyhow!("Thread panicked"))?;
Ok(())
}
fn handle_tcp(router: &Arc<Router>, connection: socket2::Socket, local_secret: &[u8; 32]) -> Result<()> {
thread::scope(|s| {
s.spawn(|_| {
let _ = (|| -> Result<()> {
let mut buffer = [0u8; 1500];
loop {
let n = router.tun.recv(&mut buffer)?;
router.encrypt(&mut buffer[..n]);
Router::send_all(&connection, &buffer[..n])?;
} }
})();
let _ = connection.shutdown(Shutdown::Both);
}); });
s.spawn(|_| {
let _ = (|| -> Result<()> {
let mut buf = [MaybeUninit::uninit(); 1500];
let packet: &mut [u8] = unsafe { std::slice::from_raw_parts_mut(buf.as_mut_ptr().cast(), buf.len()) };
loop {
Router::recv_exact(&connection, &mut buf[0..6])?;
router.decrypt2(packet, &local_secret, 0..6);
let version = packet[0] >> 4;
let total_len = match version {
4 => u16::from_be_bytes([packet[2], packet[3]]) as usize,
6 => u16::from_be_bytes([packet[4], packet[5]]) as usize + 40,
_ => bail!("Invalid IP version"),
};
ensure!(6 < total_len && total_len <= buf.len(), "Invalid total length");
Router::recv_exact(&connection, &mut buf[6..total_len])?;
router.decrypt2(packet, &local_secret, 6..total_len);
router.tun.send(&packet[..total_len])?;
} }
})();
let _ = connection.shutdown(Shutdown::Both);
});
}) })
.map_err(|_| anyhow::anyhow!("Thread panicked"))?; .map_err(|_| anyhow::anyhow!("Thread panicked"))?;
Ok(()) Ok(())
......
use crate::{ConfigRouter, Schema};
use anyhow::{bail, ensure, Error, Result};
use base64::prelude::BASE64_STANDARD; use base64::prelude::BASE64_STANDARD;
use base64::Engine; use base64::Engine;
use socket2::{Domain, Protocol, SockAddr, SockFilter, Socket, Type}; use socket2::{Domain, Protocol, SockAddr, SockFilter, Socket, Type};
use std::net::Shutdown;
use std::thread::Scope;
use std::{ use std::{
ffi::c_void, ffi::c_void,
mem::MaybeUninit, mem::MaybeUninit,
...@@ -11,13 +15,14 @@ use std::{ ...@@ -11,13 +15,14 @@ use std::{
sync::atomic::Ordering, sync::atomic::Ordering,
sync::Arc, sync::Arc,
}; };
use crate::{ConfigRouter, Schema};
use anyhow::{bail, ensure, Error, Result};
use tun::Device; use tun::Device;
use crate::Schema::IP;
use crossbeam::epoch::Atomic; use crossbeam::epoch::Atomic;
use libc::{setsockopt, sock_filter, sock_fprog, socklen_t, BPF_ABS, BPF_B, BPF_IND, BPF_JEQ, BPF_JMP, BPF_K, BPF_LD, BPF_LDX, BPF_MSH, BPF_RET, BPF_W, MSG_FASTOPEN, SOL_SOCKET, SO_ATTACH_REUSEPORT_CBPF}; use libc::{
setsockopt, sock_filter, sock_fprog, socklen_t, BPF_ABS, BPF_B, BPF_IND, BPF_JEQ, BPF_JMP, BPF_K, BPF_LD, BPF_LDX, BPF_MSH, BPF_RET, BPF_W,
MSG_FASTOPEN, SOL_SOCKET, SO_ATTACH_REUSEPORT_CBPF,
};
pub const SECRET_LENGTH: usize = 32; pub const SECRET_LENGTH: usize = 32;
pub const META_SIZE: usize = size_of::<Meta>(); pub const META_SIZE: usize = size_of::<Meta>();
...@@ -68,13 +73,13 @@ impl Router { ...@@ -68,13 +73,13 @@ impl Router {
} }
} }
pub fn create_socket(config: &ConfigRouter, local_id: u8) -> Result<Arc<Socket>> { pub fn create_socket(config: &ConfigRouter, local_id: u8) -> Result<Socket> {
match config.schema { match config.schema {
Schema::IP => { Schema::IP => {
let result = Socket::new(config.family, Type::RAW, Some(Protocol::from(config.proto as i32)))?; let result = Socket::new(config.family, Type::RAW, Some(Protocol::from(config.proto as i32)))?;
result.set_mark(config.mark)?; result.set_mark(config.mark)?;
Self::attach_filter_raw(config, local_id, &result)?; Self::attach_filter_ip(config, local_id, &result)?;
Ok(Arc::new(result)) Ok(result)
} }
Schema::UDP => { Schema::UDP => {
let result = Socket::new(config.family, Type::DGRAM, Some(Protocol::UDP))?; let result = Socket::new(config.family, Type::DGRAM, Some(Protocol::UDP))?;
...@@ -84,14 +89,43 @@ impl Router { ...@@ -84,14 +89,43 @@ impl Router {
let addr = Self::bind_addr(config); let addr = Self::bind_addr(config);
result.bind(&addr)?; result.bind(&addr)?;
} }
Ok(Arc::new(result)) Ok(result)
}
Schema::TCP => {
if config.dst_port == 0 {
// listener
let result = Socket::new(config.family, Type::STREAM, Some(Protocol::TCP))?;
let addr = Router::bind_addr(config);
result.bind(&addr)?;
result.listen(100)?;
Ok(result)
} else {
// tcp client 初始化时不创建 socket,在循环里使用 connect_tcp 来创建
Ok(unsafe { Socket::from_raw_fd(0) })
}
} }
// TCP 不需要一开始创建 socket,在运行时管理
Schema::TCP => Ok(Arc::new(unsafe { Socket::from_raw_fd(0) })),
} }
} }
fn attach_filter_raw(config: &ConfigRouter, local_id: u8, socket: &Socket) -> Result<()> { pub fn connect_tcp(&self, local_id: u8) -> Result<Socket> {
// tcp client 的 socket 不要在初始化时创建,在循环里创建
// 创建 socket 和 获取 endpoint 失败会 panic,连接失败会 error
let result = Socket::new(self.config.family, Type::STREAM, Some(Protocol::TCP)).unwrap();
result.set_mark(self.config.mark).unwrap();
let meta = Meta {
src_id: local_id,
dst_id: self.config.remote_id,
reversed: 0,
};
let guard = crossbeam::epoch::pin();
let endpoint_ref = self.endpoint.load(Ordering::Relaxed, &guard);
let endpoint = unsafe { endpoint_ref.as_ref() }.unwrap();
result.send_to_with_flags(meta.as_bytes(), endpoint, MSG_FASTOPEN)?;
Ok(result)
}
fn attach_filter_ip(config: &ConfigRouter, local_id: u8, socket: &Socket) -> Result<()> {
// 由于多个对端可能会使用相同的 ipprpto 号,这里确保每个 socket 上只会收到自己对应的对端发来的消息 // 由于多个对端可能会使用相同的 ipprpto 号,这里确保每个 socket 上只会收到自己对应的对端发来的消息
let meta = Meta { let meta = Meta {
src_id: config.remote_id, src_id: config.remote_id,
...@@ -134,23 +168,6 @@ impl Router { ...@@ -134,23 +168,6 @@ impl Router {
Ok(()) Ok(())
} }
pub(crate) fn recv_exact(sock: &Socket, mut buf: &mut [MaybeUninit<u8>]) -> Result<()> {
while !buf.is_empty() {
let n = sock.recv(buf)?;
ensure!(n != 0, std::io::ErrorKind::UnexpectedEof);
buf = &mut buf[n..];
}
Ok(())
}
pub(crate) fn send_all(sock: &Socket, mut buf: &[u8]) -> Result<()> {
while !buf.is_empty() {
let n = sock.send(buf)?;
buf = &buf[n..];
}
Ok(())
}
pub fn attach_filter_udp(group: Vec<&Arc<Router>>, local_id: u8) -> Result<()> { pub fn attach_filter_udp(group: Vec<&Arc<Router>>, local_id: u8) -> Result<()> {
let values: Vec<u32> = group let values: Vec<u32> = group
.iter() .iter()
...@@ -175,17 +192,12 @@ impl Router { ...@@ -175,17 +192,12 @@ impl Router {
} }
// If no match found after all comparisons, drop the packet // If no match found after all comparisons, drop the packet
filters.push(bpf_stmt(BPF_RET | BPF_K, u32::MAX)); filters.push(bpf_stmt(BPF_RET | BPF_K, u32::MAX));
Self::attach_reuseport_cbpf(&group[0].socket, &mut filters)?;
Ok(())
}
fn attach_reuseport_cbpf(sock: &Socket, code: &mut [SockFilter]) -> Result<()> {
let prog = sock_fprog { let prog = sock_fprog {
len: code.len() as u16, len: filters.len() as u16,
filter: code.as_mut_ptr() as *mut sock_filter, filter: filters.as_mut_ptr() as *mut sock_filter,
}; };
let fd = sock.as_raw_fd(); let fd = group[0].socket.as_raw_fd();
let ret = unsafe { let ret = unsafe {
setsockopt( setsockopt(
fd, fd,
...@@ -195,11 +207,110 @@ impl Router { ...@@ -195,11 +207,110 @@ impl Router {
size_of_val(&prog) as socklen_t, size_of_val(&prog) as socklen_t,
) )
}; };
ensure!(ret != -1, std::io::Error::last_os_error());
Ok(())
}
pub(crate) fn handle_outbound_ip_udp(&self, local_id: u8) {
let mut buffer = [0u8; 1500];
if ret == -1 { // Pre-initialize with our Meta header (local -> remote)
Err(std::io::Error::last_os_error())?; let meta = Meta {
src_id: local_id,
dst_id: self.config.remote_id,
reversed: 0,
};
buffer[..META_SIZE].copy_from_slice(meta.as_bytes());
loop {
let n = self.tun.recv(&mut buffer[META_SIZE..]).unwrap(); // recv 失败直接 panic
let guard = crossbeam::epoch::pin();
let endpoint_ref = self.endpoint.load(Ordering::Relaxed, &guard);
if let Some(endpoint) = unsafe { endpoint_ref.as_ref() } {
self.encrypt(&mut buffer[META_SIZE..META_SIZE + n]);
let _ = self.socket.send_to(&buffer[..META_SIZE + n], endpoint);
} }
}
}
pub(crate) fn handle_inbound_ip_udp(&self, local_secret: &[u8; 32]) {
let mut recv_buf = [MaybeUninit::uninit(); 1500];
loop {
// 收到一个非法报文只丢弃一个报文
let (len, addr) = { self.socket.recv_from(&mut recv_buf).unwrap() }; // recv 出错直接 panic
let packet = unsafe { std::slice::from_raw_parts_mut(recv_buf.as_mut_ptr().cast(), len) };
// if addr.is_ipv6() { println!("{:X?}", packet) }
// 只有 ipv4 raw 会给 IP报头
let offset = if self.config.family == Domain::IPV4 && self.config.schema == IP {
(packet[0] & 0x0f) as usize * 4
} else {
0
} + META_SIZE;
{
let guard = crossbeam::epoch::pin();
let current_shared = self.endpoint.load(Ordering::Relaxed, &guard);
let is_same = unsafe { current_shared.as_ref() }.map(|c| *c == addr).unwrap_or(false);
if !is_same {
let new_shared = crossbeam::epoch::Owned::new(addr).into_shared(&guard);
let old_shared = self.endpoint.swap(new_shared, Ordering::Release, &guard);
unsafe { guard.defer_destroy(old_shared) }
}
}
let payload = &mut packet[offset..];
self.decrypt(payload, &local_secret);
let _ = self.tun.send(payload);
}
}
pub(crate) fn handle_outbound_tcp(&self, connection: &Socket) {
let _ = (|| -> Result<()> {
let mut buffer = [0u8; 1500];
loop {
let n = self.tun.recv(&mut buffer)?;
self.encrypt(&mut buffer[..n]);
Router::send_all_tcp(&connection, &buffer[..n])?;
}
})();
let _ = connection.shutdown(Shutdown::Both);
}
pub(crate) fn handle_inbound_tcp(&self, connection: &Socket, local_secret: &[u8; 32]) {
let _ = (|| -> Result<()> {
let mut buf = [MaybeUninit::uninit(); 1500];
let packet: &mut [u8] = unsafe { std::slice::from_raw_parts_mut(buf.as_mut_ptr().cast(), buf.len()) };
loop {
Router::recv_exact_tcp(&connection, &mut buf[0..6])?;
self.decrypt2(packet, &local_secret, 0..6);
let version = packet[0] >> 4;
let total_len = match version {
4 => u16::from_be_bytes([packet[2], packet[3]]) as usize,
6 => u16::from_be_bytes([packet[4], packet[5]]) as usize + 40,
_ => bail!("Invalid IP version"),
};
ensure!(6 < total_len && total_len <= buf.len(), "Invalid total length");
Router::recv_exact_tcp(&connection, &mut buf[6..total_len])?;
self.decrypt2(packet, &local_secret, 6..total_len);
self.tun.send(&packet[..total_len])?;
}
})();
let _ = connection.shutdown(Shutdown::Both);
}
pub(crate) fn recv_exact_tcp(sock: &Socket, mut buf: &mut [MaybeUninit<u8>]) -> Result<()> {
while !buf.is_empty() {
let n = sock.recv(buf)?;
ensure!(n != 0, std::io::ErrorKind::UnexpectedEof);
buf = &mut buf[n..];
}
Ok(())
}
pub(crate) fn send_all_tcp(sock: &Socket, mut buf: &[u8]) -> Result<()> {
while !buf.is_empty() {
let n = sock.send(buf)?;
buf = &buf[n..];
}
Ok(()) Ok(())
} }
...@@ -237,7 +348,7 @@ impl Router { ...@@ -237,7 +348,7 @@ impl Router {
secret: Self::create_secret(config.remote_secret.as_str())?, secret: Self::create_secret(config.remote_secret.as_str())?,
tun: Self::create_tun_device(&config)?, tun: Self::create_tun_device(&config)?,
endpoint: Self::create_endpoint(&config), endpoint: Self::create_endpoint(&config),
socket: Self::create_socket(&config, local_id)?, socket: Arc::new(Self::create_socket(&config, local_id)?),
config, config,
}; };
...@@ -246,29 +357,6 @@ impl Router { ...@@ -246,29 +357,6 @@ impl Router {
Ok(router) Ok(router)
} }
pub fn tcp_listen(&self) -> Result<Socket> {
let result = Socket::new(self.config.family, Type::STREAM, Some(Protocol::TCP))?;
let addr = Router::bind_addr(&self.config);
result.bind(&addr)?;
result.listen(100)?;
Ok(result)
}
pub fn tcp_connect(&self, local_id: u8) -> Result<Socket> {
let result = Socket::new(self.config.family, Type::STREAM, Some(Protocol::TCP))?;
result.set_mark(self.config.mark)?;
let guard = crossbeam::epoch::pin();
let endpoint_ref = self.endpoint.load(Ordering::Relaxed, &guard);
let endpoint = unsafe { endpoint_ref.as_ref() }.ok_or_else(|| anyhow::anyhow!("endpoint info load failed"))?;
let meta = Meta {
src_id: local_id,
dst_id: self.config.remote_id,
reversed: 0,
};
result.send_to_with_flags(meta.as_bytes(), endpoint, MSG_FASTOPEN)?;
Ok(result)
}
fn bind_addr(config: &ConfigRouter) -> SockAddr { fn bind_addr(config: &ConfigRouter) -> SockAddr {
match config.family { match config.family {
Domain::IPV4 => SockAddr::from(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, config.src_port)), Domain::IPV4 => SockAddr::from(SocketAddrV4::new(Ipv4Addr::UNSPECIFIED, config.src_port)),
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment