Commit 55e8ed17 authored by nanamicat's avatar nanamicat

clean

parent 87991e29
use crate::router::SECRET_LENGTH;
use base64::prelude::BASE64_STANDARD;
use base64::Engine;
use serde::{Deserialize, Deserializer};
use socket2::Domain;
#[derive(Deserialize)]
pub struct Config {
pub local_id: u8,
#[serde(deserialize_with = "deserialize_secret")]
pub local_secret: [u8; SECRET_LENGTH],
pub routers: Vec<ConfigRouter>,
}
#[derive(Deserialize, Clone)]
pub struct ConfigRouter {
pub remote_id: u8,
#[serde(default)]
pub schema: Schema,
#[serde(default)]
pub proto: u8,
#[serde(default)]
pub src_port: u16,
#[serde(default)]
pub dst_port: u16,
#[serde(deserialize_with = "deserialize_domain")]
pub family: Domain,
pub mark: u32,
pub endpoint: String,
#[serde(deserialize_with = "deserialize_secret")]
pub remote_secret: [u8; SECRET_LENGTH],
pub dev: String,
pub up: String,
}
#[derive(Deserialize, Default, PartialEq, Clone, Copy)]
pub enum Schema {
#[default]
IP,
UDP,
TCP,
}
fn deserialize_domain<'de, D>(d: D) -> Result<Domain, D::Error>
where
D: Deserializer<'de>,
{
match u8::deserialize(d)? {
4 => Ok(Domain::IPV4),
6 => Ok(Domain::IPV6),
_ => Err(serde::de::Error::custom("Invalid domain")),
}
}
fn deserialize_secret<'de, D>(d: D) -> Result<[u8; SECRET_LENGTH], D::Error>
where
D: Deserializer<'de>,
{
BASE64_STANDARD
.decode(String::deserialize(d)?)
.map_err(serde::de::Error::custom)?
.as_slice()
.try_into()
.map_err(serde::de::Error::custom)
}
mod router; mod router;
mod config;
use crate::router::{Meta, Router, META_SIZE, SECRET_LENGTH}; use crate::config::{Config, Schema};
use crate::Schema::{TCP, UDP}; use crate::router::{Meta, Router, META_SIZE};
use anyhow::{Context, Result}; use anyhow::{Context, Result};
use crossbeam::epoch::{pin, Owned}; use crossbeam::epoch::{pin, Owned};
use crossbeam_utils::thread; use crossbeam_utils::thread;
use itertools::Itertools; use itertools::Itertools;
use serde::{Deserialize, Deserializer};
use socket2::Domain;
use std::net::Shutdown; use std::net::Shutdown;
use std::sync::atomic::Ordering; use std::sync::atomic::Ordering;
use std::sync::Arc; use std::sync::Arc;
use std::time::Duration; use std::time::Duration;
use std::{collections::HashMap, env, mem::MaybeUninit}; use std::{collections::HashMap, env, mem::MaybeUninit};
#[derive(Deserialize)]
pub struct Config {
pub local_id: u8,
pub local_secret: String,
pub routers: Vec<ConfigRouter>,
}
#[derive(Deserialize, Clone)]
pub struct ConfigRouter {
pub remote_id: u8,
#[serde(default)]
pub schema: Schema,
#[serde(default)]
pub proto: u8,
#[serde(default)]
pub src_port: u16,
#[serde(default)]
pub dst_port: u16,
#[serde(deserialize_with = "deserialize_domain")]
pub family: Domain,
pub mark: u32,
pub endpoint: String,
pub remote_secret: String,
pub dev: String,
pub up: String,
}
#[derive(Deserialize, Default, PartialEq, Clone, Copy)]
pub enum Schema {
#[default]
IP,
UDP,
TCP,
}
fn deserialize_domain<'de, D>(d: D) -> Result<Domain, D::Error>
where
D: Deserializer<'de>,
{
match u8::deserialize(d)? {
4 => Ok(Domain::IPV4),
6 => Ok(Domain::IPV6),
_ => Err(serde::de::Error::custom("Invalid domain")),
}
}
fn main() -> Result<()> { fn main() -> Result<()> {
println!("Starting"); println!("Starting");
let config = serde_json::from_str::<Config>(env::args().nth(1).context("need param")?.as_str())?; let config = serde_json::from_str::<Config>(env::args().nth(1).context("need param")?.as_str())?;
let local_secret: [u8; SECRET_LENGTH] = Router::create_secret(config.local_secret.as_str())?;
let routers = &config let routers = &config
.routers .routers
...@@ -77,7 +30,7 @@ fn main() -> Result<()> { ...@@ -77,7 +30,7 @@ fn main() -> Result<()> {
for (_, group) in &routers for (_, group) in &routers
.values() .values()
.filter(|r| r.config.schema == UDP && r.config.src_port != 0) .filter(|r| r.config.schema == Schema::UDP && r.config.src_port != 0)
.chunk_by(|r| r.config.src_port) .chunk_by(|r| r.config.src_port)
{ {
Router::attach_filter_udp(group.sorted_by_key(|r| r.config.remote_id).collect(), config.local_id)?; Router::attach_filter_udp(group.sorted_by_key(|r| r.config.remote_id).collect(), config.local_id)?;
...@@ -88,23 +41,23 @@ fn main() -> Result<()> { ...@@ -88,23 +41,23 @@ fn main() -> Result<()> {
thread::scope(|s| { thread::scope(|s| {
// IP, UDP // IP, UDP
for router in routers.values().filter(|&r| r.config.schema != TCP) { for router in routers.values().filter(|&r| r.config.schema != Schema::TCP) {
s.spawn(|_| { s.spawn(|_| {
router.handle_outbound_ip_udp(config.local_id); router.handle_outbound_ip_udp(config.local_id);
}); });
s.spawn(|_| { s.spawn(|_| {
router.handle_inbound_ip_udp(&local_secret); router.handle_inbound_ip_udp(&config.local_secret);
}); });
} }
for router in routers.values().filter(|&r| r.config.schema == TCP && r.config.dst_port != 0) { for router in routers.values().filter(|&r| r.config.schema == Schema::TCP && r.config.dst_port != 0) {
s.spawn(|_| { s.spawn(|_| {
loop { loop {
if let Ok(connection) = router.connect_tcp(config.local_id) { if let Ok(connection) = router.connect_tcp(config.local_id) {
let _ = thread::scope(|s| { let _ = thread::scope(|s| {
s.spawn(|_| router.handle_outbound_tcp(&connection)); s.spawn(|_| router.handle_outbound_tcp(&connection));
s.spawn(|_| router.handle_inbound_tcp(&connection, &local_secret)); s.spawn(|_| router.handle_inbound_tcp(&connection, &config.local_secret));
}); });
} }
std::thread::sleep(Duration::from_secs(TCP_RECONNECT)); std::thread::sleep(Duration::from_secs(TCP_RECONNECT));
...@@ -115,7 +68,7 @@ fn main() -> Result<()> { ...@@ -115,7 +68,7 @@ fn main() -> Result<()> {
// tcp listeners // tcp listeners
for router in routers for router in routers
.values() .values()
.filter(|&r| r.config.schema == TCP && r.config.dst_port == 0) .filter(|&r| r.config.schema == Schema::TCP && r.config.dst_port == 0)
.unique_by(|r| r.config.src_port) .unique_by(|r| r.config.src_port)
{ {
s.spawn(|s| { s.spawn(|s| {
...@@ -149,7 +102,7 @@ fn main() -> Result<()> { ...@@ -149,7 +102,7 @@ fn main() -> Result<()> {
let _ = thread::scope(|s| { let _ = thread::scope(|s| {
s.spawn(|_| router.handle_outbound_tcp(&connection)); s.spawn(|_| router.handle_outbound_tcp(&connection));
s.spawn(|_| router.handle_inbound_tcp(&connection, &local_secret)); s.spawn(|_| router.handle_inbound_tcp(&connection, &config.local_secret));
}); });
} }
}); });
......
use crate::{ConfigRouter, Schema}; use anyhow::{bail, ensure, Result};
use anyhow::{Error, Result, bail, ensure};
use base64::Engine;
use base64::prelude::BASE64_STANDARD;
use socket2::{Domain, Protocol, SockAddr, SockFilter, Socket, Type}; use socket2::{Domain, Protocol, SockAddr, SockFilter, Socket, Type};
use std::net::Shutdown; use std::net::Shutdown;
use std::sync::Arc; use std::sync::Arc;
...@@ -16,10 +13,11 @@ use std::{ ...@@ -16,10 +13,11 @@ use std::{
}; };
use tun::Device; use tun::Device;
use crossbeam::epoch::{Atomic, pin}; use crate::config::{ConfigRouter, Schema};
use crossbeam::epoch::{pin, Atomic};
use libc::{ use libc::{
BPF_ABS, BPF_B, BPF_IND, BPF_JEQ, BPF_JMP, BPF_K, BPF_LD, BPF_LDX, BPF_MSH, BPF_RET, BPF_W, MSG_FASTOPEN, SO_ATTACH_REUSEPORT_CBPF, SOL_SOCKET, setsockopt, setsockopt, sock_filter, sock_fprog, socklen_t, BPF_ABS, BPF_B, BPF_IND, BPF_JEQ, BPF_JMP, BPF_K, BPF_LD, BPF_LDX, BPF_MSH, BPF_RET, BPF_W,
sock_filter, sock_fprog, socklen_t, MSG_FASTOPEN, SOL_SOCKET, SO_ATTACH_REUSEPORT_CBPF,
}; };
pub const SECRET_LENGTH: usize = 32; pub const SECRET_LENGTH: usize = 32;
...@@ -43,7 +41,6 @@ impl Meta { ...@@ -43,7 +41,6 @@ impl Meta {
pub struct Router { pub struct Router {
pub config: ConfigRouter, pub config: ConfigRouter,
pub secret: [u8; SECRET_LENGTH],
pub tun: Device, pub tun: Device,
pub socket: Socket, pub socket: Socket,
pub endpoint: Atomic<SockAddr>, pub endpoint: Atomic<SockAddr>,
...@@ -52,10 +49,6 @@ pub struct Router { ...@@ -52,10 +49,6 @@ pub struct Router {
} }
impl Router { impl Router {
pub(crate) fn create_secret(config: &str) -> Result<[u8; SECRET_LENGTH]> {
BASE64_STANDARD.decode(config)?.as_slice().try_into().map_err(Error::from)
}
pub(crate) fn decrypt(&self, data: &mut [u8], secret: &[u8; SECRET_LENGTH]) { pub(crate) fn decrypt(&self, data: &mut [u8], secret: &[u8; SECRET_LENGTH]) {
for (i, b) in data.iter_mut().enumerate() { for (i, b) in data.iter_mut().enumerate() {
*b ^= secret[i % SECRET_LENGTH]; *b ^= secret[i % SECRET_LENGTH];
...@@ -69,7 +62,7 @@ impl Router { ...@@ -69,7 +62,7 @@ impl Router {
pub(crate) fn encrypt(&self, data: &mut [u8]) { pub(crate) fn encrypt(&self, data: &mut [u8]) {
for (i, b) in data.iter_mut().enumerate() { for (i, b) in data.iter_mut().enumerate() {
*b ^= self.secret[i % SECRET_LENGTH]; *b ^= self.config.remote_secret[i % SECRET_LENGTH];
} }
} }
...@@ -354,7 +347,6 @@ impl Router { ...@@ -354,7 +347,6 @@ impl Router {
pub fn new(config: ConfigRouter, local_id: u8) -> Result<Router> { pub fn new(config: ConfigRouter, local_id: u8) -> Result<Router> {
let router = Router { let router = Router {
secret: Self::create_secret(config.remote_secret.as_str())?,
tun: Self::create_tun_device(&config)?, tun: Self::create_tun_device(&config)?,
endpoint: Self::create_endpoint(&config), endpoint: Self::create_endpoint(&config),
socket: Self::create_socket(&config, local_id)?, socket: Self::create_socket(&config, local_id)?,
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment