Commit d5b52d94 authored by nanamicat's avatar nanamicat

config

parent 683e1d35
...@@ -4,15 +4,9 @@ use base64::Engine; ...@@ -4,15 +4,9 @@ use base64::Engine;
use serde::{Deserialize, Deserializer}; use serde::{Deserialize, Deserializer};
use socket2::Domain; use socket2::Domain;
#[derive(Deserialize)]
pub struct Config {
pub local_id: u8,
#[serde(deserialize_with = "deserialize_secret")]
pub local_secret: [u8; SECRET_LENGTH],
pub routers: Vec<ConfigRouter>,
}
#[derive(Deserialize, Clone)] #[derive(Deserialize, Clone)]
pub struct ConfigRouter { pub struct ConfigRouter {
pub local_id: u8,
pub remote_id: u8, pub remote_id: u8,
#[serde(default)] #[serde(default)]
pub schema: Schema, pub schema: Schema,
...@@ -27,6 +21,8 @@ pub struct ConfigRouter { ...@@ -27,6 +21,8 @@ pub struct ConfigRouter {
pub mark: u32, pub mark: u32,
pub endpoint: String, pub endpoint: String,
#[serde(deserialize_with = "deserialize_secret")] #[serde(deserialize_with = "deserialize_secret")]
pub local_secret: [u8; SECRET_LENGTH],
#[serde(deserialize_with = "deserialize_secret")]
pub remote_secret: [u8; SECRET_LENGTH], pub remote_secret: [u8; SECRET_LENGTH],
pub dev: String, pub dev: String,
pub up: String, pub up: String,
......
mod router;
mod config; mod config;
mod router;
use crate::config::{Config, Schema}; use crate::config::{ConfigRouter, Schema};
use crate::router::{Meta, Router, META_SIZE}; use crate::router::{Meta, Router, META_SIZE};
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use crossbeam::epoch::{pin, Owned}; use crossbeam::epoch::{pin, Owned};
...@@ -15,16 +15,14 @@ use std::{collections::HashMap, env, mem::MaybeUninit}; ...@@ -15,16 +15,14 @@ use std::{collections::HashMap, env, mem::MaybeUninit};
fn main() -> Result<()> { fn main() -> Result<()> {
println!("Starting"); println!("Starting");
let config = serde_json::from_str::<Config>(env::args().nth(1).context("need param")?.as_str())?; let config = serde_json::from_str::<Vec<ConfigRouter>>(env::args().nth(1).context("need param")?.as_str())?;
let routers = &config let routers = &config
.routers .into_iter()
.iter()
.cloned()
.sorted_by_key(|r| r.remote_id) .sorted_by_key(|r| r.remote_id)
.map(|c| { .map(|c| {
let remote_id = c.remote_id; let remote_id = c.remote_id;
Router::new(c, config.local_id).map(|r| (remote_id, r)) Router::new(c).map(|router| (remote_id, router))
}) })
.collect::<Result<HashMap<u8, Router>, _>>()?; .collect::<Result<HashMap<u8, Router>, _>>()?;
...@@ -33,7 +31,7 @@ fn main() -> Result<()> { ...@@ -33,7 +31,7 @@ fn main() -> Result<()> {
.filter(|r| r.config.schema == Schema::UDP && r.config.src_port != 0) .filter(|r| r.config.schema == Schema::UDP && r.config.src_port != 0)
.chunk_by(|r| r.config.src_port) .chunk_by(|r| r.config.src_port)
{ {
Router::attach_filter_udp(group.sorted_by_key(|r| r.config.remote_id).collect(), config.local_id)?; Router::attach_filter_udp(group.sorted_by_key(|r| r.config.remote_id).collect())?;
} }
println!("created tuns"); println!("created tuns");
...@@ -43,21 +41,21 @@ fn main() -> Result<()> { ...@@ -43,21 +41,21 @@ fn main() -> Result<()> {
// IP, UDP // IP, UDP
for router in routers.values().filter(|&r| r.config.schema != Schema::TCP) { for router in routers.values().filter(|&r| r.config.schema != Schema::TCP) {
s.spawn(|_| { s.spawn(|_| {
router.handle_outbound_ip_udp(config.local_id); router.handle_outbound_ip_udp();
}); });
s.spawn(|_| { s.spawn(|_| {
router.handle_inbound_ip_udp(&config.local_secret); router.handle_inbound_ip_udp();
}); });
} }
for router in routers.values().filter(|&r| r.config.schema == Schema::TCP && r.config.dst_port != 0) { for router in routers.values().filter(|&r| r.config.schema == Schema::TCP && r.config.dst_port != 0) {
s.spawn(|_| { s.spawn(|_| {
loop { loop {
if let Ok(connection) = router.connect_tcp(config.local_id) { if let Ok(connection) = router.connect_tcp() {
let _ = thread::scope(|s| { let _ = thread::scope(|s| {
s.spawn(|_| router.handle_outbound_tcp(&connection)); s.spawn(|_| router.handle_outbound_tcp(&connection));
s.spawn(|_| router.handle_inbound_tcp(&connection, &config.local_secret)); s.spawn(|_| router.handle_inbound_tcp(&connection));
}); });
} }
std::thread::sleep(Duration::from_secs(TCP_RECONNECT)); std::thread::sleep(Duration::from_secs(TCP_RECONNECT));
...@@ -82,8 +80,8 @@ fn main() -> Result<()> { ...@@ -82,8 +80,8 @@ fn main() -> Result<()> {
Router::recv_exact_tcp(&connection, &mut meta_bytes).unwrap(); Router::recv_exact_tcp(&connection, &mut meta_bytes).unwrap();
let meta: &Meta = Meta::from_bytes(&meta_bytes); let meta: &Meta = Meta::from_bytes(&meta_bytes);
if meta.reversed == 0 if meta.reversed == 0
&& meta.dst_id == config.local_id
&& let Some(router) = routers.get(&meta.src_id) && let Some(router) = routers.get(&meta.src_id)
&& meta.dst_id == router.config.local_id
{ {
let connection = Arc::new(connection); let connection = Arc::new(connection);
...@@ -102,7 +100,7 @@ fn main() -> Result<()> { ...@@ -102,7 +100,7 @@ fn main() -> Result<()> {
let _ = thread::scope(|s| { let _ = thread::scope(|s| {
s.spawn(|_| router.handle_outbound_tcp(&connection)); s.spawn(|_| router.handle_outbound_tcp(&connection));
s.spawn(|_| router.handle_inbound_tcp(&connection, &config.local_secret)); s.spawn(|_| router.handle_inbound_tcp(&connection));
}); });
} }
}); });
......
...@@ -66,12 +66,12 @@ impl Router { ...@@ -66,12 +66,12 @@ impl Router {
} }
} }
pub fn create_socket(config: &ConfigRouter, local_id: u8) -> Result<Socket> { pub fn create_socket(config: &ConfigRouter) -> Result<Socket> {
match config.schema { match config.schema {
Schema::IP => { Schema::IP => {
let result = Socket::new(config.family, Type::RAW, Some(Protocol::from(config.proto as i32)))?; let result = Socket::new(config.family, Type::RAW, Some(Protocol::from(config.proto as i32)))?;
result.set_mark(config.mark)?; result.set_mark(config.mark)?;
Self::attach_filter_ip(config, local_id, &result)?; Self::attach_filter_ip(config, &result)?;
Ok(result) Ok(result)
} }
Schema::UDP => { Schema::UDP => {
...@@ -100,7 +100,7 @@ impl Router { ...@@ -100,7 +100,7 @@ impl Router {
} }
} }
pub fn connect_tcp(&self, local_id: u8) -> Result<Socket> { pub fn connect_tcp(&self) -> Result<Socket> {
// tcp client 的 socket 不要在初始化时创建,在循环里创建 // tcp client 的 socket 不要在初始化时创建,在循环里创建
// 创建 socket 和 获取 endpoint 失败会 panic,连接失败会 error // 创建 socket 和 获取 endpoint 失败会 panic,连接失败会 error
let result = Socket::new(self.config.family, Type::STREAM, Some(Protocol::TCP)).unwrap(); let result = Socket::new(self.config.family, Type::STREAM, Some(Protocol::TCP)).unwrap();
...@@ -113,7 +113,7 @@ impl Router { ...@@ -113,7 +113,7 @@ impl Router {
} }
let meta = Meta { let meta = Meta {
src_id: local_id, src_id: self.config.local_id,
dst_id: self.config.remote_id, dst_id: self.config.remote_id,
reversed: 0, reversed: 0,
}; };
...@@ -125,11 +125,11 @@ impl Router { ...@@ -125,11 +125,11 @@ impl Router {
Ok(result) Ok(result)
} }
fn attach_filter_ip(config: &ConfigRouter, local_id: u8, socket: &Socket) -> Result<()> { fn attach_filter_ip(config: &ConfigRouter, socket: &Socket) -> Result<()> {
// 由于多个对端可能会使用相同的 ipprpto 号,这里确保每个 socket 上只会收到自己对应的对端发来的消息 // 由于多个对端可能会使用相同的 ipprpto 号,这里确保每个 socket 上只会收到自己对应的对端发来的消息
let meta = Meta { let meta = Meta {
src_id: config.remote_id, src_id: config.remote_id,
dst_id: local_id, dst_id: config.local_id,
reversed: 0, reversed: 0,
}; };
let value = u32::from_be_bytes(*meta.as_bytes()); let value = u32::from_be_bytes(*meta.as_bytes());
...@@ -173,13 +173,13 @@ impl Router { ...@@ -173,13 +173,13 @@ impl Router {
Ok(()) Ok(())
} }
pub fn attach_filter_udp(group: Vec<&Router>, local_id: u8) -> Result<()> { pub fn attach_filter_udp(group: Vec<&Router>) -> Result<()> {
let values: Vec<u32> = group let values: Vec<u32> = group
.iter() .iter()
.map(|&f| { .map(|&f| {
let meta = Meta { let meta = Meta {
src_id: f.config.remote_id, src_id: f.config.remote_id,
dst_id: local_id, dst_id: f.config.local_id,
reversed: 0, reversed: 0,
}; };
u32::from_be_bytes(*meta.as_bytes()) u32::from_be_bytes(*meta.as_bytes())
...@@ -221,12 +221,12 @@ impl Router { ...@@ -221,12 +221,12 @@ impl Router {
Ok(()) Ok(())
} }
pub(crate) fn handle_outbound_ip_udp(&self, local_id: u8) { pub(crate) fn handle_outbound_ip_udp(&self) {
let mut buffer = [0u8; 1500]; let mut buffer = [0u8; 1500];
// Pre-initialize with our Meta header (local -> remote) // Pre-initialize with our Meta header (local -> remote)
let meta = Meta { let meta = Meta {
src_id: local_id, src_id: self.config.local_id,
dst_id: self.config.remote_id, dst_id: self.config.remote_id,
reversed: 0, reversed: 0,
}; };
...@@ -243,7 +243,7 @@ impl Router { ...@@ -243,7 +243,7 @@ impl Router {
} }
} }
pub(crate) fn handle_inbound_ip_udp(&self, local_secret: &[u8; 32]) { pub(crate) fn handle_inbound_ip_udp(&self) {
let mut recv_buf = [MaybeUninit::uninit(); 1500]; let mut recv_buf = [MaybeUninit::uninit(); 1500];
loop { loop {
// 收到一个非法报文只丢弃一个报文 // 收到一个非法报文只丢弃一个报文
...@@ -269,7 +269,7 @@ impl Router { ...@@ -269,7 +269,7 @@ impl Router {
} }
let payload = &mut packet[offset..]; let payload = &mut packet[offset..];
self.decrypt(payload, &local_secret); self.decrypt(payload, &self.config.local_secret);
let _ = self.tun.send(payload); let _ = self.tun.send(payload);
} }
} }
...@@ -285,13 +285,13 @@ impl Router { ...@@ -285,13 +285,13 @@ impl Router {
})(); })();
let _ = connection.shutdown(Shutdown::Both); let _ = connection.shutdown(Shutdown::Both);
} }
pub(crate) fn handle_inbound_tcp(&self, connection: &Socket, local_secret: &[u8; 32]) { pub(crate) fn handle_inbound_tcp(&self, connection: &Socket) {
let _ = (|| -> Result<()> { let _ = (|| -> Result<()> {
let mut buf = [MaybeUninit::uninit(); 1500]; let mut buf = [MaybeUninit::uninit(); 1500];
let packet: &mut [u8] = unsafe { std::slice::from_raw_parts_mut(buf.as_mut_ptr().cast(), buf.len()) }; let packet: &mut [u8] = unsafe { std::slice::from_raw_parts_mut(buf.as_mut_ptr().cast(), buf.len()) };
loop { loop {
Router::recv_exact_tcp(&connection, &mut buf[0..6])?; Router::recv_exact_tcp(&connection, &mut buf[0..6])?;
self.decrypt2(packet, &local_secret, 0..6); self.decrypt2(packet, &self.config.local_secret, 0..6);
let version = packet[0] >> 4; let version = packet[0] >> 4;
let total_len = match version { let total_len = match version {
4 => u16::from_be_bytes([packet[2], packet[3]]) as usize, 4 => u16::from_be_bytes([packet[2], packet[3]]) as usize,
...@@ -300,7 +300,7 @@ impl Router { ...@@ -300,7 +300,7 @@ impl Router {
}; };
ensure!(6 < total_len && total_len <= buf.len(), "Invalid total length"); ensure!(6 < total_len && total_len <= buf.len(), "Invalid total length");
Router::recv_exact_tcp(&connection, &mut buf[6..total_len])?; Router::recv_exact_tcp(&connection, &mut buf[6..total_len])?;
self.decrypt2(packet, &local_secret, 6..total_len); self.decrypt2(packet, &self.config.local_secret, 6..total_len);
self.tun.send(&packet[..total_len])?; self.tun.send(&packet[..total_len])?;
} }
})(); })();
...@@ -351,11 +351,11 @@ impl Router { ...@@ -351,11 +351,11 @@ impl Router {
} }
} }
pub fn new(config: ConfigRouter, local_id: u8) -> Result<Router> { pub fn new(config: ConfigRouter) -> Result<Router> {
let router = Router { let router = Router {
tun: Self::create_tun_device(&config)?, tun: Self::create_tun_device(&config)?,
endpoint: Self::create_endpoint(&config), endpoint: Self::create_endpoint(&config),
socket: Self::create_socket(&config, local_id)?, socket: Self::create_socket(&config)?,
tcp_listener_connection: Atomic::null(), tcp_listener_connection: Atomic::null(),
config, config,
}; };
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment