#![allow(dead_code)]

use bincode::{Decode, Encode};
use ipnet::Ipv4Net;
use itertools::{EitherOrBoth, Itertools};
use serde::{Deserialize, Serialize};
use std::{
    collections::BTreeMap,
    fmt::{Debug, Display},
    net::Ipv4Addr,
    sync::OnceLock,
};
use string_interner::{StringInterner, Symbol, backend::StringBackend};

#[derive(Serialize, Deserialize, Clone)]
#[serde(rename_all = "camelCase")]
pub struct Router {
    pub id: RouterID,
    pub name: String,
    pub location: String,
    pub address: Ipv4Addr,
}

#[derive(Serialize, Deserialize, Clone, Debug)]
#[serde(rename_all = "camelCase")]
pub struct Subnet {
    pub router: RouterID,
    pub subnet: Ipv4Net,
}

#[derive(Serialize, Deserialize, Clone)]
pub struct Gateway {
    pub id: GatewayID,
    pub router: RouterID,
    pub cost_outbound: i32,
    pub metrics: Vec<i32>,
}

#[derive(Serialize, Deserialize, Clone)]
#[serde(rename_all = "camelCase")]
pub struct GatewayGroup {
    pub id: GatewayGroupID,
    pub name: String,
    pub location_prefix: Vec<String>,
    pub include_routers: Vec<RouterID>,
    pub exclude_routers: Vec<RouterID>,
    pub children: Vec<String>,
    pub dest_mark: u16,
}

#[derive(Serialize, Deserialize, Clone)]
pub struct Connection {
    pub protocol: Schema,
    pub metric: u32,
}

#[derive(Serialize, Deserialize, Default, PartialEq, Clone, Copy)]
pub enum Schema {
    #[default]
    IP,
    UDP,
    TCP,
}

#[derive(Serialize, Deserialize, PartialEq, Clone)]
pub struct Region {}

#[derive(Encode, Decode, Clone, Copy, Default, Ord, PartialOrd, Eq, PartialEq, Debug)]
pub struct RouterID(pub u8);

impl Symbol for RouterID {
    fn try_from_usize(index: usize) -> Option<Self> {
        if index <= u8::MAX as usize { Some(Self(index as u8)) } else { None }
    }

    fn to_usize(self) -> usize {
        self.0 as usize
    }
}

impl Display for RouterID {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.write_str(ROUTER_ID_REGISTRY.get().and_then(|r| r.resolve(*self)).ok_or(std::fmt::Error)?)
    }
}

#[derive(Serialize, Deserialize, Encode, Decode, Clone, Copy, Default, Ord, PartialOrd, Eq, PartialEq, Debug)]
pub struct GatewayID(pub u8);

// 为了节约流量，GatewayGroupID 在网络上使用 u8 格式，比表里配的值少 20000
#[derive(Serialize, Deserialize, Encode, Decode, Clone, Copy, Default, Ord, PartialOrd, Eq, PartialEq, Debug)]
#[serde(from = "u16", into = "u16")]
pub struct GatewayGroupID(pub u8);
impl From<GatewayGroupID> for u16 {
    fn from(val: GatewayGroupID) -> Self {
        val.0 as u16 + 20000
    }
}
impl From<u16> for GatewayGroupID {
    fn from(value: u16) -> Self {
        Self((value - 20000) as u8)
    }
}

#[derive(Serialize, Deserialize, Encode, Decode, Clone, Copy, Default, Ord, PartialOrd, Eq, PartialEq, Debug)]
pub struct RegionID(pub u8);

#[derive(Default)]
pub struct Database {
    pub routers: Vec<Router>,
    pub gateways: Vec<Gateway>,
    pub gateway_groups: Vec<GatewayGroup>,
    pub regions: Vec<Region>,
    pub subnets: Vec<Subnet>,
    pub connections: BTreeMap<RouterID, BTreeMap<RouterID, Connection>>,
}

pub static DATABASE: std::sync::LazyLock<Database> = std::sync::LazyLock::new(|| Database {
    routers: register(load_file("import/data/Router.json"), |r| r.id, |r| r.name.clone(), &ROUTER_ID_REGISTRY),
    gateways: load_file("import/data/Gateway.json"),
    gateway_groups: load_file("import/data/GatewayGroup.json"),
    regions: load_file("import/data/Region.json"),
    subnets: load_file("import/data/Subnet.json"),
    connections: load_file("import/connections.json"),
});

static ROUTER_ID_REGISTRY: OnceLock<StringInterner<StringBackend<RouterID>>> = OnceLock::new();

fn load_file<T: serde::de::DeserializeOwned>(path: &str) -> T {
    serde_json::from_str(&std::fs::read_to_string(path).unwrap()).unwrap()
}

pub fn register<T, N, S, Sym>(mut data: Vec<T>, num: N, str: S, registry: &OnceLock<StringInterner<StringBackend<Sym>>>) -> Vec<T>
where
    T: serde::de::DeserializeOwned,
    N: Fn(&T) -> Sym,
    S: Fn(&T) -> String,
    Sym: Symbol + Debug + Ord,
{
    data.sort_by_key(&num);

    let mut interner = StringInterner::<StringBackend<Sym>>::new();
    if let Some(max_id) = data.last().map(&num) {
        for item in (0..=max_id.to_usize()).merge_join_by(data.iter(), |i, r| i.cmp(&num(r).to_usize())) {
            let name = match item {
                EitherOrBoth::Both(_, r) => str(r),
                EitherOrBoth::Left(i) => i.to_string(),
                EitherOrBoth::Right(_) => unreachable!(),
            };
            interner.get_or_intern(name);
        }
    }
    interner.shrink_to_fit();
    registry.set(interner).unwrap();
    data
}

impl<'de> Deserialize<'de> for RouterID {
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
    where
        D: serde::Deserializer<'de>,
    {
        match serde_json::Value::deserialize(deserializer)? {
            serde_json::Value::Number(n) => Ok(RouterID(n.as_u64().ok_or_else(|| serde::de::Error::custom("Invalid router id"))? as u8)),
            serde_json::Value::String(s) => match s.parse::<u8>() {
                Ok(id) => Ok(RouterID(id)),
                Err(_) => ROUTER_ID_REGISTRY.get().unwrap().get(&s).ok_or_else(|| serde::de::Error::custom(format!("Unknown router {}", s))),
            },
            _ => Err(serde::de::Error::custom("Invalid router id type")),
        }
    }
}

impl Serialize for RouterID {
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
    where
        S: serde::Serializer,
    {
        serializer.serialize_u8(self.0)
    }
}
