use std::{net::IpAddr, sync::Arc, time::Duration};
use governor::{clock::QuantaClock, state::keyed::DashMapStateStore, RateLimiter};
use mas_config::RateLimitingConfig;
use mas_data_model::User;
use ulid::Ulid;
#[derive(Debug, Clone, thiserror::Error)]
pub enum AccountRecoveryLimitedError {
    #[error("Too many account recovery requests for requester {0}")]
    Requester(RequesterFingerprint),
    #[error("Too many account recovery requests for e-mail {0}")]
    Email(String),
}
#[derive(Debug, Clone, Copy, thiserror::Error)]
pub enum PasswordCheckLimitedError {
    #[error("Too many password checks for requester {0}")]
    Requester(RequesterFingerprint),
    #[error("Too many password checks for user {0}")]
    User(Ulid),
}
#[derive(Debug, Clone, thiserror::Error)]
pub enum RegistrationLimitedError {
    #[error("Too many account registration requests for requester {0}")]
    Requester(RequesterFingerprint),
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct RequesterFingerprint {
    ip: Option<IpAddr>,
}
impl std::fmt::Display for RequesterFingerprint {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        if let Some(ip) = self.ip {
            write!(f, "{ip}")
        } else {
            write!(f, "(NO CLIENT IP)")
        }
    }
}
impl RequesterFingerprint {
    pub const EMPTY: Self = Self { ip: None };
    #[must_use]
    pub const fn new(ip: IpAddr) -> Self {
        Self { ip: Some(ip) }
    }
}
#[derive(Debug, Clone)]
pub struct Limiter {
    inner: Arc<LimiterInner>,
}
type KeyedRateLimiter<K> = RateLimiter<K, DashMapStateStore<K>, QuantaClock>;
#[derive(Debug)]
struct LimiterInner {
    account_recovery_per_requester: KeyedRateLimiter<RequesterFingerprint>,
    account_recovery_per_email: KeyedRateLimiter<String>,
    password_check_for_requester: KeyedRateLimiter<RequesterFingerprint>,
    password_check_for_user: KeyedRateLimiter<Ulid>,
    registration_per_requester: KeyedRateLimiter<RequesterFingerprint>,
}
impl LimiterInner {
    fn new(config: &RateLimitingConfig) -> Option<Self> {
        Some(Self {
            account_recovery_per_requester: RateLimiter::keyed(
                config.account_recovery.per_ip.to_quota()?,
            ),
            account_recovery_per_email: RateLimiter::keyed(
                config.account_recovery.per_address.to_quota()?,
            ),
            password_check_for_requester: RateLimiter::keyed(config.login.per_ip.to_quota()?),
            password_check_for_user: RateLimiter::keyed(config.login.per_account.to_quota()?),
            registration_per_requester: RateLimiter::keyed(config.registration.to_quota()?),
        })
    }
}
impl Limiter {
    #[must_use]
    pub fn new(config: &RateLimitingConfig) -> Option<Self> {
        Some(Self {
            inner: Arc::new(LimiterInner::new(config)?),
        })
    }
    pub fn start(&self) {
        let this = self.clone();
        tokio::spawn(async move {
            let mut interval = tokio::time::interval(Duration::from_secs(60));
            interval.set_missed_tick_behavior(tokio::time::MissedTickBehavior::Skip);
            loop {
                this.inner.account_recovery_per_email.retain_recent();
                this.inner.account_recovery_per_requester.retain_recent();
                this.inner.password_check_for_requester.retain_recent();
                this.inner.password_check_for_user.retain_recent();
                this.inner.registration_per_requester.retain_recent();
                interval.tick().await;
            }
        });
    }
    pub fn check_account_recovery(
        &self,
        requester: RequesterFingerprint,
        email_address: &str,
    ) -> Result<(), AccountRecoveryLimitedError> {
        self.inner
            .account_recovery_per_requester
            .check_key(&requester)
            .map_err(|_| AccountRecoveryLimitedError::Requester(requester))?;
        let canonical_email = email_address.to_lowercase();
        self.inner
            .account_recovery_per_email
            .check_key(&canonical_email)
            .map_err(|_| AccountRecoveryLimitedError::Email(canonical_email))?;
        Ok(())
    }
    pub fn check_password(
        &self,
        key: RequesterFingerprint,
        user: &User,
    ) -> Result<(), PasswordCheckLimitedError> {
        self.inner
            .password_check_for_requester
            .check_key(&key)
            .map_err(|_| PasswordCheckLimitedError::Requester(key))?;
        self.inner
            .password_check_for_user
            .check_key(&user.id)
            .map_err(|_| PasswordCheckLimitedError::User(user.id))?;
        Ok(())
    }
    pub fn check_registration(
        &self,
        requester: RequesterFingerprint,
    ) -> Result<(), RegistrationLimitedError> {
        self.inner
            .registration_per_requester
            .check_key(&requester)
            .map_err(|_| RegistrationLimitedError::Requester(requester))?;
        Ok(())
    }
}
#[cfg(test)]
mod tests {
    use mas_data_model::User;
    use mas_storage::{clock::MockClock, Clock};
    use rand::SeedableRng;
    use super::*;
    #[test]
    fn test_password_check_limiter() {
        let now = MockClock::default().now();
        let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
        let limiter = Limiter::new(&RateLimitingConfig::default()).unwrap();
        let requesters: [_; 768] = (0..=255)
            .flat_map(|a| (0..3).map(move |b| RequesterFingerprint::new([a, a, b, b].into())))
            .collect::<Vec<_>>()
            .try_into()
            .unwrap();
        let alice = User {
            id: Ulid::from_datetime_with_source(now.into(), &mut rng),
            username: "alice".to_owned(),
            sub: "123-456".to_owned(),
            primary_user_email_id: None,
            created_at: now,
            locked_at: None,
            can_request_admin: false,
        };
        let bob = User {
            id: Ulid::from_datetime_with_source(now.into(), &mut rng),
            username: "bob".to_owned(),
            sub: "123-456".to_owned(),
            primary_user_email_id: None,
            created_at: now,
            locked_at: None,
            can_request_admin: false,
        };
        assert!(limiter.check_password(requesters[0], &alice).is_ok());
        assert!(limiter.check_password(requesters[0], &alice).is_ok());
        assert!(limiter.check_password(requesters[0], &alice).is_ok());
        assert!(limiter.check_password(requesters[0], &alice).is_err());
        assert!(limiter.check_password(requesters[0], &bob).is_err());
        assert!(limiter.check_password(requesters[1], &alice).is_ok());
        for requester in requesters.iter().skip(2).take(598) {
            assert!(limiter.check_password(*requester, &alice).is_ok());
            assert!(limiter.check_password(*requester, &alice).is_ok());
            assert!(limiter.check_password(*requester, &alice).is_ok());
            assert!(limiter.check_password(*requester, &alice).is_err());
        }
        assert!(limiter.check_password(requesters[600], &alice).is_ok());
        assert!(limiter.check_password(requesters[601], &alice).is_ok());
        assert!(limiter.check_password(requesters[602], &alice).is_err());
        assert!(limiter.check_password(requesters[603], &bob).is_ok());
    }
}