tower_sessions_core/
session_store.rs

1//! A session backend for managing session state.
2//!
3//! This crate provides the ability to use custom backends for session
4//! management by implementing the [`SessionStore`] trait. This trait defines
5//! the necessary operations for creating, saving, loading, and deleting session
6//! records.
7//!
8//! # Implementing a Custom Store
9//!
10//! Below is an example of implementing a custom session store using an
11//! in-memory [`HashMap`]. This example is for illustration purposes only; you
12//! can use the provided [`MemoryStore`] directly without implementing it
13//! yourself.
14//!
15//! ```rust
16//! use std::{collections::HashMap, sync::Arc};
17//!
18//! use async_trait::async_trait;
19//! use time::OffsetDateTime;
20//! use tokio::sync::Mutex;
21//! use tower_sessions_core::{
22//!     session::{Id, Record},
23//!     session_store, SessionStore,
24//! };
25//!
26//! #[derive(Clone, Debug, Default)]
27//! pub struct MemoryStore(Arc<Mutex<HashMap<Id, Record>>>);
28//!
29//! #[async_trait]
30//! impl SessionStore for MemoryStore {
31//!     async fn create(&self, record: &mut Record) -> session_store::Result<()> {
32//!         let mut store_guard = self.0.lock().await;
33//!         while store_guard.contains_key(&record.id) {
34//!             // Session ID collision mitigation.
35//!             record.id = Id::default();
36//!         }
37//!         store_guard.insert(record.id, record.clone());
38//!         Ok(())
39//!     }
40//!
41//!     async fn save(&self, record: &Record) -> session_store::Result<()> {
42//!         self.0.lock().await.insert(record.id, record.clone());
43//!         Ok(())
44//!     }
45//!
46//!     async fn load(&self, session_id: &Id) -> session_store::Result<Option<Record>> {
47//!         Ok(self
48//!             .0
49//!             .lock()
50//!             .await
51//!             .get(session_id)
52//!             .filter(|Record { expiry_date, .. }| is_active(*expiry_date))
53//!             .cloned())
54//!     }
55//!
56//!     async fn delete(&self, session_id: &Id) -> session_store::Result<()> {
57//!         self.0.lock().await.remove(session_id);
58//!         Ok(())
59//!     }
60//! }
61//!
62//! fn is_active(expiry_date: OffsetDateTime) -> bool {
63//!     expiry_date > OffsetDateTime::now_utc()
64//! }
65//! ```
66//!
67//! # Session Store Trait
68//!
69//! The [`SessionStore`] trait defines the interface for session management.
70//! Implementations must handle session creation, saving, loading, and deletion.
71//!
72//! # CachingSessionStore
73//!
74//! The [`CachingSessionStore`] provides a layered caching mechanism with a
75//! cache as the frontend and a store as the backend. This can improve read
76//! performance by reducing the need to access the backend store for frequently
77//! accessed sessions.
78//!
79//! # ExpiredDeletion
80//!
81//! The [`ExpiredDeletion`] trait provides a method for deleting expired
82//! sessions. Implementations can optionally provide a method for continuously
83//! deleting expired sessions at a specified interval.
84use std::fmt::Debug;
85
86use async_trait::async_trait;
87
88use crate::session::{Id, Record};
89
90/// Stores must map any errors that might occur during their use to this type.
91#[derive(thiserror::Error, Debug)]
92pub enum Error {
93    #[error("Encoding failed with: {0}")]
94    Encode(String),
95
96    #[error("Decoding failed with: {0}")]
97    Decode(String),
98
99    #[error("{0}")]
100    Backend(String),
101}
102
103pub type Result<T> = std::result::Result<T, Error>;
104
105/// Defines the interface for session management.
106///
107/// See [`session_store`](crate::session_store) for more details.
108#[async_trait]
109pub trait SessionStore: Debug + Send + Sync + 'static {
110    /// Creates a new session in the store with the provided session record.
111    ///
112    /// Implementers must decide how to handle potential ID collisions. For
113    /// example, they might generate a new unique ID or return `Error::Backend`.
114    ///
115    /// The record is given as an exclusive reference to allow modifications,
116    /// such as assigning a new ID, during the creation process.
117    async fn create(&self, session_record: &mut Record) -> Result<()> {
118        default_create(self, session_record).await
119    }
120
121    /// Saves the provided session record to the store.
122    ///
123    /// This method is intended for updating the state of an existing session.
124    async fn save(&self, session_record: &Record) -> Result<()>;
125
126    /// Loads an existing session record from the store using the provided ID.
127    ///
128    /// If a session with the given ID exists, it is returned. If the session
129    /// does not exist or has been invalidated (e.g., expired), `None` is
130    /// returned.
131    async fn load(&self, session_id: &Id) -> Result<Option<Record>>;
132
133    /// Deletes a session record from the store using the provided ID.
134    ///
135    /// If the session exists, it is removed from the store.
136    async fn delete(&self, session_id: &Id) -> Result<()>;
137}
138
139async fn default_create<S: SessionStore + ?Sized>(
140    store: &S,
141    session_record: &mut Record,
142) -> Result<()> {
143    tracing::warn!(
144        "The default implementation of `SessionStore::create` is being used, which relies on \
145         `SessionStore::save`. To properly handle potential ID collisions, it is recommended that \
146         stores implement their own version of `SessionStore::create`."
147    );
148    store.save(session_record).await?;
149    Ok(())
150}
151
152/// Provides a layered caching mechanism with a cache as the frontend and a
153/// store as the backend..
154///
155/// Contains both a cache, which acts as a frontend, and a store which acts as a
156/// backend. Both cache and store implement `SessionStore`.
157///
158/// By using a cache, the cost of reads can be greatly reduced as once cached,
159/// reads need only interact with the frontend, forgoing the cost of retrieving
160/// the session record from the backend.
161///
162/// # Examples
163///
164/// ```rust,ignore
165/// # tokio_test::block_on(async {
166/// use tower_sessions::CachingSessionStore;
167/// use tower_sessions_moka_store::MokaStore;
168/// use tower_sessions_sqlx_store::{SqlitePool, SqliteStore};
169/// let pool = SqlitePool::connect("sqlite::memory:").await.unwrap();
170/// let sqlite_store = SqliteStore::new(pool);
171/// let moka_store = MokaStore::new(Some(2_000));
172/// let caching_store = CachingSessionStore::new(moka_store, sqlite_store);
173/// # })
174/// ```
175#[derive(Debug, Clone)]
176pub struct CachingSessionStore<Cache: SessionStore, Store: SessionStore> {
177    cache: Cache,
178    store: Store,
179}
180
181impl<Cache: SessionStore, Store: SessionStore> CachingSessionStore<Cache, Store> {
182    /// Create a new `CachingSessionStore`.
183    pub fn new(cache: Cache, store: Store) -> Self {
184        Self { cache, store }
185    }
186}
187
188#[async_trait]
189impl<Cache, Store> SessionStore for CachingSessionStore<Cache, Store>
190where
191    Cache: SessionStore,
192    Store: SessionStore,
193{
194    async fn create(&self, record: &mut Record) -> Result<()> {
195        self.store.create(record).await?;
196        self.cache.create(record).await?;
197        Ok(())
198    }
199
200    async fn save(&self, record: &Record) -> Result<()> {
201        let store_save_fut = self.store.save(record);
202        let cache_save_fut = self.cache.save(record);
203
204        futures::try_join!(store_save_fut, cache_save_fut)?;
205
206        Ok(())
207    }
208
209    async fn load(&self, session_id: &Id) -> Result<Option<Record>> {
210        match self.cache.load(session_id).await {
211            // We found a session in the cache, so let's use it.
212            Ok(Some(session_record)) => Ok(Some(session_record)),
213
214            // We didn't find a session in the cache, so we'll try loading from the backend.
215            //
216            // When we find a session in the backend, we'll hydrate our cache with it.
217            Ok(None) => {
218                let session_record = self.store.load(session_id).await?;
219
220                if let Some(ref session_record) = session_record {
221                    self.cache.save(session_record).await?;
222                }
223
224                Ok(session_record)
225            }
226
227            // Some error occurred with our cache so we'll bubble this up.
228            Err(err) => Err(err),
229        }
230    }
231
232    async fn delete(&self, session_id: &Id) -> Result<()> {
233        let store_delete_fut = self.store.delete(session_id);
234        let cache_delete_fut = self.cache.delete(session_id);
235
236        futures::try_join!(store_delete_fut, cache_delete_fut)?;
237
238        Ok(())
239    }
240}
241
242/// Provides a method for deleting expired sessions.
243#[async_trait]
244pub trait ExpiredDeletion: SessionStore
245where
246    Self: Sized,
247{
248    /// A method for deleting expired sessions from the store.
249    async fn delete_expired(&self) -> Result<()>;
250
251    /// This function will keep running indefinitely, deleting expired rows and
252    /// then waiting for the specified period before deleting again.
253    ///
254    /// Generally this will be used as a task, for example via
255    /// `tokio::task::spawn`.
256    ///
257    /// # Errors
258    ///
259    /// This function returns a `Result` that contains an error of type
260    /// `sqlx::Error` if the deletion operation fails.
261    ///
262    /// # Examples
263    ///
264    /// ```rust,no_run,ignore
265    /// use tower_sessions::session_store::ExpiredDeletion;
266    /// use tower_sessions_sqlx_store::{sqlx::SqlitePool, SqliteStore};
267    ///
268    /// # {
269    /// # tokio_test::block_on(async {
270    /// let pool = SqlitePool::connect("sqlite::memory:").await.unwrap();
271    /// let session_store = SqliteStore::new(pool);
272    ///
273    /// tokio::task::spawn(
274    ///     session_store
275    ///         .clone()
276    ///         .continuously_delete_expired(tokio::time::Duration::from_secs(60)),
277    /// );
278    /// # })
279    /// ```
280    #[cfg(feature = "deletion-task")]
281    #[cfg_attr(docsrs, doc(cfg(feature = "deletion-task")))]
282    async fn continuously_delete_expired(self, period: tokio::time::Duration) -> Result<()> {
283        let mut interval = tokio::time::interval(period);
284        interval.tick().await; // The first tick completes immediately; skip.
285        loop {
286            interval.tick().await;
287            self.delete_expired().await?;
288        }
289    }
290}
291
292#[cfg(test)]
293mod tests {
294    use mockall::{
295        mock,
296        predicate::{self, *},
297    };
298    use time::{Duration, OffsetDateTime};
299
300    use super::*;
301
302    mock! {
303        #[derive(Debug)]
304        pub Cache {}
305
306        #[async_trait]
307        impl SessionStore for Cache {
308            async fn create(&self, record: &mut Record) -> Result<()>;
309            async fn save(&self, record: &Record) -> Result<()>;
310            async fn load(&self, session_id: &Id) -> Result<Option<Record>>;
311            async fn delete(&self, session_id: &Id) -> Result<()>;
312        }
313    }
314
315    mock! {
316        #[derive(Debug)]
317        pub Store {}
318
319        #[async_trait]
320        impl SessionStore for Store {
321            async fn create(&self, record: &mut Record) -> Result<()>;
322            async fn save(&self, record: &Record) -> Result<()>;
323            async fn load(&self, session_id: &Id) -> Result<Option<Record>>;
324            async fn delete(&self, session_id: &Id) -> Result<()>;
325        }
326    }
327
328    mock! {
329        #[derive(Debug)]
330        pub CollidingStore {}
331
332        #[async_trait]
333        impl SessionStore for CollidingStore {
334            async fn save(&self, record: &Record) -> Result<()>;
335            async fn load(&self, session_id: &Id) -> Result<Option<Record>>;
336            async fn delete(&self, session_id: &Id) -> Result<()>;
337        }
338    }
339
340    #[tokio::test]
341    async fn test_create() {
342        let mut store = MockCollidingStore::new();
343        let mut record = Record {
344            id: Default::default(),
345            data: Default::default(),
346            expiry_date: OffsetDateTime::now_utc() + Duration::minutes(30),
347        };
348
349        store
350            .expect_save()
351            .with(predicate::eq(record.clone()))
352            .times(1)
353            .returning(|_| Ok(()));
354        let result = store.create(&mut record).await;
355        assert!(result.is_ok());
356    }
357
358    #[tokio::test]
359    async fn test_save() {
360        let mut store = MockStore::new();
361        let record = Record {
362            id: Default::default(),
363            data: Default::default(),
364            expiry_date: OffsetDateTime::now_utc() + Duration::minutes(30),
365        };
366        store
367            .expect_save()
368            .with(predicate::eq(record.clone()))
369            .times(1)
370            .returning(|_| Ok(()));
371
372        let result = store.save(&record).await;
373        assert!(result.is_ok());
374    }
375
376    #[tokio::test]
377    async fn test_load() {
378        let mut store = MockStore::new();
379        let session_id = Id::default();
380        let record = Record {
381            id: Default::default(),
382            data: Default::default(),
383            expiry_date: OffsetDateTime::now_utc() + Duration::minutes(30),
384        };
385        let expected_record = record.clone();
386
387        store
388            .expect_load()
389            .with(predicate::eq(session_id))
390            .times(1)
391            .returning(move |_| Ok(Some(record.clone())));
392
393        let result = store.load(&session_id).await;
394        assert!(result.is_ok());
395        assert_eq!(result.unwrap(), Some(expected_record));
396    }
397
398    #[tokio::test]
399    async fn test_delete() {
400        let mut store = MockStore::new();
401        let session_id = Id::default();
402
403        store
404            .expect_delete()
405            .with(predicate::eq(session_id))
406            .times(1)
407            .returning(|_| Ok(()));
408
409        let result = store.delete(&session_id).await;
410        assert!(result.is_ok());
411    }
412
413    #[tokio::test]
414    async fn test_caching_store_create() {
415        let mut cache = MockCache::new();
416        let mut store = MockStore::new();
417        let mut record = Record {
418            id: Default::default(),
419            data: Default::default(),
420            expiry_date: OffsetDateTime::now_utc() + Duration::minutes(30),
421        };
422
423        cache.expect_create().times(1).returning(|_| Ok(()));
424        store.expect_create().times(1).returning(|_| Ok(()));
425
426        let caching_store = CachingSessionStore::new(cache, store);
427        let result = caching_store.create(&mut record).await;
428        assert!(result.is_ok());
429    }
430
431    #[tokio::test]
432    async fn test_caching_store_save() {
433        let mut cache = MockCache::new();
434        let mut store = MockStore::new();
435        let record = Record {
436            id: Default::default(),
437            data: Default::default(),
438            expiry_date: OffsetDateTime::now_utc() + Duration::minutes(30),
439        };
440
441        cache
442            .expect_save()
443            .with(predicate::eq(record.clone()))
444            .times(1)
445            .returning(|_| Ok(()));
446        store
447            .expect_save()
448            .with(predicate::eq(record.clone()))
449            .times(1)
450            .returning(|_| Ok(()));
451
452        let caching_store = CachingSessionStore::new(cache, store);
453        let result = caching_store.save(&record).await;
454        assert!(result.is_ok());
455    }
456
457    #[tokio::test]
458    async fn test_caching_store_load() {
459        let mut cache = MockCache::new();
460        let mut store = MockStore::new();
461        let session_id = Id::default();
462        let record = Record {
463            id: Default::default(),
464            data: Default::default(),
465            expiry_date: OffsetDateTime::now_utc() + Duration::minutes(30),
466        };
467        let expected_record = record.clone();
468
469        cache
470            .expect_load()
471            .with(predicate::eq(session_id))
472            .times(1)
473            .returning(move |_| Ok(Some(record.clone())));
474        // Store load should not be called since cache returns a record
475        store.expect_load().times(0);
476
477        let caching_store = CachingSessionStore::new(cache, store);
478        let result = caching_store.load(&session_id).await;
479        assert!(result.is_ok());
480        assert_eq!(result.unwrap(), Some(expected_record));
481    }
482
483    #[tokio::test]
484    async fn test_caching_store_delete() {
485        let mut cache = MockCache::new();
486        let mut store = MockStore::new();
487        let session_id = Id::default();
488
489        cache
490            .expect_delete()
491            .with(predicate::eq(session_id))
492            .times(1)
493            .returning(|_| Ok(()));
494        store
495            .expect_delete()
496            .with(predicate::eq(session_id))
497            .times(1)
498            .returning(|_| Ok(()));
499
500        let caching_store = CachingSessionStore::new(cache, store);
501        let result = caching_store.delete(&session_id).await;
502        assert!(result.is_ok());
503    }
504}