tower_sessions_core/session.rs
1//! A session which allows HTTP applications to associate data with visitors.
2use std::{
3 collections::HashMap,
4 fmt::{self, Display},
5 hash::Hash,
6 result,
7 str::{self, FromStr},
8 sync::{
9 atomic::{self, AtomicBool},
10 Arc,
11 },
12};
13
14use base64::{engine::general_purpose::URL_SAFE_NO_PAD, DecodeError, Engine as _};
15use serde::{de::DeserializeOwned, Deserialize, Serialize};
16use serde_json::Value;
17use time::{Duration, OffsetDateTime};
18use tokio::sync::{MappedMutexGuard, Mutex, MutexGuard};
19
20use crate::{session_store, SessionStore};
21
22const DEFAULT_DURATION: Duration = Duration::weeks(2);
23
24type Result<T> = result::Result<T, Error>;
25
26type Data = HashMap<String, Value>;
27
28/// Session errors.
29#[derive(thiserror::Error, Debug)]
30pub enum Error {
31 /// Maps `serde_json` errors.
32 #[error(transparent)]
33 SerdeJson(#[from] serde_json::Error),
34
35 /// Maps `session_store::Error` errors.
36 #[error(transparent)]
37 Store(#[from] session_store::Error),
38}
39
40#[derive(Debug)]
41struct Inner {
42 // This will be `None` when:
43 //
44 // 1. We have not been provided a session cookie or have failed to parse it,
45 // 2. The store has not found the session.
46 //
47 // Sync lock, see: https://docs.rs/tokio/latest/tokio/sync/struct.Mutex.html#which-kind-of-mutex-should-you-use
48 session_id: parking_lot::Mutex<Option<Id>>,
49
50 // A lazy representation of the session's value, hydrated on a just-in-time basis. A
51 // `None` value indicates we have not tried to access it yet. After access, it will always
52 // contain `Some(Record)`.
53 record: Mutex<Option<Record>>,
54
55 // Sync lock, see: https://docs.rs/tokio/latest/tokio/sync/struct.Mutex.html#which-kind-of-mutex-should-you-use
56 expiry: parking_lot::Mutex<Option<Expiry>>,
57
58 is_modified: AtomicBool,
59}
60
61/// A session which allows HTTP applications to associate key-value pairs with
62/// visitors.
63#[derive(Debug, Clone)]
64pub struct Session {
65 store: Arc<dyn SessionStore>,
66 inner: Arc<Inner>,
67}
68
69impl Session {
70 /// Creates a new session with the session ID, store, and expiry.
71 ///
72 /// This method is lazy and does not invoke the overhead of talking to the
73 /// backing store.
74 ///
75 /// # Examples
76 ///
77 /// ```rust
78 /// use std::sync::Arc;
79 ///
80 /// use tower_sessions::{MemoryStore, Session};
81 ///
82 /// let store = Arc::new(MemoryStore::default());
83 /// Session::new(None, store, None);
84 /// ```
85 pub fn new(
86 session_id: Option<Id>,
87 store: Arc<impl SessionStore>,
88 expiry: Option<Expiry>,
89 ) -> Self {
90 let inner = Inner {
91 session_id: parking_lot::Mutex::new(session_id),
92 record: Mutex::new(None), // `None` indicates we have not loaded from store.
93 expiry: parking_lot::Mutex::new(expiry),
94 is_modified: AtomicBool::new(false),
95 };
96
97 Self {
98 store,
99 inner: Arc::new(inner),
100 }
101 }
102
103 fn create_record(&self) -> Record {
104 Record::new(self.expiry_date())
105 }
106
107 #[tracing::instrument(skip(self), err)]
108 async fn get_record(&self) -> Result<MappedMutexGuard<Record>> {
109 let mut record_guard = self.inner.record.lock().await;
110
111 // Lazily load the record since `None` here indicates we have no yet loaded it.
112 if record_guard.is_none() {
113 tracing::trace!("record not loaded from store; loading");
114
115 let session_id = *self.inner.session_id.lock();
116 *record_guard = Some(if let Some(session_id) = session_id {
117 match self.store.load(&session_id).await? {
118 Some(loaded_record) => {
119 tracing::trace!("record found in store");
120 loaded_record
121 }
122
123 None => {
124 // A well-behaved user agent should not send session cookies after
125 // expiration. Even so it's possible for an expired session to be removed
126 // from the store after a request was initiated. However, such a race should
127 // be relatively uncommon and as such entering this branch could indicate
128 // malicious behavior.
129 tracing::warn!("possibly suspicious activity: record not found in store");
130 *self.inner.session_id.lock() = None;
131 self.create_record()
132 }
133 }
134 } else {
135 tracing::trace!("session id not found");
136 self.create_record()
137 })
138 }
139
140 Ok(MutexGuard::map(record_guard, |opt| {
141 opt.as_mut()
142 .expect("Record should always be `Option::Some` at this point")
143 }))
144 }
145
146 /// Inserts a `impl Serialize` value into the session.
147 ///
148 /// # Examples
149 ///
150 /// ```rust
151 /// # tokio_test::block_on(async {
152 /// use std::sync::Arc;
153 ///
154 /// use tower_sessions::{MemoryStore, Session};
155 ///
156 /// let store = Arc::new(MemoryStore::default());
157 /// let session = Session::new(None, store, None);
158 ///
159 /// session.insert("foo", 42).await.unwrap();
160 ///
161 /// let value = session.get::<usize>("foo").await.unwrap();
162 /// assert_eq!(value, Some(42));
163 /// # });
164 /// ```
165 ///
166 /// # Errors
167 ///
168 /// - This method can fail when [`serde_json::to_value`] fails.
169 /// - If the session has not been hydrated and loading from the store fails,
170 /// we fail with [`Error::Store`].
171 pub async fn insert(&self, key: &str, value: impl Serialize) -> Result<()> {
172 self.insert_value(key, serde_json::to_value(&value)?)
173 .await?;
174 Ok(())
175 }
176
177 /// Inserts a `serde_json::Value` into the session.
178 ///
179 /// If the key was not present in the underlying map, `None` is returned and
180 /// `modified` is set to `true`.
181 ///
182 /// If the underlying map did have the key and its value is the same as the
183 /// provided value, `None` is returned and `modified` is not set.
184 ///
185 /// # Examples
186 ///
187 /// ```rust
188 /// # tokio_test::block_on(async {
189 /// use std::sync::Arc;
190 ///
191 /// use tower_sessions::{MemoryStore, Session};
192 ///
193 /// let store = Arc::new(MemoryStore::default());
194 /// let session = Session::new(None, store, None);
195 ///
196 /// let value = session
197 /// .insert_value("foo", serde_json::json!(42))
198 /// .await
199 /// .unwrap();
200 /// assert!(value.is_none());
201 ///
202 /// let value = session
203 /// .insert_value("foo", serde_json::json!(42))
204 /// .await
205 /// .unwrap();
206 /// assert!(value.is_none());
207 ///
208 /// let value = session
209 /// .insert_value("foo", serde_json::json!("bar"))
210 /// .await
211 /// .unwrap();
212 /// assert_eq!(value, Some(serde_json::json!(42)));
213 /// # });
214 /// ```
215 ///
216 /// # Errors
217 ///
218 /// - If the session has not been hydrated and loading from the store fails,
219 /// we fail with [`Error::Store`].
220 pub async fn insert_value(&self, key: &str, value: Value) -> Result<Option<Value>> {
221 let mut record_guard = self.get_record().await?;
222 Ok(if record_guard.data.get(key) != Some(&value) {
223 self.inner
224 .is_modified
225 .store(true, atomic::Ordering::Release);
226 record_guard.data.insert(key.to_string(), value)
227 } else {
228 None
229 })
230 }
231
232 /// Gets a value from the store.
233 ///
234 /// # Examples
235 ///
236 /// ```rust
237 /// # tokio_test::block_on(async {
238 /// use std::sync::Arc;
239 ///
240 /// use tower_sessions::{MemoryStore, Session};
241 ///
242 /// let store = Arc::new(MemoryStore::default());
243 /// let session = Session::new(None, store, None);
244 ///
245 /// session.insert("foo", 42).await.unwrap();
246 ///
247 /// let value = session.get::<usize>("foo").await.unwrap();
248 /// assert_eq!(value, Some(42));
249 /// # });
250 /// ```
251 ///
252 /// # Errors
253 ///
254 /// - This method can fail when [`serde_json::from_value`] fails.
255 /// - If the session has not been hydrated and loading from the store fails,
256 /// we fail with [`Error::Store`].
257 pub async fn get<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
258 Ok(self
259 .get_value(key)
260 .await?
261 .map(serde_json::from_value)
262 .transpose()?)
263 }
264
265 /// Gets a `serde_json::Value` from the store.
266 ///
267 /// # Examples
268 ///
269 /// ```rust
270 /// # tokio_test::block_on(async {
271 /// use std::sync::Arc;
272 ///
273 /// use tower_sessions::{MemoryStore, Session};
274 ///
275 /// let store = Arc::new(MemoryStore::default());
276 /// let session = Session::new(None, store, None);
277 ///
278 /// session.insert("foo", 42).await.unwrap();
279 ///
280 /// let value = session.get_value("foo").await.unwrap().unwrap();
281 /// assert_eq!(value, serde_json::json!(42));
282 /// # });
283 /// ```
284 ///
285 /// # Errors
286 ///
287 /// - If the session has not been hydrated and loading from the store fails,
288 /// we fail with [`Error::Store`].
289 pub async fn get_value(&self, key: &str) -> Result<Option<Value>> {
290 let record_guard = self.get_record().await?;
291 Ok(record_guard.data.get(key).cloned())
292 }
293
294 /// Removes a value from the store, retuning the value of the key if it was
295 /// present in the underlying map.
296 ///
297 /// # Examples
298 ///
299 /// ```rust
300 /// # tokio_test::block_on(async {
301 /// use std::sync::Arc;
302 ///
303 /// use tower_sessions::{MemoryStore, Session};
304 ///
305 /// let store = Arc::new(MemoryStore::default());
306 /// let session = Session::new(None, store, None);
307 ///
308 /// session.insert("foo", 42).await.unwrap();
309 ///
310 /// let value: Option<usize> = session.remove("foo").await.unwrap();
311 /// assert_eq!(value, Some(42));
312 ///
313 /// let value: Option<usize> = session.get("foo").await.unwrap();
314 /// assert!(value.is_none());
315 /// # });
316 /// ```
317 ///
318 /// # Errors
319 ///
320 /// - This method can fail when [`serde_json::from_value`] fails.
321 /// - If the session has not been hydrated and loading from the store fails,
322 /// we fail with [`Error::Store`].
323 pub async fn remove<T: DeserializeOwned>(&self, key: &str) -> Result<Option<T>> {
324 Ok(self
325 .remove_value(key)
326 .await?
327 .map(serde_json::from_value)
328 .transpose()?)
329 }
330
331 /// Removes a `serde_json::Value` from the session.
332 ///
333 /// # Examples
334 ///
335 /// ```rust
336 /// # tokio_test::block_on(async {
337 /// use std::sync::Arc;
338 ///
339 /// use tower_sessions::{MemoryStore, Session};
340 ///
341 /// let store = Arc::new(MemoryStore::default());
342 /// let session = Session::new(None, store, None);
343 ///
344 /// session.insert("foo", 42).await.unwrap();
345 /// let value = session.remove_value("foo").await.unwrap().unwrap();
346 /// assert_eq!(value, serde_json::json!(42));
347 ///
348 /// let value: Option<usize> = session.get("foo").await.unwrap();
349 /// assert!(value.is_none());
350 /// # });
351 /// ```
352 ///
353 /// # Errors
354 ///
355 /// - If the session has not been hydrated and loading from the store fails,
356 /// we fail with [`Error::Store`].
357 pub async fn remove_value(&self, key: &str) -> Result<Option<Value>> {
358 let mut record_guard = self.get_record().await?;
359 self.inner
360 .is_modified
361 .store(true, atomic::Ordering::Release);
362 Ok(record_guard.data.remove(key))
363 }
364
365 /// Clears the session of all data but does not delete it from the store.
366 ///
367 /// # Examples
368 ///
369 /// ```rust
370 /// # tokio_test::block_on(async {
371 /// use std::sync::Arc;
372 ///
373 /// use tower_sessions::{MemoryStore, Session};
374 ///
375 /// let store = Arc::new(MemoryStore::default());
376 ///
377 /// let session = Session::new(None, store.clone(), None);
378 /// session.insert("foo", 42).await.unwrap();
379 /// assert!(!session.is_empty().await);
380 ///
381 /// session.save().await.unwrap();
382 ///
383 /// session.clear().await;
384 ///
385 /// // Not empty! (We have an ID still.)
386 /// assert!(!session.is_empty().await);
387 /// // Data is cleared...
388 /// assert!(session.get::<usize>("foo").await.unwrap().is_none());
389 ///
390 /// // ...data is cleared before loading from the backend...
391 /// let session = Session::new(session.id(), store.clone(), None);
392 /// session.clear().await;
393 /// assert!(session.get::<usize>("foo").await.unwrap().is_none());
394 ///
395 /// let session = Session::new(session.id(), store, None);
396 /// // ...but data is not deleted from the store.
397 /// assert_eq!(session.get::<usize>("foo").await.unwrap(), Some(42));
398 /// # });
399 /// ```
400 pub async fn clear(&self) {
401 let mut record_guard = self.inner.record.lock().await;
402 if let Some(record) = record_guard.as_mut() {
403 record.data.clear();
404 } else if let Some(session_id) = *self.inner.session_id.lock() {
405 let mut new_record = self.create_record();
406 new_record.id = session_id;
407 *record_guard = Some(new_record);
408 }
409
410 self.inner
411 .is_modified
412 .store(true, atomic::Ordering::Release);
413 }
414
415 /// Returns `true` if there is no session ID and the session is empty.
416 ///
417 /// # Examples
418 ///
419 /// ```rust
420 /// # tokio_test::block_on(async {
421 /// use std::sync::Arc;
422 ///
423 /// use tower_sessions::{session::Id, MemoryStore, Session};
424 ///
425 /// let store = Arc::new(MemoryStore::default());
426 ///
427 /// let session = Session::new(None, store.clone(), None);
428 /// // Empty if we have no ID and record is not loaded.
429 /// assert!(session.is_empty().await);
430 ///
431 /// let session = Session::new(Some(Id::default()), store.clone(), None);
432 /// // Not empty if we have an ID but no record. (Record is not loaded here.)
433 /// assert!(!session.is_empty().await);
434 ///
435 /// let session = Session::new(Some(Id::default()), store.clone(), None);
436 /// session.insert("foo", 42).await.unwrap();
437 /// // Not empty after inserting.
438 /// assert!(!session.is_empty().await);
439 /// session.save().await.unwrap();
440 /// // Not empty after saving.
441 /// assert!(!session.is_empty().await);
442 ///
443 /// let session = Session::new(session.id(), store.clone(), None);
444 /// session.load().await.unwrap();
445 /// // Not empty after loading from store...
446 /// assert!(!session.is_empty().await);
447 /// // ...and not empty after accessing the session.
448 /// session.get::<usize>("foo").await.unwrap();
449 /// assert!(!session.is_empty().await);
450 ///
451 /// let session = Session::new(session.id(), store.clone(), None);
452 /// session.delete().await.unwrap();
453 /// // Not empty after deleting from store...
454 /// assert!(!session.is_empty().await);
455 /// session.get::<usize>("foo").await.unwrap();
456 /// // ...but empty after trying to access the deleted session.
457 /// assert!(session.is_empty().await);
458 ///
459 /// let session = Session::new(None, store, None);
460 /// session.insert("foo", 42).await.unwrap();
461 /// session.flush().await.unwrap();
462 /// // Empty after flushing.
463 /// assert!(session.is_empty().await);
464 /// # });
465 /// ```
466 pub async fn is_empty(&self) -> bool {
467 let record_guard = self.inner.record.lock().await;
468
469 // N.B.: Session IDs are `None` if:
470 //
471 // 1. The cookie was not provided or otherwise could not be parsed,
472 // 2. Or the session could not be loaded from the store.
473 let session_id = self.inner.session_id.lock();
474
475 let Some(record) = record_guard.as_ref() else {
476 return session_id.is_none();
477 };
478
479 session_id.is_none() && record.data.is_empty()
480 }
481
482 /// Get the session ID.
483 ///
484 /// # Examples
485 ///
486 /// ```rust
487 /// use std::sync::Arc;
488 ///
489 /// use tower_sessions::{session::Id, MemoryStore, Session};
490 ///
491 /// let store = Arc::new(MemoryStore::default());
492 ///
493 /// let session = Session::new(None, store.clone(), None);
494 /// assert!(session.id().is_none());
495 ///
496 /// let id = Some(Id::default());
497 /// let session = Session::new(id, store, None);
498 /// assert_eq!(id, session.id());
499 /// ```
500 pub fn id(&self) -> Option<Id> {
501 *self.inner.session_id.lock()
502 }
503
504 /// Get the session expiry.
505 ///
506 /// # Examples
507 ///
508 /// ```rust
509 /// use std::sync::Arc;
510 ///
511 /// use tower_sessions::{session::Expiry, MemoryStore, Session};
512 ///
513 /// let store = Arc::new(MemoryStore::default());
514 /// let session = Session::new(None, store, None);
515 ///
516 /// assert_eq!(session.expiry(), None);
517 /// ```
518 pub fn expiry(&self) -> Option<Expiry> {
519 *self.inner.expiry.lock()
520 }
521
522 /// Set `expiry` to the given value.
523 ///
524 /// This may be used within applications directly to alter the session's
525 /// time to live.
526 ///
527 /// # Examples
528 ///
529 /// ```rust
530 /// use std::sync::Arc;
531 ///
532 /// use time::OffsetDateTime;
533 /// use tower_sessions::{session::Expiry, MemoryStore, Session};
534 ///
535 /// let store = Arc::new(MemoryStore::default());
536 /// let session = Session::new(None, store, None);
537 ///
538 /// let expiry = Expiry::AtDateTime(OffsetDateTime::now_utc());
539 /// session.set_expiry(Some(expiry));
540 ///
541 /// assert_eq!(session.expiry(), Some(expiry));
542 /// ```
543 pub fn set_expiry(&self, expiry: Option<Expiry>) {
544 *self.inner.expiry.lock() = expiry;
545 self.inner
546 .is_modified
547 .store(true, atomic::Ordering::Release);
548 }
549
550 /// Get session expiry as `OffsetDateTime`.
551 ///
552 /// # Examples
553 ///
554 /// ```rust
555 /// use std::sync::Arc;
556 ///
557 /// use time::{Duration, OffsetDateTime};
558 /// use tower_sessions::{MemoryStore, Session};
559 ///
560 /// let store = Arc::new(MemoryStore::default());
561 /// let session = Session::new(None, store, None);
562 ///
563 /// // Our default duration is two weeks.
564 /// let expected_expiry = OffsetDateTime::now_utc().saturating_add(Duration::weeks(2));
565 ///
566 /// assert!(session.expiry_date() > expected_expiry.saturating_sub(Duration::seconds(1)));
567 /// assert!(session.expiry_date() < expected_expiry.saturating_add(Duration::seconds(1)));
568 /// ```
569 pub fn expiry_date(&self) -> OffsetDateTime {
570 let expiry = self.inner.expiry.lock();
571 match *expiry {
572 Some(Expiry::OnInactivity(duration)) => {
573 OffsetDateTime::now_utc().saturating_add(duration)
574 }
575 Some(Expiry::AtDateTime(datetime)) => datetime,
576 Some(Expiry::OnSessionEnd) | None => {
577 OffsetDateTime::now_utc().saturating_add(DEFAULT_DURATION) // TODO: The default should probably be configurable.
578 }
579 }
580 }
581
582 /// Get session expiry as `Duration`.
583 ///
584 /// # Examples
585 ///
586 /// ```rust
587 /// use std::sync::Arc;
588 ///
589 /// use time::Duration;
590 /// use tower_sessions::{MemoryStore, Session};
591 ///
592 /// let store = Arc::new(MemoryStore::default());
593 /// let session = Session::new(None, store, None);
594 ///
595 /// let expected_duration = Duration::weeks(2);
596 ///
597 /// assert!(session.expiry_age() > expected_duration.saturating_sub(Duration::seconds(1)));
598 /// assert!(session.expiry_age() < expected_duration.saturating_add(Duration::seconds(1)));
599 /// ```
600 pub fn expiry_age(&self) -> Duration {
601 std::cmp::max(
602 self.expiry_date() - OffsetDateTime::now_utc(),
603 Duration::ZERO,
604 )
605 }
606
607 /// Returns `true` if the session has been modified during the request.
608 ///
609 /// # Examples
610 ///
611 /// ```rust
612 /// # tokio_test::block_on(async {
613 /// use std::sync::Arc;
614 ///
615 /// use tower_sessions::{MemoryStore, Session};
616 ///
617 /// let store = Arc::new(MemoryStore::default());
618 /// let session = Session::new(None, store, None);
619 ///
620 /// // Not modified initially.
621 /// assert!(!session.is_modified());
622 ///
623 /// // Getting doesn't count as a modification.
624 /// session.get::<usize>("foo").await.unwrap();
625 /// assert!(!session.is_modified());
626 ///
627 /// // Insertions and removals do though.
628 /// session.insert("foo", 42).await.unwrap();
629 /// assert!(session.is_modified());
630 /// # });
631 /// ```
632 pub fn is_modified(&self) -> bool {
633 self.inner.is_modified.load(atomic::Ordering::Acquire)
634 }
635
636 /// Saves the session record to the store.
637 ///
638 /// Note that this method is generally not needed and is reserved for
639 /// situations where the session store must be updated during the
640 /// request.
641 ///
642 /// # Examples
643 ///
644 /// ```rust
645 /// # tokio_test::block_on(async {
646 /// use std::sync::Arc;
647 ///
648 /// use tower_sessions::{MemoryStore, Session};
649 ///
650 /// let store = Arc::new(MemoryStore::default());
651 /// let session = Session::new(None, store.clone(), None);
652 ///
653 /// session.insert("foo", 42).await.unwrap();
654 /// session.save().await.unwrap();
655 ///
656 /// let session = Session::new(session.id(), store, None);
657 /// assert_eq!(session.get::<usize>("foo").await.unwrap().unwrap(), 42);
658 /// # });
659 /// ```
660 ///
661 /// # Errors
662 ///
663 /// - If saving to the store fails, we fail with [`Error::Store`].
664 #[tracing::instrument(skip(self), err)]
665 pub async fn save(&self) -> Result<()> {
666 let mut record_guard = self.get_record().await?;
667 record_guard.expiry_date = self.expiry_date();
668
669 // Session ID is `None` if:
670 //
671 // 1. No valid cookie was found on the request or,
672 // 2. No valid session was found in the store.
673 //
674 // In either case, we must create a new session via the store interface.
675 //
676 // Potential ID collisions must be handled by session store implementers.
677 if self.inner.session_id.lock().is_none() {
678 self.store.create(&mut record_guard).await?;
679 *self.inner.session_id.lock() = Some(record_guard.id);
680 } else {
681 self.store.save(&record_guard).await?;
682 }
683 Ok(())
684 }
685
686 /// Loads the session record from the store.
687 ///
688 /// Note that this method is generally not needed and is reserved for
689 /// situations where the session must be updated during the request.
690 ///
691 /// # Examples
692 ///
693 /// ```rust
694 /// # tokio_test::block_on(async {
695 /// use std::sync::Arc;
696 ///
697 /// use tower_sessions::{session::Id, MemoryStore, Session};
698 ///
699 /// let store = Arc::new(MemoryStore::default());
700 /// let id = Some(Id::default());
701 /// let session = Session::new(id, store.clone(), None);
702 ///
703 /// session.insert("foo", 42).await.unwrap();
704 /// session.save().await.unwrap();
705 ///
706 /// let session = Session::new(session.id(), store, None);
707 /// session.load().await.unwrap();
708 ///
709 /// assert_eq!(session.get::<usize>("foo").await.unwrap().unwrap(), 42);
710 /// # });
711 /// ```
712 ///
713 /// # Errors
714 ///
715 /// - If loading from the store fails, we fail with [`Error::Store`].
716 #[tracing::instrument(skip(self), err)]
717 pub async fn load(&self) -> Result<()> {
718 let session_id = *self.inner.session_id.lock();
719 let Some(ref id) = session_id else {
720 tracing::warn!("called load with no session id");
721 return Ok(());
722 };
723 let loaded_record = self.store.load(id).await.map_err(Error::Store)?;
724 let mut record_guard = self.inner.record.lock().await;
725 *record_guard = loaded_record;
726 Ok(())
727 }
728
729 /// Deletes the session from the store.
730 ///
731 /// # Examples
732 ///
733 /// ```rust
734 /// # tokio_test::block_on(async {
735 /// use std::sync::Arc;
736 ///
737 /// use tower_sessions::{session::Id, MemoryStore, Session, SessionStore};
738 ///
739 /// let store = Arc::new(MemoryStore::default());
740 /// let session = Session::new(Some(Id::default()), store.clone(), None);
741 ///
742 /// // Save before deleting.
743 /// session.save().await.unwrap();
744 ///
745 /// // Delete from the store.
746 /// session.delete().await.unwrap();
747 ///
748 /// assert!(store.load(&session.id().unwrap()).await.unwrap().is_none());
749 /// # });
750 /// ```
751 ///
752 /// # Errors
753 ///
754 /// - If deleting from the store fails, we fail with [`Error::Store`].
755 #[tracing::instrument(skip(self), err)]
756 pub async fn delete(&self) -> Result<()> {
757 let session_id = *self.inner.session_id.lock();
758 let Some(ref session_id) = session_id else {
759 tracing::warn!("called delete with no session id");
760 return Ok(());
761 };
762 self.store.delete(session_id).await.map_err(Error::Store)?;
763 Ok(())
764 }
765
766 /// Flushes the session by removing all data contained in the session and
767 /// then deleting it from the store.
768 ///
769 /// # Examples
770 ///
771 /// ```rust
772 /// # tokio_test::block_on(async {
773 /// use std::sync::Arc;
774 ///
775 /// use tower_sessions::{MemoryStore, Session, SessionStore};
776 ///
777 /// let store = Arc::new(MemoryStore::default());
778 /// let session = Session::new(None, store.clone(), None);
779 ///
780 /// session.insert("foo", "bar").await.unwrap();
781 /// session.save().await.unwrap();
782 ///
783 /// let id = session.id().unwrap();
784 ///
785 /// session.flush().await.unwrap();
786 ///
787 /// assert!(session.id().is_none());
788 /// assert!(session.is_empty().await);
789 /// assert!(store.load(&id).await.unwrap().is_none());
790 /// # });
791 /// ```
792 ///
793 /// # Errors
794 ///
795 /// - If deleting from the store fails, we fail with [`Error::Store`].
796 pub async fn flush(&self) -> Result<()> {
797 self.clear().await;
798 self.delete().await?;
799 *self.inner.session_id.lock() = None;
800 Ok(())
801 }
802
803 /// Cycles the session ID while retaining any data that was associated with
804 /// it.
805 ///
806 /// Using this method helps prevent session fixation attacks by ensuring a
807 /// new ID is assigned to the session.
808 ///
809 /// # Examples
810 ///
811 /// ```rust
812 /// # tokio_test::block_on(async {
813 /// use std::sync::Arc;
814 ///
815 /// use tower_sessions::{session::Id, MemoryStore, Session};
816 ///
817 /// let store = Arc::new(MemoryStore::default());
818 /// let session = Session::new(None, store.clone(), None);
819 ///
820 /// session.insert("foo", 42).await.unwrap();
821 /// session.save().await.unwrap();
822 /// let id = session.id();
823 ///
824 /// let session = Session::new(session.id(), store.clone(), None);
825 /// session.cycle_id().await.unwrap();
826 ///
827 /// assert!(!session.is_empty().await);
828 /// assert!(session.is_modified());
829 ///
830 /// session.save().await.unwrap();
831 ///
832 /// let session = Session::new(session.id(), store, None);
833 ///
834 /// assert_ne!(id, session.id());
835 /// assert_eq!(session.get::<usize>("foo").await.unwrap().unwrap(), 42);
836 /// # });
837 /// ```
838 ///
839 /// # Errors
840 ///
841 /// - If deleting from the store fails or saving to the store fails, we fail
842 /// with [`Error::Store`].
843 pub async fn cycle_id(&self) -> Result<()> {
844 let mut record_guard = self.get_record().await?;
845
846 let old_session_id = record_guard.id;
847 record_guard.id = Id::default();
848 *self.inner.session_id.lock() = None; // Setting `None` ensures `save` invokes the store's
849 // `create` method.
850
851 self.store
852 .delete(&old_session_id)
853 .await
854 .map_err(Error::Store)?;
855
856 self.inner
857 .is_modified
858 .store(true, atomic::Ordering::Release);
859
860 Ok(())
861 }
862}
863
864/// ID type for sessions.
865///
866/// Wraps an array of 16 bytes.
867///
868/// # Examples
869///
870/// ```rust
871/// use tower_sessions::session::Id;
872///
873/// Id::default();
874/// ```
875#[derive(Copy, Clone, Debug, Deserialize, Serialize, Eq, Hash, PartialEq)]
876pub struct Id(pub i128); // TODO: By this being public, it may be possible to override the
877 // session ID, which is undesirable.
878
879impl Default for Id {
880 fn default() -> Self {
881 use rand::prelude::*;
882
883 Self(rand::thread_rng().gen())
884 }
885}
886
887impl Display for Id {
888 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
889 let mut encoded = [0; 22];
890 URL_SAFE_NO_PAD
891 .encode_slice(self.0.to_le_bytes(), &mut encoded)
892 .expect("Encoded ID must be exactly 22 bytes");
893 let encoded = str::from_utf8(&encoded).expect("Encoded ID must be valid UTF-8");
894
895 f.write_str(encoded)
896 }
897}
898
899impl FromStr for Id {
900 type Err = base64::DecodeSliceError;
901
902 fn from_str(s: &str) -> result::Result<Self, Self::Err> {
903 let mut decoded = [0; 16];
904 let bytes_decoded = URL_SAFE_NO_PAD.decode_slice(s.as_bytes(), &mut decoded)?;
905 if bytes_decoded != 16 {
906 let err = DecodeError::InvalidLength(bytes_decoded);
907 return Err(base64::DecodeSliceError::DecodeError(err));
908 }
909
910 Ok(Self(i128::from_le_bytes(decoded)))
911 }
912}
913
914/// Record type that's appropriate for encoding and decoding sessions to and
915/// from session stores.
916#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
917pub struct Record {
918 pub id: Id,
919 pub data: Data,
920 pub expiry_date: OffsetDateTime,
921}
922
923impl Record {
924 fn new(expiry_date: OffsetDateTime) -> Self {
925 Self {
926 id: Id::default(),
927 data: Data::default(),
928 expiry_date,
929 }
930 }
931}
932
933/// Session expiry configuration.
934///
935/// # Examples
936///
937/// ```rust
938/// use time::{Duration, OffsetDateTime};
939/// use tower_sessions::Expiry;
940///
941/// // Will be expired on "session end".
942/// let expiry = Expiry::OnSessionEnd;
943///
944/// // Will be expired in five minutes from last acitve.
945/// let expiry = Expiry::OnInactivity(Duration::minutes(5));
946///
947/// // Will be expired at the given timestamp.
948/// let expired_at = OffsetDateTime::now_utc().saturating_add(Duration::weeks(2));
949/// let expiry = Expiry::AtDateTime(expired_at);
950/// ```
951#[derive(Copy, Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
952pub enum Expiry {
953 /// Expire on [current session end][current-session-end], as defined by the
954 /// browser.
955 ///
956 /// [current-session-end]: https://developer.mozilla.org/en-US/docs/Web/HTTP/Cookies#removal_defining_the_lifetime_of_a_cookie
957 OnSessionEnd,
958
959 /// Expire on inactivity.
960 ///
961 /// Reading a session is not considered activity for expiration purposes.
962 /// [`Session`] expiration is computed from the last time the session was
963 /// _modified_.
964 OnInactivity(Duration),
965
966 /// Expire at a specific date and time.
967 ///
968 /// This value may be extended manually with
969 /// [`set_expiry`](Session::set_expiry).
970 AtDateTime(OffsetDateTime),
971}
972
973#[cfg(test)]
974mod tests {
975 use async_trait::async_trait;
976 use mockall::{
977 mock,
978 predicate::{self, always},
979 };
980
981 use super::*;
982
983 mock! {
984 #[derive(Debug)]
985 pub Store {}
986
987 #[async_trait]
988 impl SessionStore for Store {
989 async fn create(&self, record: &mut Record) -> session_store::Result<()>;
990 async fn save(&self, record: &Record) -> session_store::Result<()>;
991 async fn load(&self, session_id: &Id) -> session_store::Result<Option<Record>>;
992 async fn delete(&self, session_id: &Id) -> session_store::Result<()>;
993 }
994 }
995
996 #[tokio::test]
997 async fn test_cycle_id() {
998 let mut mock_store = MockStore::new();
999
1000 let initial_id = Id::default();
1001 let new_id = Id::default();
1002
1003 // Set up expectations for the mock store
1004 mock_store
1005 .expect_save()
1006 .with(always())
1007 .times(1)
1008 .returning(|_| Ok(()));
1009 mock_store
1010 .expect_load()
1011 .with(predicate::eq(initial_id))
1012 .times(1)
1013 .returning(move |_| {
1014 Ok(Some(Record {
1015 id: initial_id,
1016 data: Data::default(),
1017 expiry_date: OffsetDateTime::now_utc(),
1018 }))
1019 });
1020 mock_store
1021 .expect_delete()
1022 .with(predicate::eq(initial_id))
1023 .times(1)
1024 .returning(|_| Ok(()));
1025 mock_store
1026 .expect_create()
1027 .times(1)
1028 .returning(move |record| {
1029 record.id = new_id;
1030 Ok(())
1031 });
1032
1033 let store = Arc::new(mock_store);
1034 let session = Session::new(Some(initial_id), store.clone(), None);
1035
1036 // Insert some data and save the session
1037 session.insert("foo", 42).await.unwrap();
1038 session.save().await.unwrap();
1039
1040 // Cycle the session ID
1041 session.cycle_id().await.unwrap();
1042
1043 // Verify that the session ID has changed and the data is still present
1044 assert_ne!(session.id(), Some(initial_id));
1045 assert!(session.id().is_none()); // The session ID should be None
1046 assert_eq!(session.get::<i32>("foo").await.unwrap(), Some(42));
1047
1048 // Save the session to update the ID in the session object
1049 session.save().await.unwrap();
1050 assert_eq!(session.id(), Some(new_id));
1051 }
1052}