Commit d4fdd8b5 authored by nanahira's avatar nanahira

add multi endpoint

parent f5ffc9fe
Pipeline #37140 passed with stages
in 3 minutes and 16 seconds
...@@ -8,6 +8,7 @@ use std::intrinsics::transmute; ...@@ -8,6 +8,7 @@ use std::intrinsics::transmute;
use std::io::{Read, Write}; use std::io::{Read, Write};
use std::mem::MaybeUninit; use std::mem::MaybeUninit;
use std::sync::Arc; use std::sync::Arc;
use std::time::{Duration, Instant};
#[repr(C)] #[repr(C)]
pub struct Meta { pub struct Meta {
...@@ -93,10 +94,26 @@ fn main() -> Result<(), Box<dyn Error>> { ...@@ -93,10 +94,26 @@ fn main() -> Result<(), Box<dyn Error>> {
loop { loop {
let n = router.tun_reader.read(&mut buffer[meta_size..]).unwrap(); let n = router.tun_reader.read(&mut buffer[meta_size..]).unwrap();
if let Some(ref addr) = *router.endpoint.read().unwrap() {
router.encrypt(&mut buffer[meta_size..meta_size + n]); // 加密 payload
#[cfg(target_os = "linux")] router.encrypt(&mut buffer[meta_size..meta_size + n]);
let _ = router.socket.set_mark(router.config.mark);
#[cfg(target_os = "linux")]
let _ = router.socket.set_mark(router.config.mark);
// 计算是否超时
let now = Instant::now();
let last_seen = *router.last_seen.read().unwrap();
let elapsed = now.duration_since(last_seen);
if elapsed > Duration::from_secs(5) {
// ⏱ 超时,广播给所有 endpoints
let endpoints = router.endpoints.read().unwrap();
for addr in endpoints.iter() {
let _ = router.socket.send_to(&buffer[..meta_size + n], addr);
}
} else if let Some(ref addr) = *router.endpoint.read().unwrap() {
// ✅ 正常单播
let _ = router.socket.send_to(&buffer[..meta_size + n], addr); let _ = router.socket.send_to(&buffer[..meta_size + n], addr);
} }
} }
...@@ -125,6 +142,11 @@ fn main() -> Result<(), Box<dyn Error>> { ...@@ -125,6 +142,11 @@ fn main() -> Result<(), Box<dyn Error>> {
.get_mut(&meta.src_id) .get_mut(&meta.src_id)
.ok_or("missing router")?; .ok_or("missing router")?;
*router.endpoint.write().unwrap() = Some(addr); *router.endpoint.write().unwrap() = Some(addr);
*router.last_seen.write().unwrap() = Instant::now(); // ✅ 更新时间戳
let mut endpoints = router.endpoints.write().unwrap();
endpoints.insert(addr); // ✅ 加入新的地址
router.decrypt(payload, &local_secret); router.decrypt(payload, &local_secret);
router.tun_writer.write_all(payload)?; router.tun_writer.write_all(payload)?;
} }
......
...@@ -52,6 +52,8 @@ pub struct Router<'a> { ...@@ -52,6 +52,8 @@ pub struct Router<'a> {
pub tun_writer: Writer, pub tun_writer: Writer,
pub socket: Arc<Socket>, pub socket: Arc<Socket>,
pub endpoint: Arc<RwLock<Option<SockAddr>>>, pub endpoint: Arc<RwLock<Option<SockAddr>>>,
pub endpoints: Arc<RwLock<HashSet<SocketAddr>>>,
pub last_seen: Arc<RwLock<Instant>>,
} }
impl<'a> Router<'a> { impl<'a> Router<'a> {
...@@ -122,6 +124,13 @@ impl<'a> Router<'a> { ...@@ -122,6 +124,13 @@ impl<'a> Router<'a> {
let (tun_reader, tun_writer) = Self::create_tun_device(&config)?; let (tun_reader, tun_writer) = Self::create_tun_device(&config)?;
Self::run_up_script(&config)?; Self::run_up_script(&config)?;
let mut endpoints_set = HashSet::new();
if let Some(addr) = *endpoint.read().unwrap() {
endpoints_set.insert(addr);
}
let endpoints = Arc::new(RwLock::new(endpoints_set));
let last_seen = Arc::new(RwLock::new(Instant::now()));
let router = Router { let router = Router {
config, config,
secret, secret,
...@@ -129,6 +138,8 @@ impl<'a> Router<'a> { ...@@ -129,6 +138,8 @@ impl<'a> Router<'a> {
tun_reader, tun_reader,
tun_writer, tun_writer,
socket, socket,
endpoints,
last_seen,
}; };
Ok(router) Ok(router)
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment