Commit de4c1f95 authored by nanamicat's avatar nanamicat

.

parent a7a7f9d9
Pipeline #42511 passed with stages
in 5 minutes and 23 seconds
......@@ -989,8 +989,8 @@ dependencies = [
"rtnetlink",
"saturating_cast",
"serde",
"serde_derive",
"serde_json",
"string-interner",
"tokio",
"tracing",
"tracing-subscriber",
......@@ -1260,6 +1260,16 @@ version = "1.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ce2be8dc25455e1f91df71bfa12ad37d7af1092ae736f3a6cd0e37bc7810596"
[[package]]
name = "string-interner"
version = "0.19.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "23de088478b31c349c9ba67816fa55d9355232d63c3afea8bf513e31f0f1d2c0"
dependencies = [
"hashbrown 0.15.5",
"serde",
]
[[package]]
name = "syn"
version = "2.0.111"
......
......@@ -9,7 +9,6 @@ edition = "2024"
config = "0.15.19"
serde_json = "1.0.145"
serde = { version = "1.0.228", features = ["derive"] }
serde_derive = "1.0"
tokio = { version = "1.48", features = ["full"] }
anyhow = "1.0.100"
bincode = { version = "2.0.1", features = ["derive"] }
......@@ -23,3 +22,4 @@ rand = "0.9.2"
saturating_cast = "0.1.0"
ipnet = { version = "2.11.0", features = ["serde"] }
itertools = "0.14.0"
string-interner = "0.19.0"
use serde_derive::Deserialize;
#[derive(Deserialize)]
pub(crate) struct Connection {
// pub metric: u32,
// pub protocol: Schema,
}
// #[derive(Deserialize)]
// pub(crate) enum Schema {
// IP,
// UDP,
// TCP,
// }
use std::net::Ipv4Addr;
use ipnet::Ipv4Net;
use serde::Deserialize;
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Router {
pub id: u8,
pub name: String,
pub address: Ipv4Addr,
// pub location: String,
// pub user: String,
// pub host: String,
// pub ssh_port: u16,
// pub ssh_system: Option<u16>,
// pub next_mark: u16,
// pub dest_mark: u16,
// pub port: u16,
// pub port2: u16,
// pub wg_private_key: Option<String>,
// pub masq_interfaces: Vec<String>,
// pub os: String,
// pub arch: String,
// pub ocserv_port: u16,
// pub offset: u16,
// pub offset2: u16,
}
// #[derive(Deserialize)]
// #[serde(rename_all = "camelCase")]
// pub struct GatewayGroup {
// pub id: u16,
// pub name: String,
// pub description: String,
// pub children: HashSet<String>,
// pub dest_mark: u16,
// pub location_prefix: Vec<String>,
// pub include_routers: HashSet<String>,
// pub exclude_routers: HashSet<String>,
// }
#[derive(Clone, Debug, Deserialize)]
#[serde(rename_all = "camelCase")]
pub struct Subnet {
// pub id: u32,
pub router: String,
pub subnet: Ipv4Net,
// pub interface: Option<String>,
// pub comment: Option<String>,
// pub oc_server: Option<String>,
}
// use std::collections::HashSet;
//
// use crate::data::{GatewayGroup, Router};
//
// impl GatewayGroup {
// pub fn routers(&self, groups: &[GatewayGroup], routers: &[Router]) -> HashSet<u8> {
// routers
// .iter()
// .filter(|r| self.include_routers.contains(&r.name))
// .chain(
// self.location_prefix
// .iter()
// .flat_map(|p| routers.iter().filter(move |r| r.location.starts_with(p))),
// )
// .filter(|r| !self.exclude_routers.contains(&r.name))
// .map(|r| r.id)
// .chain(
// groups
// .iter()
// .filter(|g| self.children.contains(&g.name))
// .flat_map(|g1| g1.routers(groups, routers)),
// )
// .collect()
// }
// }
mod connection;
mod data;
mod gateway_group;
mod protocol;
mod router;
mod server;
mod settings;
mod shared;
use crate::{
connection::Connection,
protocol::{Hello, MessageType, Uplink},
router::Router,
server::Server,
settings::{Settings, INTERVAL, WINDOW},
settings::{CONFIG, INTERVAL, WINDOW},
shared::{
data::{self, GatewayGroupID},
protocol::{Hello, MessageType, Uplink},
},
};
use config::Config;
use hickory_resolver::Resolver;
use itertools::Itertools;
use std::{collections::BTreeMap, fs, time::SystemTime};
use std::{collections::BTreeMap, time::SystemTime};
use tokio::{
net::UdpSocket,
time::{self, Instant},
};
use crate::shared::data::DATABASE;
#[tokio::main]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt::init();
let config: Settings = Config::builder().add_source(config::Environment::default()).build()?.try_deserialize()?;
let mut routers = {
let routers_data = serde_json::from_slice::<Vec<data::Router>>(&fs::read("import/data/Router.json")?)?;
let subnets_data = serde_json::from_slice::<Vec<data::Subnet>>(&fs::read("import/data/Subnet.json")?)?;
let mut subnets_map = subnets_data.into_iter().into_group_map_by(|s| s.router.clone());
routers_data
.into_iter()
.map(|r| (r.id, Router::new(subnets_map.remove(&r.name).unwrap_or_default().into_iter().map(|s| s.subnet).collect(), r, &config)))
.collect::<BTreeMap<u8, Router>>()
};
let connections = serde_json::from_slice::<BTreeMap<u8, BTreeMap<u8, Connection>>>(&fs::read("import/connections.json")?)?;
// let groups: Vec<GatewayGroup> = serde_json::from_slice(&fs::read("import/GatewayGroup.json")?)?;
let mut server = Server::new(
config.id, &routers,
// groups
// .iter()
// .map(|g| (g.id, g.routers(&groups, &routers_data)))
// .collect::<BTreeMap<u32, HashSet<u8>>>(),
);
let mut routers: BTreeMap<data::RouterID, Router> = DATABASE.routers.iter().map(|r| (r.id, Router::new(&r))).collect();
let mut server = Server::new();
let socket = UdpSocket::bind(config.bind).await?;
tracing::info!("Listening on {}", config.bind);
let socket = UdpSocket::bind(CONFIG.bind).await?;
tracing::info!("Listening on {}", CONFIG.bind);
let mut timer = time::interval(INTERVAL);
let mut buf = [0; 1500];
......@@ -58,21 +38,21 @@ async fn main() -> anyhow::Result<()> {
let resolver = Resolver::builder_tokio()?.build();
loop {
let server_addr = config.server.to_socket_addrs(&resolver).await?;
let server_addr = CONFIG.server.to_socket_addrs(&resolver).await?;
tokio::select! {
biased; // 优先处理上面的
result = socket.recv_from(&mut buf) => {
let (len, addr) = result?;
if addr.port() == config.bind.port()
if addr.port() == CONFIG.bind.port()
&& let Some(peer) = Router::get(&mut routers, addr)
&& let Ok((hello, _)) = bincode::decode_from_slice(&buf[..len], bincode::config::standard())
{
peer.on_message(&hello);
} else if addr.port() == config.server.port
} else if addr.port() == CONFIG.server.port
&& let Ok((downlink, _)) = bincode::decode_from_slice(&buf[..len], bincode::config::standard())
&& let Some(uplink) = server.on_message(downlink, &routers, &connections[&config.id], &config).await
&& let Some(uplink) = server.on_message(downlink, &routers).await
{
let len = bincode::encode_into_slice(uplink, &mut buf, bincode::config::standard())?;
let _ = socket.send_to(&buf[..len], addr).await;
......@@ -83,19 +63,19 @@ async fn main() -> anyhow::Result<()> {
let hello = Hello { time: SystemTime::now().duration_since(SystemTime::UNIX_EPOCH)?.as_millis() as u32 };
let len = bincode::encode_into_slice(&hello, &mut buf, bincode::config::standard())?;
for id in connections[&config.id].keys() {
for id in DATABASE.connections[&CONFIG.id].keys() {
let router = &routers[id];
let _ = socket.send_to(&buf[..len], router.link_address).await;
}
// to server
let uplink = Uplink {
id: config.id,
id: CONFIG.id,
action: if server.online {MessageType::Update} else {MessageType::Query},
version: server.version,
peers: if now.duration_since(start) < INTERVAL * WINDOW { Default::default() } else { connections
peers: if now.duration_since(start) < INTERVAL * WINDOW { Default::default() } else { DATABASE.connections
.iter()
.filter(|(_, to)| to.contains_key(&config.id))
.filter(|(_, to)| to.contains_key(&CONFIG.id))
.map(|(from,_)|routers.get_mut(from).unwrap().update(now))
.collect()},
via: Default::default(),
......
use crate::{
data,
protocol::{Hello, PeerQuality},
settings::{INTERVAL, Settings, WINDOW},
settings::{CONFIG, INTERVAL, Settings, WINDOW},
shared::{
data::{self, DATABASE, RouterID},
protocol::{Hello, PeerQuality},
},
};
use ipnet::Ipv4Net;
use saturating_cast::SaturatingCast;
......@@ -20,31 +22,31 @@ pub struct Router {
receive: u64,
remote_time: u32,
local_time: Instant,
pub(crate) data: data::Router,
pub(crate) subnets: Vec<Ipv4Net>,
pub addresses: Vec<Ipv4Net>,
}
impl Router {
pub fn link_address(from: u8, to: u8) -> Ipv4Addr {
Ipv4Addr::from([10, 200, to, from])
pub fn link_address(from: RouterID, to: RouterID) -> Ipv4Addr {
Ipv4Addr::from([10, 200, to.0, from.0])
}
pub fn get(routers: &mut BTreeMap<u8, Router>, link_address: SocketAddr) -> Option<&mut Router> {
pub fn get(routers: &mut BTreeMap<RouterID, Router>, link_address: SocketAddr) -> Option<&mut Router> {
match link_address {
SocketAddr::V4(addr) => routers.get_mut(&addr.ip().octets()[2]),
SocketAddr::V4(addr) => routers.get_mut(&RouterID(addr.ip().octets()[2])),
SocketAddr::V6(_) => None,
}
}
pub fn new(subnets: Vec<Ipv4Net>, data: data::Router, config: &Settings) -> Router {
pub fn new(data: &data::Router) -> Router {
Router {
link_address: SocketAddr::new(IpAddr::V4(Router::link_address(config.id, data.id)), config.bind.port()),
link_address: SocketAddr::new(IpAddr::V4(Router::link_address(CONFIG.id, data.id)), CONFIG.bind.port()),
remote_time: rand::random(),
receive: 0,
jitter: 0,
prev_delay: 0,
delay: 0,
local_time: Instant::now(),
data,
subnets,
addresses: std::iter::once(Ipv4Net::from(data.address))
.chain(DATABASE.subnets.iter().filter(|s| s.router == data.id).map(|s| s.subnet))
.collect(),
}
}
......@@ -81,7 +83,7 @@ impl Router {
self.prev_delay = delay;
}
pub(crate) fn update(&mut self, now: Instant) -> PeerQuality {
pub fn update(&mut self, now: Instant) -> PeerQuality {
let reliability = self.receive.count_ones() as u8;
if reliability > 0 {
let duration = now.duration_since(self.local_time).div_duration_f32(INTERVAL) as u8;
......
use crate::{
connection::Connection,
protocol::{Downlink, MessageType, Uplink},
router::Router,
settings::{ROUTE_PROTOCOL, Settings},
settings::{CONFIG, ROUTE_PROTOCOL},
shared::{
data::{self, DATABASE, GatewayGroupID, GatewayID, RegionID, RouterID},
gateway_group::GATEWAYGROUPINDEX,
protocol::{Downlink, MessageType, Uplink},
},
};
use rtnetlink::RouteMessageBuilder;
use std::{collections::BTreeMap, net::Ipv4Addr};
pub struct Server {
pub(crate) online: bool,
pub(crate) version: u32,
pub(crate) via: BTreeMap<u8, u8>,
pub(crate) plan: BTreeMap<u8, BTreeMap<u8, u8>>,
pub(crate) handle: rtnetlink::Handle,
pub online: bool,
pub version: u32,
pub via: BTreeMap<RouterID, RouterID>,
pub plan: BTreeMap<RegionID, BTreeMap<GatewayGroupID, GatewayID>>,
pub handle: rtnetlink::Handle,
}
impl Server {
pub fn new(id: u8, routers: &BTreeMap<u8, Router>) -> Self {
pub fn new() -> Self {
let (connection, handle, _) = rtnetlink::new_connection().unwrap();
tokio::spawn(connection);
let id = CONFIG.id;
Server {
online: false,
version: rand::random(),
via: routers.keys().filter(|&&i| i != id).map(|&i| (i, i)).collect(),
plan: Default::default(),
via: data::GatewayGroup::default_via(id),
plan: data::GatewayGroup::default_plan(id),
handle,
}
}
pub async fn on_message(&mut self, mut message: Downlink, routers: &BTreeMap<u8, Router>, connections: &BTreeMap<u8, Connection>, config: &Settings) -> Option<Uplink> {
fn guess(id: RouterID, gw: &data::Gateway, region: usize) -> i32 {
gw.metrics[region].saturating_add(gw.cost_outbound).saturating_add(if gw.router == id { 0 } else { 100 })
}
pub async fn on_message(&mut self, mut message: Downlink, routers: &BTreeMap<RouterID, Router>) -> Option<Uplink> {
if message.ack != self.version {
return None;
}
......@@ -40,11 +48,9 @@ impl Server {
for (to, via) in self.via.iter_mut() {
*via = *to;
}
self.via.append(&mut message.via);
self.plan.append(&mut message.plan);
self.write(&self.via, routers, connections, config).await;
self.apply(&mut message.via, &mut message.plan, routers).await;
Some(Uplink {
id: config.id,
id: CONFIG.id,
action: MessageType::Update,
version: self.version,
peers: Default::default(),
......@@ -53,7 +59,7 @@ impl Server {
})
}
(true, MessageType::Query) => Some(Uplink {
id: config.id,
id: CONFIG.id,
action: MessageType::Full,
version: self.version,
peers: Default::default(),
......@@ -61,11 +67,9 @@ impl Server {
plan: self.plan.clone(),
}),
(true, MessageType::Update) => {
self.via.append(&mut message.via);
self.plan.append(&mut message.plan);
self.write(&self.via, routers, connections, config).await;
self.apply(&mut message.via, &mut message.plan, routers).await;
Some(Uplink {
id: config.id,
id: CONFIG.id,
action: MessageType::Update,
version: self.version,
peers: Default::default(),
......@@ -76,18 +80,52 @@ impl Server {
_ => None,
}
}
pub async fn write(&self, via: &BTreeMap<u8, u8>, routers: &BTreeMap<u8, Router>, connections: &BTreeMap<u8, Connection>, config: &Settings) {
pub async fn apply(&mut self, via: &mut BTreeMap<RouterID, RouterID>, plan: &mut BTreeMap<RegionID, BTreeMap<GatewayGroupID, GatewayID>>, routers: &BTreeMap<RouterID, Router>) {
self.via.append(via);
for (region, mut plan) in std::mem::take(plan) {
self.plan.entry(region).or_default().append(&mut plan);
}
self.write(&self.via, &self.plan, routers).await;
}
pub async fn write(&self, via: &BTreeMap<RouterID, RouterID>, plan: &BTreeMap<RegionID, BTreeMap<GatewayGroupID, GatewayID>>, routers: &BTreeMap<RouterID, Router>) {
for (to_id, via_id) in via.iter() {
let to = &routers[to_id];
for (destination, prefix) in std::iter::once((to.data.address, 32)).chain(to.subnets.iter().map(|s| (s.addr(), s.prefix_len()))) {
let builder = RouteMessageBuilder::<Ipv4Addr>::new().destination_prefix(destination, prefix).protocol(ROUTE_PROTOCOL);
let msg = if connections.contains_key(via_id) {
builder.gateway(Router::link_address(config.id, *via_id)).build()
for address in to.addresses.iter() {
let builder = RouteMessageBuilder::<Ipv4Addr>::new().destination_prefix(address.addr(), address.prefix_len()).protocol(ROUTE_PROTOCOL);
let msg = if DATABASE.connections.contains_key(via_id) {
builder.gateway(Router::link_address(CONFIG.id, *via_id)).build()
} else {
builder.kind(netlink_packet_route::route::RouteType::Unreachable).build()
};
self.handle.route().add(msg).replace().execute().await.unwrap_or_else(|e| panic!("{}", e));
tracing::info!("{:?}", msg);
if let Err(e) = self.handle.route().add(msg).replace().execute().await {
eprintln!("{}", e);
}
}
}
// if let Some(global) = plan.get(&RegionID(0)) {
// for (group_id, &gateway_id) in global.iter() {
// if GATEWAYGROUPINDEX[group_id].iter().any(|g| g.router == CONFIG.id) {
// continue;
// }
// let group = DATABASE.gateway_groups.iter().find(|g| g.id == *group_id).unwrap();
// let gateway = GATEWAYGROUPINDEX[group_id].iter().find(|g| g.id == gateway_id).unwrap();
// if let Some(&via_id) = self.via.get(&gateway.router) {
// let gateway_ip = Router::link_address(CONFIG.id, via_id);
// let msg = RouteMessageBuilder::<Ipv4Addr>::new()
// .destination_prefix(Ipv4Addr::UNSPECIFIED, 0)
// .table_id(group.dest_mark as u32)
// .protocol(ROUTE_PROTOCOL)
// .gateway(gateway_ip)
// .build();
// tracing::info!("Default route: {:?}", msg);
// if let Err(e) = self.handle.route().add(msg).replace().execute().await {
// eprintln!("Error adding default route: {}", e);
// }
// }
// }
// }
}
}
use config::{Config, Environment};
use hickory_resolver::Resolver;
use hickory_resolver::name_server::GenericConnector;
use hickory_resolver::proto::runtime::TokioRuntimeProvider;
......@@ -6,33 +7,29 @@ use serde::Deserialize;
use std::net::SocketAddr;
use std::time::Duration;
use crate::shared::data::RouterID;
#[derive(Deserialize)]
pub struct Settings {
pub id: u8,
pub id: RouterID,
pub server: Endpoint,
pub bind: SocketAddr,
}
#[derive(Deserialize)]
#[serde(try_from = "String")]
#[serde(from = "String")]
pub struct Endpoint {
pub host: String,
pub port: u16,
}
impl TryFrom<String> for Endpoint {
type Error = String;
fn try_from(value: String) -> Result<Self, Self::Error> {
let parts: Vec<&str> = value.rsplitn(2, ':').collect();
if parts.len() != 2 {
return Err(format!("Invalid endpoint format: {}", value));
}
let port = parts[0].parse::<u16>().map_err(|e| format!("Invalid port: {}", e))?;
let host = parts[1].to_string();
Ok(Endpoint { host, port })
impl From<String> for Endpoint {
fn from(value: String) -> Self {
let (host, port) = value.rsplit_once(':').unwrap();
let port = port.parse::<u16>().unwrap();
// 处理 IPv6 的方括号
let host = if host.starts_with('[') && host.ends_with(']') { &host[1..host.len() - 1] } else { host };
Endpoint { host: host.to_string(), port }
}
}
......@@ -46,3 +43,5 @@ impl Endpoint {
pub const WINDOW: u32 = 64;
pub const INTERVAL: Duration = Duration::from_secs(1);
pub const ROUTE_PROTOCOL: RouteProtocol = RouteProtocol::Other(252);
pub static CONFIG: std::sync::LazyLock<Settings> = std::sync::LazyLock::new(|| Config::builder().add_source(Environment::default()).build().unwrap().try_deserialize().unwrap());
#![allow(dead_code)]
use bincode::{Decode, Encode};
use ipnet::Ipv4Net;
use itertools::{EitherOrBoth, Itertools};
use serde::{Deserialize, Serialize};
use std::{
collections::BTreeMap,
fmt::{Debug, Display},
net::Ipv4Addr,
sync::OnceLock,
};
use string_interner::{StringInterner, Symbol, backend::StringBackend};
#[derive(Serialize, Deserialize, Clone)]
#[serde(rename_all = "camelCase")]
pub struct Router {
pub id: RouterID,
pub name: String,
pub location: String,
pub address: Ipv4Addr,
}
#[derive(Serialize, Deserialize, Clone, Debug)]
#[serde(rename_all = "camelCase")]
pub struct Subnet {
pub router: RouterID,
pub subnet: Ipv4Net,
}
#[derive(Serialize, Deserialize, Clone)]
pub struct Gateway {
pub id: GatewayID,
pub router: RouterID,
pub cost_outbound: i32,
pub metrics: Vec<i32>,
}
#[derive(Serialize, Deserialize, Clone)]
#[serde(rename_all = "camelCase")]
pub struct GatewayGroup {
pub id: GatewayGroupID,
pub name: String,
pub location_prefix: Vec<String>,
pub include_routers: Vec<RouterID>,
pub exclude_routers: Vec<RouterID>,
pub children: Vec<String>,
pub dest_mark: u16,
}
#[derive(Serialize, Deserialize, Clone)]
pub struct Connection {
pub protocol: Schema,
pub metric: u32,
}
#[derive(Serialize, Deserialize, Default, PartialEq, Clone, Copy)]
pub enum Schema {
#[default]
IP,
UDP,
TCP,
}
#[derive(Serialize, Deserialize, PartialEq, Clone)]
pub struct Region {}
#[derive(Encode, Decode, Clone, Copy, Default, Ord, PartialOrd, Eq, PartialEq, Debug)]
pub struct RouterID(pub u8);
impl Symbol for RouterID {
fn try_from_usize(index: usize) -> Option<Self> {
if index <= u8::MAX as usize { Some(Self(index as u8)) } else { None }
}
fn to_usize(self) -> usize {
self.0 as usize
}
}
impl Display for RouterID {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(ROUTER_ID_REGISTRY.get().and_then(|r| r.resolve(*self)).ok_or(std::fmt::Error)?)
}
}
#[derive(Serialize, Deserialize, Encode, Decode, Clone, Copy, Default, Ord, PartialOrd, Eq, PartialEq, Debug)]
pub struct GatewayID(pub u8);
// 为了节约流量,GatewayGroupID 在网络上使用 u8 格式,比表里配的值少 20000
#[derive(Serialize, Deserialize, Encode, Decode, Clone, Copy, Default, Ord, PartialOrd, Eq, PartialEq, Debug)]
#[serde(from = "u16", into = "u16")]
pub struct GatewayGroupID(pub u8);
impl From<GatewayGroupID> for u16 {
fn from(val: GatewayGroupID) -> Self {
val.0 as u16 + 20000
}
}
impl From<u16> for GatewayGroupID {
fn from(value: u16) -> Self {
Self((value - 20000) as u8)
}
}
#[derive(Serialize, Deserialize, Encode, Decode, Clone, Copy, Default, Ord, PartialOrd, Eq, PartialEq, Debug)]
pub struct RegionID(pub u8);
#[derive(Default)]
pub struct Database {
pub routers: Vec<Router>,
pub gateways: Vec<Gateway>,
pub gateway_groups: Vec<GatewayGroup>,
pub regions: Vec<Region>,
pub subnets: Vec<Subnet>,
pub connections: BTreeMap<RouterID, BTreeMap<RouterID, Connection>>,
}
pub static DATABASE: std::sync::LazyLock<Database> = std::sync::LazyLock::new(|| Database {
routers: register(load_file("import/data/Router.json"), |r| r.id, |r| r.name.clone(), &ROUTER_ID_REGISTRY),
gateways: load_file("import/data/Gateway.json"),
gateway_groups: load_file("import/data/GatewayGroup.json"),
regions: load_file("import/data/Region.json"),
subnets: load_file("import/data/Subnet.json"),
connections: load_file("import/connections.json"),
});
static ROUTER_ID_REGISTRY: OnceLock<StringInterner<StringBackend<RouterID>>> = OnceLock::new();
fn load_file<T: serde::de::DeserializeOwned>(path: &str) -> T {
serde_json::from_str(&std::fs::read_to_string(path).unwrap()).unwrap()
}
pub fn register<T, N, S, Sym>(mut data: Vec<T>, num: N, str: S, registry: &OnceLock<StringInterner<StringBackend<Sym>>>) -> Vec<T>
where
T: serde::de::DeserializeOwned,
N: Fn(&T) -> Sym,
S: Fn(&T) -> String,
Sym: Symbol + Debug + Ord,
{
data.sort_by_key(&num);
let mut interner = StringInterner::<StringBackend<Sym>>::new();
if let Some(max_id) = data.last().map(&num) {
for item in (0..=max_id.to_usize()).merge_join_by(data.iter(), |i, r| i.cmp(&num(r).to_usize())) {
let name = match item {
EitherOrBoth::Both(_, r) => str(r),
EitherOrBoth::Left(i) => i.to_string(),
EitherOrBoth::Right(_) => unreachable!(),
};
interner.get_or_intern(name);
}
}
interner.shrink_to_fit();
registry.set(interner).unwrap();
data
}
impl<'de> Deserialize<'de> for RouterID {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
match serde_json::Value::deserialize(deserializer)? {
serde_json::Value::Number(n) => Ok(RouterID(n.as_u64().ok_or_else(|| serde::de::Error::custom("Invalid router id"))? as u8)),
serde_json::Value::String(s) => match s.parse::<u8>() {
Ok(id) => Ok(RouterID(id)),
Err(_) => ROUTER_ID_REGISTRY.get().unwrap().get(&s).ok_or_else(|| serde::de::Error::custom(format!("Unknown router {}", s))),
},
_ => Err(serde::de::Error::custom("Invalid router id type")),
}
}
}
impl Serialize for RouterID {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_u8(self.0)
}
}
use crate::shared::data::{self, DATABASE, GatewayGroup, GatewayGroupID, GatewayID, RegionID, RouterID};
use std::collections::{BTreeMap, BTreeSet};
use std::sync::LazyLock;
pub static GATEWAYGROUPINDEX: LazyLock<BTreeMap<GatewayGroupID, Vec<&'static data::Gateway>>> = LazyLock::new(|| {
DATABASE
.gateway_groups
.iter()
.map(|g| {
let routers = g.search_routers(&DATABASE.routers, &DATABASE.gateway_groups);
(g.id, DATABASE.gateways.iter().filter(|gw| routers.contains(&gw.router)).collect())
})
.collect()
});
impl GatewayGroup {
fn search_routers(&self, routers_data: &[data::Router], groups_data: &[Self]) -> BTreeSet<RouterID> {
let mut routers: BTreeSet<RouterID> = self
.children
.iter()
.flat_map(|c| groups_data.iter().find(|g| &g.name == c))
.flat_map(|g| g.search_routers(routers_data, groups_data))
.chain(routers_data.iter().filter(|r| self.location_prefix.iter().any(|p| r.location.starts_with(p))).map(|r| r.id))
.chain(self.include_routers.iter().cloned())
.collect();
for r in &self.exclude_routers {
routers.remove(r);
}
routers
}
pub fn default_via(id: RouterID) -> BTreeMap<RouterID, RouterID> {
DATABASE.routers.iter().filter(|r| r.id != id).map(|r| (r.id, r.id)).collect()
}
pub fn default_plan(id: RouterID) -> BTreeMap<RegionID, BTreeMap<GatewayGroupID, GatewayID>> {
DATABASE
.regions
.iter()
.enumerate()
.map(|(r, _)| {
(
RegionID(r as u8),
GATEWAYGROUPINDEX.iter().map(|(&gid, g)| (gid, g.iter().min_by_key(|gw| Self::guess(id, gw, r)).unwrap().id)).collect(),
)
})
.collect()
}
fn guess(id: RouterID, gw: &data::Gateway, region: usize) -> i32 {
gw.metrics[region].saturating_add(gw.cost_outbound).saturating_add(if gw.router == id { 0 } else { 100 })
}
}
pub mod data;
pub mod gateway_group;
pub mod protocol;
#![allow(dead_code)]
use bincode::{Decode, Encode};
use serde_derive::Serialize;
use serde::Serialize;
use std::collections::BTreeMap;
use crate::shared::data::{GatewayGroupID, GatewayID, RegionID, RouterID};
#[derive(Encode, Decode)]
pub struct Hello {
pub time: u32,
......@@ -17,12 +21,12 @@ pub enum MessageType {
#[derive(Encode, Decode, Default, Debug, Clone)]
pub struct Uplink {
pub id: u8,
pub id: RouterID,
pub action: MessageType,
pub version: u32,
pub peers: Vec<PeerQuality>,
pub via: BTreeMap<u8, u8>,
pub plan: BTreeMap<u8, BTreeMap<u8, u8>>,
pub via: BTreeMap<RouterID, RouterID>,
pub plan: BTreeMap<RegionID, BTreeMap<GatewayGroupID, GatewayID>>,
}
#[derive(Encode, Decode, Default, Debug, Clone)]
......@@ -30,8 +34,8 @@ pub struct Downlink {
pub action: MessageType,
pub version: u32,
pub ack: u32,
pub via: BTreeMap<u8, u8>,
pub plan: BTreeMap<u8, BTreeMap<u8, u8>>,
pub via: BTreeMap<RouterID, RouterID>,
pub plan: BTreeMap<RegionID, BTreeMap<GatewayGroupID, GatewayID>>,
}
#[derive(Encode, Decode, Serialize, Copy, Clone, Debug, Default)]
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment