Commit 5093db1d authored by nanamicat's avatar nanamicat

clean

parent 2bea4472
use bincode::{Decode, Encode};
use serde::{Deserialize, Serialize};
use std::fmt;
#[derive(Serialize, Deserialize, Encode, Decode, Clone, Copy, Default, Ord, PartialOrd, Eq, PartialEq, Debug)]
pub struct RouterID(pub u8);
#[derive(Serialize, Deserialize, Encode, Decode, Clone, Copy, Default, Ord, PartialOrd, Eq, PartialEq, Debug)]
pub struct GatewayID(pub u8);
// 为了节约流量,GatewayGroupID 在网络上使用 u8 格式,比表里配的值少 20000
#[derive(Serialize, Deserialize, Encode, Decode, Clone, Copy, Default, Ord, PartialOrd, Eq, PartialEq, Debug)]
#[serde(from = "u16", into = "u16")]
pub struct GatewayGroupID(pub u8);
impl From<GatewayGroupID> for u16 {
fn from(val: GatewayGroupID) -> Self {
val.0 as u16 + 20000
}
}
impl From<u16> for GatewayGroupID {
fn from(value: u16) -> Self {
Self((value - 20000) as u8)
}
}
#[derive(Serialize, Deserialize, Encode, Decode, Clone, Copy, Default, Ord, PartialOrd, Eq, PartialEq, Debug)]
pub struct RegionID(pub u8);
impl fmt::Display for RegionID {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
#[derive(Serialize, Deserialize, Clone)]
pub struct Router {
pub id: RouterID,
......@@ -46,19 +15,10 @@ pub struct Router {
pub struct Gateway {
pub id: GatewayID,
pub router: String,
pub r#type: GatewayType,
pub cost_outbound: i32,
pub metrics: Vec<i32>,
}
#[derive(Serialize, Deserialize, PartialEq, Clone, Copy)]
#[serde(rename_all = "lowercase")]
pub enum GatewayType {
Common,
VPC,
Virtual,
}
#[derive(Serialize, Deserialize, Clone)]
#[serde(rename_all = "camelCase")]
pub struct GatewayGroup {
......@@ -87,3 +47,27 @@ pub enum Schema {
#[derive(Serialize, Deserialize, PartialEq, Clone)]
pub struct Region {}
#[derive(Serialize, Deserialize, Encode, Decode, Clone, Copy, Default, Ord, PartialOrd, Eq, PartialEq, Debug)]
pub struct RouterID(pub u8);
#[derive(Serialize, Deserialize, Encode, Decode, Clone, Copy, Default, Ord, PartialOrd, Eq, PartialEq, Debug)]
pub struct GatewayID(pub u8);
// 为了节约流量,GatewayGroupID 在网络上使用 u8 格式,比表里配的值少 20000
#[derive(Serialize, Deserialize, Encode, Decode, Clone, Copy, Default, Ord, PartialOrd, Eq, PartialEq, Debug)]
#[serde(from = "u16", into = "u16")]
pub struct GatewayGroupID(pub u8);
impl From<GatewayGroupID> for u16 {
fn from(val: GatewayGroupID) -> Self {
val.0 as u16 + 20000
}
}
impl From<u16> for GatewayGroupID {
fn from(value: u16) -> Self {
Self((value - 20000) as u8)
}
}
#[derive(Serialize, Deserialize, Encode, Decode, Clone, Copy, Default, Ord, PartialOrd, Eq, PartialEq, Debug)]
pub struct RegionID(pub u8);
use crate::data::{self, GatewayGroup};
use std::collections::BTreeSet;
impl GatewayGroup {
pub fn search_routers(&self, routers_data: &[data::Router], groups_data: &[data::GatewayGroup]) -> BTreeSet<String> {
let mut routers: BTreeSet<String> = self
.children
.iter()
.flat_map(|c| groups_data.iter().find(|g| &g.name == c))
.flat_map(|g| g.search_routers(routers_data, groups_data))
.chain(routers_data.iter().filter(|r| self.location_prefix.iter().any(|p| r.location.starts_with(p))).map(|r| r.name.clone()))
.chain(self.include_routers.iter().cloned())
.collect();
for r in &self.exclude_routers {
routers.remove(r);
}
routers
}
}
use crate::api::create_app;
use crate::data::{GatewayGroupID, GatewayID, GatewayType, RegionID, RouterID};
use crate::data::{GatewayGroupID, GatewayID, RegionID, RouterID};
use crate::protocol::{Downlink, Uplink};
use crate::router::Router;
use crate::settings::{Settings, TIMEOUT};
......@@ -7,7 +7,7 @@ use ::config::Config;
use anyhow::{Context, Result};
use config::Environment;
use net::SocketAddr;
use std::collections::{BTreeMap, BTreeSet};
use std::collections::BTreeMap;
use std::net;
use std::sync::Arc;
use tokio::net::UdpSocket;
......@@ -16,6 +16,7 @@ use tokio::time::Instant;
mod api;
mod data;
mod gateway_group;
mod protocol;
mod quality;
mod router;
......@@ -27,21 +28,6 @@ pub struct UpdatingState {
message: Downlink,
}
fn search_gateway_group(data: &data::GatewayGroup, routers_data: &Vec<data::Router>, gateways_groups_data: &Vec<data::GatewayGroup>) -> BTreeSet<String> {
let mut routers: BTreeSet<String> = data
.children
.iter()
.flat_map(|c| gateways_groups_data.iter().find(|g| &g.name == c))
.flat_map(|g| search_gateway_group(g, routers_data, gateways_groups_data))
.chain(routers_data.iter().filter(|r| data.location_prefix.iter().any(|p| r.location.starts_with(p))).map(|r| r.name.clone()))
.chain(data.include_routers.iter().cloned())
.collect();
data.exclude_routers.iter().for_each(|r| {
routers.remove(r);
});
routers
}
#[tokio::main]
async fn main() -> Result<()> {
tracing_subscriber::fmt::init();
......@@ -56,8 +42,8 @@ async fn main() -> Result<()> {
let gateways_group: BTreeMap<GatewayGroupID, Vec<&data::Gateway>> = gateways_groups_data
.iter()
.map(|g| {
let routers = search_gateway_group(g, &routers_data, &gateways_groups_data);
(g.id, gateways_data.iter().filter(|gw| gw.r#type != GatewayType::VPC && routers.contains(&gw.router)).collect())
let routers = g.search_routers(&routers_data, &gateways_groups_data);
(g.id, gateways_data.iter().filter(|gw| routers.contains(&gw.router)).collect())
})
.collect();
let gateway_routers: BTreeMap<GatewayID, RouterID> = gateways_data.iter().map(|gw| (gw.id, routers_data.iter().find(|r| r.name == gw.router).unwrap().id)).collect();
......@@ -69,15 +55,6 @@ async fn main() -> Result<()> {
.collect();
let routers = Arc::new(RwLock::new(routers));
// 3. Initialize State
// let all_router_ids: Vec<u8> = routers_data.iter().map(|r| r.id).collect();
// let mut gateway_groups_map: BTreeMap<u8, GatewayGroup> = BTreeMap::new();
// let raw_groups = gateway_groups_data.clone();
// for g_data in gateway_groups_data {
// let g = GatewayGroup::new(g_data, &raw_groups, &routers_data);
// gateway_groups_map.insert(g.id, g);
// }
let listener = tokio::net::TcpListener::bind(config.http_bind).await?;
let app = create_app(routers_data.clone(), connections_data.clone(), routers.clone());
......@@ -107,6 +84,7 @@ async fn main() -> Result<()> {
if updating.router_id != Default::default() && !routers.get(&updating.router_id).context("router not found")?.is_online() {
updating.router_id = Default::default();
}
tracing::debug!("recv {:?}", uplink);
// 处理收到的消息
if let Some(router) = routers.get_mut(&uplink.id)
......
......@@ -54,7 +54,7 @@ impl Router {
region,
gateway_groups
.iter()
.map(|(&group_id, gateways)| (group_id, gateways.iter().min_by_key(|gw| Self::guess_metric(data, gw, &region, gateway_router)).unwrap().id))
.map(|(&gid, gws)| (gid, gws.iter().min_by_key(|gw| Self::guess_metric(data, gw, gateway_router, &region)).unwrap().id))
.collect(),
)
})
......@@ -65,7 +65,7 @@ impl Router {
}
}
pub fn guess_metric(data: &data::Router, gw: &data::Gateway, region: &RegionID, gateway_router: &BTreeMap<GatewayID, RouterID>) -> i32 {
fn guess_metric(data: &data::Router, gw: &data::Gateway, gateway_router: &BTreeMap<GatewayID, RouterID>, region: &RegionID) -> i32 {
gw.metrics[region.0 as usize]
.saturating_add(gw.cost_outbound)
.saturating_add(if gateway_router[&gw.id] == data.id { 0 } else { 100 })
......@@ -156,55 +156,55 @@ impl Router {
gateway_router: &BTreeMap<GatewayID, RouterID>,
) -> Option<Downlink> {
let penalty = PENALTY_MIN + (PENALTY as f32 * f32::exp2(-now.duration_since(self.last_update).div_duration_f32(HALF_LIFE))) as i32;
let mut changed_via: BTreeMap<RouterID, RouterID> = BTreeMap::new();
let mut changed_plan: BTreeMap<RegionID, BTreeMap<GatewayGroupID, GatewayID>> = BTreeMap::new();
let mut metric: BTreeMap<RouterID, i32> = BTreeMap::new();
metric.insert(self.id, 0);
let mut changed_via = BTreeMap::new();
let mut changed_plan = BTreeMap::new();
let mut metrics = BTreeMap::new();
metrics.insert(self.id, 0);
let mut overcome = false;
// Route updates
for to in routers.values().filter(|&r| r != self) {
let current_router = routers.get(self.via.get(&to.id).unwrap()).unwrap();
let current_metric = self.route_quality(to, current_router, routers, connections).map_or(i32::MAX, |r| r.metric());
match connections[&self.id]
let current_via = &routers[&self.via[&to.id]];
let current_metric = self.route_metric(to, current_via, routers, connections);
let (best_via, best_metric) = match connections[&self.id]
.keys()
.map(|id| routers.get(id).unwrap())
.filter_map(|r| self.route_quality(to, r, routers, connections).map(|q| (r, q.metric())))
.map(|id| &routers[id])
.map(|r| (r, self.route_metric(to, r, routers, connections)))
.min_by_key(|(_, m)| *m)
.unwrap()
{
None if current_router != to => {
// 无论如何都不可达就标记为直连
(_, i32::MAX) => (to.id, i32::MAX),
(r, m) => (r.id, m),
};
metrics.insert(to.id, best_metric);
if current_via.id != best_via {
changed_via.insert(to.id, best_via);
if best_metric == i32::MAX || best_metric.saturating_add(penalty) < current_metric {
overcome = true;
changed_via.insert(to.id, to.id);
metric.insert(to.id, i32::MAX);
}
Some((best_router, best_metric)) if current_router != best_router && best_metric < current_metric => {
if best_metric + penalty < current_metric {
overcome = true
}
changed_via.insert(to.id, best_router.id);
metric.insert(to.id, best_metric);
}
_ => {}
}
}
for region in regions {
for (group_id, gateways) in gateway_groups.iter() {
let current_gateway = self.plan[&region][group_id];
let current_metric = metric[&gateway_router[&current_gateway]];
// Plan updates (Gateways)
for &region in regions {
for (&gid, gateways) in gateway_groups {
let current_gw = self.plan[&region][&gid];
let current_metric = metrics[&gateway_router[&current_gw]];
let (best_gateway, best_metric) = gateways
let (best_gw, best_metric) = gateways
.iter()
.map(|g| (g, metric[&gateway_router[&g.id]].saturating_add(g.cost_outbound).saturating_add(g.metrics[region.0 as usize])))
.map(|g| (g, metrics[&gateway_router[&g.id]].saturating_add(g.cost_outbound).saturating_add(g.metrics[region.0 as usize])))
.min_by_key(|(_, m)| *m)
.unwrap();
if current_gateway != best_gateway.id && best_metric < current_metric {
if current_gw != best_gw.id {
if best_metric.saturating_add(penalty) < current_metric {
overcome = true;
}
changed_plan.entry(*region).or_default().insert(*group_id, best_gateway.id);
changed_plan.entry(region).or_insert_with(BTreeMap::new).insert(gid, best_gw.id);
}
}
}
......@@ -222,23 +222,25 @@ impl Router {
}
}
pub fn route_quality(&self, to: &Router, via: &Router, routers: &BTreeMap<RouterID, Router>, connections: &BTreeMap<RouterID, BTreeMap<RouterID, data::Connection>>) -> Option<Quality> {
pub fn route_metric(&self, to: &Router, via: &Router, routers: &BTreeMap<RouterID, Router>, connections: &BTreeMap<RouterID, BTreeMap<RouterID, data::Connection>>) -> i32 {
assert!(self != to);
assert!(self != via);
let mut result: Quality = Default::default();
let mut route = vec![self];
let mut current = self;
while current != to {
let next = if current == self { via } else { &routers[&current.via[&to.id]] };
match next.peers.get(&current.id).filter(|_| next.is_online() && !route.contains(&next)) {
None => return None,
Some(quality) if quality.reliability == 0 => return None,
Some(quality) => result.concat(quality, connections[&current.id][&next.id].metric),
let quality = next.peers.get(&current.id);
if quality.is_none() || quality.unwrap().reliability == 0 || !next.is_online() || route.contains(&next) {
return i32::MAX;
}
result.concat(quality.unwrap(), connections[&current.id][&next.id].metric);
route.push(next);
current = next;
}
Some(result)
result.metric()
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment