mysql_async/conn/
stmt_cache.rs
1use lru::LruCache;
10use twox_hash::XxHash64;
11
12use std::{
13 borrow::Borrow,
14 collections::HashMap,
15 hash::{BuildHasherDefault, Hash},
16 sync::Arc,
17};
18
19use crate::queryable::stmt::StmtInner;
20
21#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
22pub struct QueryString(pub Arc<[u8]>);
23
24impl Borrow<[u8]> for QueryString {
25 fn borrow(&self) -> &[u8] {
26 self.0.as_ref()
27 }
28}
29
30impl PartialEq<[u8]> for QueryString {
31 fn eq(&self, other: &[u8]) -> bool {
32 self.0.as_ref() == other
33 }
34}
35
36pub struct Entry {
37 pub stmt: Arc<StmtInner>,
38 pub query: QueryString,
39}
40
41#[derive(Debug)]
42pub struct StmtCache {
43 cap: usize,
44 cache: LruCache<u32, Entry>,
45 query_map: HashMap<QueryString, u32, BuildHasherDefault<XxHash64>>,
46}
47
48impl StmtCache {
49 pub fn new(cap: usize) -> Self {
50 Self {
51 cap,
52 cache: LruCache::unbounded(),
53 query_map: Default::default(),
54 }
55 }
56
57 pub fn by_query<T>(&mut self, query: &T) -> Option<&Entry>
58 where
59 QueryString: Borrow<T>,
60 QueryString: PartialEq<T>,
61 T: Hash + Eq,
62 T: ?Sized,
63 {
64 let id = self.query_map.get(query).cloned();
65 match id {
66 Some(id) => self.cache.get(&id),
67 None => None,
68 }
69 }
70
71 pub fn put(&mut self, query: Arc<[u8]>, stmt: Arc<StmtInner>) -> Option<Arc<StmtInner>> {
72 if self.cap == 0 {
73 return None;
74 }
75
76 let query = QueryString(query);
77
78 self.query_map.insert(query.clone(), stmt.id());
79 self.cache.put(stmt.id(), Entry { stmt, query });
80
81 if self.cache.len() > self.cap {
82 if let Some((_, entry)) = self.cache.pop_lru() {
83 self.query_map.remove(entry.query.0.as_ref());
84 return Some(entry.stmt);
85 }
86 }
87
88 None
89 }
90
91 pub fn clear(&mut self) {
92 self.query_map.clear();
93 self.cache.clear();
94 }
95
96 pub fn remove(&mut self, id: u32) {
97 if let Some(entry) = self.cache.pop(&id) {
98 self.query_map.remove::<[u8]>(entry.query.borrow());
99 }
100 }
101
102 #[cfg(test)]
103 pub fn iter(&self) -> impl Iterator<Item = (&u32, &Entry)> {
104 self.cache.iter()
105 }
106
107 #[cfg(test)]
108 pub fn len(&self) -> usize {
109 self.cache.len()
110 }
111}
112
113impl super::Conn {
114 #[cfg(test)]
115 pub(crate) fn stmt_cache_ref(&self) -> &StmtCache {
116 &self.inner.stmt_cache
117 }
118
119 pub(crate) fn stmt_cache_mut(&mut self) -> &mut StmtCache {
120 &mut self.inner.stmt_cache
121 }
122
123 pub(crate) fn cache_stmt(&mut self, stmt: &Arc<StmtInner>) -> Option<Arc<StmtInner>> {
127 let query = stmt.raw_query.clone();
128 if self.inner.opts.stmt_cache_size() > 0 {
129 self.stmt_cache_mut().put(query, stmt.clone())
130 } else {
131 None
132 }
133 }
134
135 pub(crate) fn get_cached_stmt(&mut self, raw_query: &[u8]) -> Option<Arc<StmtInner>> {
139 self.stmt_cache_mut()
140 .by_query(raw_query)
141 .map(|entry| entry.stmt.clone())
142 }
143}