Commit 50d9f945 authored by nanamicat's avatar nanamicat

tcp

parent 1d2cbed5
mod router; mod router;
use crate::Schema::{IP, TCP, UDP}; use crate::router::{Meta, Router, META_SIZE, SECRET_LENGTH};
use crate::router::{META_SIZE, Router, SECRET_LENGTH, Meta}; use crate::Schema::{TCP, UDP};
use anyhow::{Context, Result, bail, ensure}; use anyhow::{Context, Result};
use crossbeam_utils::thread; use crossbeam_utils::thread;
use itertools::Itertools; use itertools::Itertools;
use serde::{Deserialize, Deserializer}; use serde::{Deserialize, Deserializer};
use socket2::Domain; use socket2::Domain;
use std::net::Shutdown;
use std::time::Duration; use std::time::Duration;
use std::{ use std::{collections::HashMap, env, mem::MaybeUninit, sync::Arc};
collections::HashMap,
env,
mem::MaybeUninit,
sync::{Arc, atomic::Ordering},
};
#[derive(Deserialize)] #[derive(Deserialize)]
pub struct Config { pub struct Config {
...@@ -94,67 +88,22 @@ fn main() -> Result<()> { ...@@ -94,67 +88,22 @@ fn main() -> Result<()> {
// IP, UDP // IP, UDP
for router in routers.values().filter(|&r| !(r.config.schema != TCP)) { for router in routers.values().filter(|&r| !(r.config.schema != TCP)) {
s.spawn(|_| { s.spawn(|_| {
let mut buffer = [0u8; 1500]; router.handle_outbound_ip_udp(config.local_id);
// Pre-initialize with our Meta header (local -> remote)
let meta = Meta {
src_id: config.local_id,
dst_id: router.config.remote_id,
reversed: 0,
};
buffer[..META_SIZE].copy_from_slice(meta.as_bytes());
loop {
let n = router.tun.recv(&mut buffer[META_SIZE..]).unwrap(); // recv 失败直接 panic
let guard = crossbeam::epoch::pin();
let endpoint_ref = router.endpoint.load(Ordering::Relaxed, &guard);
if let Some(endpoint) = unsafe { endpoint_ref.as_ref() } {
router.encrypt(&mut buffer[META_SIZE..META_SIZE + n]);
let _ = router.socket.send_to(&buffer[..META_SIZE + n], endpoint);
}
}
}); });
s.spawn(|_| { s.spawn(|_| {
let mut recv_buf = [MaybeUninit::uninit(); 1500]; router.handle_inbound_ip_udp(&local_secret);
loop {
let _ = (|| -> Result<()> {
// 收到一个非法报文只丢弃一个报文
let (len, addr) = { router.socket.recv_from(&mut recv_buf).unwrap() }; // recv 出错直接 panic
let packet = unsafe { std::slice::from_raw_parts_mut(recv_buf.as_mut_ptr().cast(), len) };
// if addr.is_ipv6() { println!("{:X?}", packet) }
// 只有 ipv4 raw 会给 IP报头
let offset = if router.config.family == Domain::IPV4 && router.config.schema == IP {
(packet[0] & 0x0f) as usize * 4
} else {
0
} + META_SIZE;
{
let guard = crossbeam::epoch::pin();
let current_shared = router.endpoint.load(Ordering::Relaxed, &guard);
let is_same = unsafe { current_shared.as_ref() }.map(|c| *c == addr).unwrap_or(false);
if !is_same {
let new_shared = crossbeam::epoch::Owned::new(addr).into_shared(&guard);
let old_shared = router.endpoint.swap(new_shared, Ordering::Release, &guard);
unsafe { guard.defer_destroy(old_shared) }
}
}
let payload = &mut packet[offset..];
router.decrypt(payload, &local_secret);
router.tun.send(payload)?;
Ok(())
})();
}
}); });
} }
for router in routers.values().filter(|&r| r.config.schema == TCP && r.config.dst_port != 0) { for router in routers.values().filter(|&r| r.config.schema == TCP && r.config.dst_port != 0) {
s.spawn(|_| { s.spawn(|_| {
loop { loop {
if let Ok(connection) = router.tcp_connect(config.local_id) { if let Ok(connection) = router.connect_tcp(config.local_id) {
let _ = handle_tcp(router, connection, &local_secret); let _ = thread::scope(|s| {
s.spawn(|_| router.handle_outbound_tcp(&connection));
s.spawn(|_| router.handle_inbound_tcp(&connection, &local_secret));
});
} }
std::thread::sleep(Duration::from_secs(TCP_RECONNECT)); std::thread::sleep(Duration::from_secs(TCP_RECONNECT));
} }
...@@ -168,26 +117,22 @@ fn main() -> Result<()> { ...@@ -168,26 +117,22 @@ fn main() -> Result<()> {
.unique_by(|r| r.config.src_port) .unique_by(|r| r.config.src_port)
{ {
s.spawn(|_| { s.spawn(|_| {
// listen 或 accept 出错直接 panic // accept 出错直接 panic
let listener = router.tcp_listen().unwrap();
loop { loop {
let (connection, _) = listener.accept().unwrap(); let (connection, _) = router.socket.accept().unwrap();
let _ = (|| -> Result<()> {
// 为了写起来方便,每个 tcp connection 有两秒钟时间发送握手报文,如果没收到就关闭连接再响应下一个 let mut meta_bytes = [MaybeUninit::uninit(); META_SIZE];
// 正常的连接是 fast open 的,期望会瞬间连好 Router::recv_exact_tcp(&connection, &mut meta_bytes).unwrap();
connection.set_read_timeout(Option::from(Duration::from_secs(2)))?; let meta: &Meta = unsafe { &*(meta_bytes.as_ptr() as *const Meta) };
let mut meta_bytes = [MaybeUninit::uninit(); META_SIZE]; if meta.reversed == 0
Router::recv_exact(&connection, &mut meta_bytes)?; && meta.dst_id == config.local_id
let meta: &Meta = unsafe { &*(meta_bytes.as_ptr() as *const Meta) }; && let Some(router) = routers.get(&meta.src_id)
if meta.reversed == 0 {
&& meta.dst_id == config.local_id let _ = thread::scope(|s| {
&& let Some(router) = routers.get(&meta.src_id) s.spawn(|_| router.handle_outbound_tcp(&connection));
{ s.spawn(|_| router.handle_inbound_tcp(&connection, &local_secret));
connection.set_read_timeout(None)?; });
handle_tcp(router, connection, &local_secret)?; }
}
Ok(())
})();
} }
}); });
} }
...@@ -195,42 +140,3 @@ fn main() -> Result<()> { ...@@ -195,42 +140,3 @@ fn main() -> Result<()> {
.map_err(|_| anyhow::anyhow!("Thread panicked"))?; .map_err(|_| anyhow::anyhow!("Thread panicked"))?;
Ok(()) Ok(())
} }
fn handle_tcp(router: &Arc<Router>, connection: socket2::Socket, local_secret: &[u8; 32]) -> Result<()> {
thread::scope(|s| {
s.spawn(|_| {
let _ = (|| -> Result<()> {
let mut buffer = [0u8; 1500];
loop {
let n = router.tun.recv(&mut buffer)?;
router.encrypt(&mut buffer[..n]);
Router::send_all(&connection, &buffer[..n])?;
}
})();
let _ = connection.shutdown(Shutdown::Both);
});
s.spawn(|_| {
let _ = (|| -> Result<()> {
let mut buf = [MaybeUninit::uninit(); 1500];
let packet: &mut [u8] = unsafe { std::slice::from_raw_parts_mut(buf.as_mut_ptr().cast(), buf.len()) };
loop {
Router::recv_exact(&connection, &mut buf[0..6])?;
router.decrypt2(packet, &local_secret, 0..6);
let version = packet[0] >> 4;
let total_len = match version {
4 => u16::from_be_bytes([packet[2], packet[3]]) as usize,
6 => u16::from_be_bytes([packet[4], packet[5]]) as usize + 40,
_ => bail!("Invalid IP version"),
};
ensure!(6 < total_len && total_len <= buf.len(), "Invalid total length");
Router::recv_exact(&connection, &mut buf[6..total_len])?;
router.decrypt2(packet, &local_secret, 6..total_len);
router.tun.send(&packet[..total_len])?;
}
})();
let _ = connection.shutdown(Shutdown::Both);
});
})
.map_err(|_| anyhow::anyhow!("Thread panicked"))?;
Ok(())
}
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment