1#![deny(missing_docs)]
16
17use super::p2c::Balance;
18use crate::discover::Change;
19use crate::load::Load;
20use crate::make::MakeService;
21use futures_core::{ready, Stream};
22use pin_project_lite::pin_project;
23use slab::Slab;
24use std::{
25    fmt,
26    future::Future,
27    pin::Pin,
28    task::{Context, Poll},
29};
30use tower_service::Service;
31
32#[cfg(test)]
33mod test;
34
35#[derive(Debug, Clone, Copy, Eq, PartialEq)]
36enum Level {
37    Low,
39    Normal,
41    High,
43}
44
45pin_project! {
46    pub struct PoolDiscoverer<MS, Target, Request>
49    where
50        MS: MakeService<Target, Request>,
51    {
52        maker: MS,
53        #[pin]
54        making: Option<MS::Future>,
55        target: Target,
56        load: Level,
57        services: Slab<()>,
58        died_tx: tokio::sync::mpsc::UnboundedSender<usize>,
59        #[pin]
60        died_rx: tokio::sync::mpsc::UnboundedReceiver<usize>,
61        limit: Option<usize>,
62    }
63}
64
65impl<MS, Target, Request> fmt::Debug for PoolDiscoverer<MS, Target, Request>
66where
67    MS: MakeService<Target, Request> + fmt::Debug,
68    Target: fmt::Debug,
69{
70    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
71        f.debug_struct("PoolDiscoverer")
72            .field("maker", &self.maker)
73            .field("making", &self.making.is_some())
74            .field("target", &self.target)
75            .field("load", &self.load)
76            .field("services", &self.services)
77            .field("limit", &self.limit)
78            .finish()
79    }
80}
81
82impl<MS, Target, Request> Stream for PoolDiscoverer<MS, Target, Request>
83where
84    MS: MakeService<Target, Request>,
85    MS::MakeError: Into<crate::BoxError>,
86    MS::Error: Into<crate::BoxError>,
87    Target: Clone,
88{
89    type Item = Result<Change<usize, DropNotifyService<MS::Service>>, MS::MakeError>;
90
91    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
92        let mut this = self.project();
93
94        while let Poll::Ready(Some(sid)) = this.died_rx.as_mut().poll_recv(cx) {
95            this.services.remove(sid);
96            tracing::trace!(
97                pool.services = this.services.len(),
98                message = "removing dropped service"
99            );
100        }
101
102        if this.services.is_empty() && this.making.is_none() {
103            let _ = ready!(this.maker.poll_ready(cx))?;
104            tracing::trace!("construct initial pool connection");
105            this.making
106                .set(Some(this.maker.make_service(this.target.clone())));
107        }
108
109        if let Level::High = this.load {
110            if this.making.is_none() {
111                if this
112                    .limit
113                    .map(|limit| this.services.len() >= limit)
114                    .unwrap_or(false)
115                {
116                    return Poll::Pending;
117                }
118
119                tracing::trace!(
120                    pool.services = this.services.len(),
121                    message = "decided to add service to loaded pool"
122                );
123                ready!(this.maker.poll_ready(cx))?;
124                tracing::trace!("making new service");
125                this.making
127                    .set(Some(this.maker.make_service(this.target.clone())));
128            }
129        }
130
131        if let Some(fut) = this.making.as_mut().as_pin_mut() {
132            let svc = ready!(fut.poll(cx))?;
133            this.making.set(None);
134
135            let id = this.services.insert(());
136            let svc = DropNotifyService {
137                svc,
138                id,
139                notify: this.died_tx.clone(),
140            };
141            tracing::trace!(
142                pool.services = this.services.len(),
143                message = "finished creating new service"
144            );
145            *this.load = Level::Normal;
146            return Poll::Ready(Some(Ok(Change::Insert(id, svc))));
147        }
148
149        match this.load {
150            Level::High => {
151                unreachable!("found high load but no Service being made");
152            }
153            Level::Normal => Poll::Pending,
154            Level::Low if this.services.len() == 1 => Poll::Pending,
155            Level::Low => {
156                *this.load = Level::Normal;
157                let rm = this.services.iter().next().unwrap().0;
159                tracing::trace!(
162                    pool.services = this.services.len(),
163                    message = "removing service for over-provisioned pool"
164                );
165                Poll::Ready(Some(Ok(Change::Remove(rm))))
166            }
167        }
168    }
169}
170
171#[derive(Copy, Clone, Debug)]
177pub struct Builder {
178    low: f64,
179    high: f64,
180    init: f64,
181    alpha: f64,
182    limit: Option<usize>,
183}
184
185impl Default for Builder {
186    fn default() -> Self {
187        Builder {
188            init: 0.1,
189            low: 0.00001,
190            high: 0.2,
191            alpha: 0.03,
192            limit: None,
193        }
194    }
195}
196
197impl Builder {
198    pub fn new() -> Self {
202        Self::default()
203    }
204
205    pub fn underutilized_below(&mut self, low: f64) -> &mut Self {
211        self.low = low;
212        self
213    }
214
215    pub fn loaded_above(&mut self, high: f64) -> &mut Self {
222        self.high = high;
223        self
224    }
225
226    pub fn initial(&mut self, init: f64) -> &mut Self {
233        self.init = init;
234        self
235    }
236
237    pub fn urgency(&mut self, alpha: f64) -> &mut Self {
251        self.alpha = alpha.max(0.0).min(1.0);
252        self
253    }
254
255    pub fn max_services(&mut self, limit: Option<usize>) -> &mut Self {
262        self.limit = limit;
263        self
264    }
265
266    pub fn build<MS, Target, Request>(
268        &self,
269        make_service: MS,
270        target: Target,
271    ) -> Pool<MS, Target, Request>
272    where
273        MS: MakeService<Target, Request>,
274        MS::Service: Load,
275        <MS::Service as Load>::Metric: std::fmt::Debug,
276        MS::MakeError: Into<crate::BoxError>,
277        MS::Error: Into<crate::BoxError>,
278        Target: Clone,
279    {
280        let (died_tx, died_rx) = tokio::sync::mpsc::unbounded_channel();
281        let d = PoolDiscoverer {
282            maker: make_service,
283            making: None,
284            target,
285            load: Level::Normal,
286            services: Slab::new(),
287            died_tx,
288            died_rx,
289            limit: self.limit,
290        };
291
292        Pool {
293            balance: Balance::new(Box::pin(d)),
294            options: *self,
295            ewma: self.init,
296        }
297    }
298}
299
300pub struct Pool<MS, Target, Request>
302where
303    MS: MakeService<Target, Request>,
304    MS::MakeError: Into<crate::BoxError>,
305    MS::Error: Into<crate::BoxError>,
306    Target: Clone,
307{
308    balance: Balance<Pin<Box<PoolDiscoverer<MS, Target, Request>>>, Request>,
310    options: Builder,
311    ewma: f64,
312}
313
314impl<MS, Target, Request> fmt::Debug for Pool<MS, Target, Request>
315where
316    MS: MakeService<Target, Request> + fmt::Debug,
317    MS::MakeError: Into<crate::BoxError>,
318    MS::Error: Into<crate::BoxError>,
319    Target: Clone + fmt::Debug,
320    MS::Service: fmt::Debug,
321{
322    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
323        f.debug_struct("Pool")
324            .field("balance", &self.balance)
325            .field("options", &self.options)
326            .field("ewma", &self.ewma)
327            .finish()
328    }
329}
330
331impl<MS, Target, Request> Pool<MS, Target, Request>
332where
333    MS: MakeService<Target, Request>,
334    MS::Service: Load,
335    <MS::Service as Load>::Metric: std::fmt::Debug,
336    MS::MakeError: Into<crate::BoxError>,
337    MS::Error: Into<crate::BoxError>,
338    Target: Clone,
339{
340    pub fn new(make_service: MS, target: Target) -> Self {
347        Builder::new().build(make_service, target)
348    }
349}
350
351type PinBalance<S, Request> = Balance<Pin<Box<S>>, Request>;
352
353impl<MS, Target, Req> Service<Req> for Pool<MS, Target, Req>
354where
355    MS: MakeService<Target, Req>,
356    MS::Service: Load,
357    <MS::Service as Load>::Metric: std::fmt::Debug,
358    MS::MakeError: Into<crate::BoxError>,
359    MS::Error: Into<crate::BoxError>,
360    Target: Clone,
361{
362    type Response = <PinBalance<PoolDiscoverer<MS, Target, Req>, Req> as Service<Req>>::Response;
363    type Error = <PinBalance<PoolDiscoverer<MS, Target, Req>, Req> as Service<Req>>::Error;
364    type Future = <PinBalance<PoolDiscoverer<MS, Target, Req>, Req> as Service<Req>>::Future;
365
366    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
367        if let Poll::Ready(()) = self.balance.poll_ready(cx)? {
368            self.ewma *= 1.0 - self.options.alpha;
371
372            let discover = self.balance.discover_mut().as_mut().project();
373            if self.ewma < self.options.low {
374                if *discover.load != Level::Low {
375                    tracing::trace!({ ewma = %self.ewma }, "pool is over-provisioned");
376                }
377                *discover.load = Level::Low;
378
379                if discover.services.len() > 1 {
380                    self.ewma = self.options.init;
382                }
383            } else {
384                if *discover.load != Level::Normal {
385                    tracing::trace!({ ewma = %self.ewma }, "pool is appropriately provisioned");
386                }
387                *discover.load = Level::Normal;
388            }
389
390            return Poll::Ready(Ok(()));
391        }
392
393        let discover = self.balance.discover_mut().as_mut().project();
394        if discover.making.is_none() {
395            self.ewma = self.options.alpha + (1.0 - self.options.alpha) * self.ewma;
398
399            if self.ewma > self.options.high {
400                if *discover.load != Level::High {
401                    tracing::trace!({ ewma = %self.ewma }, "pool is under-provisioned");
402                }
403                *discover.load = Level::High;
404
405                self.ewma = self.options.high;
409
410                return self.balance.poll_ready(cx);
413            } else {
414                *discover.load = Level::Normal;
415            }
416        }
417
418        Poll::Pending
419    }
420
421    fn call(&mut self, req: Req) -> Self::Future {
422        self.balance.call(req)
423    }
424}
425
426#[doc(hidden)]
427#[derive(Debug)]
428pub struct DropNotifyService<Svc> {
429    svc: Svc,
430    id: usize,
431    notify: tokio::sync::mpsc::UnboundedSender<usize>,
432}
433
434impl<Svc> Drop for DropNotifyService<Svc> {
435    fn drop(&mut self) {
436        let _ = self.notify.send(self.id).is_ok();
437    }
438}
439
440impl<Svc: Load> Load for DropNotifyService<Svc> {
441    type Metric = Svc::Metric;
442    fn load(&self) -> Self::Metric {
443        self.svc.load()
444    }
445}
446
447impl<Request, Svc: Service<Request>> Service<Request> for DropNotifyService<Svc> {
448    type Response = Svc::Response;
449    type Future = Svc::Future;
450    type Error = Svc::Error;
451
452    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
453        self.svc.poll_ready(cx)
454    }
455
456    fn call(&mut self, req: Request) -> Self::Future {
457        self.svc.call(req)
458    }
459}