1#![cfg_attr(not(any(feature = "punycode", feature = "std")), no_std)]
4#![forbid(unsafe_code)]
5
6extern crate alloc;
7
8mod error;
9
10#[cfg(feature = "anycase")]
11use alloc::borrow::Cow;
12use alloc::borrow::ToOwned;
13#[cfg(not(any(feature = "hashbrown", feature = "punycode", feature = "std")))]
14use alloc::collections::BTreeMap as Map;
15#[cfg(not(feature = "anycase"))]
16use alloc::vec::Vec;
17use core::str::{from_utf8, FromStr};
18#[cfg(feature = "hashbrown")]
19use hashbrown::HashMap as Map;
20#[cfg(all(not(feature = "hashbrown"), any(feature = "punycode", feature = "std")))]
21use std::collections::HashMap as Map;
22#[cfg(feature = "anycase")]
23use unicase::UniCase;
24
25pub use error::Error;
26pub use psl_types::{Domain, Info, List as Psl, Suffix, Type};
27
28pub const LIST_URL: &str = "https://publicsuffix.org/list/public_suffix_list.dat";
30
31#[cfg(not(feature = "anycase"))]
32type Children = Map<Vec<u8>, Node>;
33
34#[cfg(feature = "anycase")]
35type Children = Map<UniCase<Cow<'static, str>>, Node>;
36
37const WILDCARD: &str = "*";
38
39#[derive(Debug, Clone, Default, Eq, PartialEq)]
40struct Node {
41    children: Children,
42    leaf: Option<Leaf>,
43}
44
45#[derive(Debug, Clone, Copy, Eq, PartialEq)]
46struct Leaf {
47    is_exception: bool,
48    typ: Type,
49}
50
51#[derive(Debug, Clone, Default, Eq, PartialEq)]
53pub struct List {
54    rules: Node,
55    typ: Option<Type>,
56}
57
58impl List {
59    #[inline]
61    #[must_use]
62    pub fn new() -> Self {
63        Self::default()
64    }
65
66    #[inline]
73    pub fn from_bytes(bytes: &[u8]) -> Result<Self, Error> {
74        from_utf8(bytes)
75            .map_err(|_| Error::ListNotUtf8Encoded)?
76            .parse()
77    }
78
79    #[inline]
81    #[must_use]
82    pub fn is_empty(&self) -> bool {
83        self.rules.children.is_empty()
84    }
85
86    #[inline]
87    fn append(&mut self, mut rule: &str, typ: Type) -> Result<(), Error> {
88        let mut is_exception = false;
89        if rule.starts_with('!') {
90            if !rule.contains('.') {
91                return Err(Error::ExceptionAtFirstLabel(rule.to_owned()));
92            }
93            is_exception = true;
94            rule = &rule[1..];
95        }
96
97        let mut current = &mut self.rules;
98        for label in rule.rsplit('.') {
99            if label.is_empty() {
100                return Err(Error::EmptyLabel(rule.to_owned()));
101            }
102
103            #[cfg(not(feature = "anycase"))]
104            let key = label.as_bytes().to_owned();
105            #[cfg(feature = "anycase")]
106            let key = UniCase::new(Cow::from(label.to_owned()));
107
108            current = current.children.entry(key).or_default();
109        }
110
111        current.leaf = Some(Leaf { is_exception, typ });
112
113        Ok(())
114    }
115}
116
117#[cfg(feature = "anycase")]
118macro_rules! anycase_key {
119    ($label:ident) => {
120        match from_utf8($label) {
121            Ok(label) => UniCase::new(Cow::from(label)),
122            Err(_) => return Info { len: 0, typ: None },
123        }
124    };
125}
126
127impl Psl for List {
128    #[inline]
129    fn find<'a, T>(&self, mut labels: T) -> Info
130    where
131        T: Iterator<Item = &'a [u8]>,
132    {
133        let mut rules = &self.rules;
134
135        let mut info = match labels.next() {
139            Some(label) => {
140                let mut info = Info {
141                    len: label.len(),
142                    typ: None,
143                };
144                #[cfg(not(feature = "anycase"))]
145                let node_opt = rules.children.get(label);
146                #[cfg(feature = "anycase")]
147                let node_opt = rules.children.get(&anycase_key!(label));
148                match node_opt {
149                    Some(node) => {
150                        info.typ = node.leaf.map(|leaf| leaf.typ);
151                        rules = node;
152                    }
153                    None => return info,
154                }
155                info
156            }
157            None => return Info { len: 0, typ: None },
158        };
159
160        let mut len_so_far = info.len;
162        for label in labels {
163            #[cfg(not(feature = "anycase"))]
164            let node_opt = rules.children.get(label);
165            #[cfg(feature = "anycase")]
166            let node_opt = rules.children.get(&anycase_key!(label));
167            match node_opt {
168                Some(node) => rules = node,
169                None => {
170                    #[cfg(not(feature = "anycase"))]
171                    let node_opt = rules.children.get(WILDCARD.as_bytes());
172                    #[cfg(feature = "anycase")]
173                    let node_opt = rules.children.get(&UniCase::new(Cow::from(WILDCARD)));
174                    match node_opt {
175                        Some(node) => rules = node,
176                        None => break,
177                    }
178                }
179            }
180            let label_plus_dot = label.len() + 1;
181            if let Some(leaf) = rules.leaf {
182                if self.typ.is_none() || self.typ == Some(leaf.typ) {
183                    info.typ = Some(leaf.typ);
184                    if leaf.is_exception {
185                        info.len = len_so_far;
186                        break;
187                    }
188                    info.len = len_so_far + label_plus_dot;
189                }
190            }
191            len_so_far += label_plus_dot;
192        }
193
194        info
195    }
196}
197
198impl FromStr for List {
199    type Err = Error;
200
201    #[inline]
202    fn from_str(s: &str) -> Result<Self, Self::Err> {
203        let mut typ = None;
204        let mut list = List::new();
205        for line in s.lines() {
206            match line {
207                line if line.contains("BEGIN ICANN DOMAINS") => {
208                    typ = Some(Type::Icann);
209                }
210                line if line.contains("BEGIN PRIVATE DOMAINS") => {
211                    typ = Some(Type::Private);
212                }
213                line if line.starts_with("//") => {
214                    continue;
215                }
216                line => match typ {
217                    Some(typ) => {
218                        let rule = match line.split_whitespace().next() {
219                            Some(rule) => rule,
220                            None => continue,
221                        };
222                        list.append(rule, typ)?;
223                        #[cfg(feature = "punycode")]
224                        {
225                            let ascii = idna::domain_to_ascii(rule)
226                                .map_err(|_| Error::InvalidRule(rule.to_owned()))?;
227                            list.append(&ascii, typ)?;
228                        }
229                    }
230                    None => {
231                        continue;
232                    }
233                },
234            }
235        }
236        if list.is_empty() {
237            return Err(Error::InvalidList);
238        }
239        Ok(list)
240    }
241}
242
243#[derive(Debug, Clone, Default, Eq, PartialEq)]
245pub struct IcannList(List);
246
247impl From<List> for IcannList {
248    #[inline]
249    fn from(mut list: List) -> Self {
250        list.typ = Some(Type::Icann);
251        Self(list)
252    }
253}
254
255impl From<IcannList> for List {
256    #[inline]
257    fn from(IcannList(mut list): IcannList) -> Self {
258        list.typ = None;
259        list
260    }
261}
262
263impl IcannList {
264    #[inline]
271    pub fn from_bytes(bytes: &[u8]) -> Result<Self, Error> {
272        let list = List::from_bytes(bytes)?;
273        Ok(list.into())
274    }
275
276    #[inline]
278    #[must_use]
279    pub fn is_empty(&self) -> bool {
280        self.0.is_empty()
281    }
282}
283
284impl FromStr for IcannList {
285    type Err = Error;
286
287    #[inline]
288    fn from_str(s: &str) -> Result<Self, Self::Err> {
289        let list = List::from_str(s)?;
290        Ok(list.into())
291    }
292}
293
294impl Psl for IcannList {
295    #[inline]
296    fn find<'a, T>(&self, labels: T) -> Info
297    where
298        T: Iterator<Item = &'a [u8]>,
299    {
300        self.0.find(labels)
301    }
302}
303
304#[derive(Debug, Clone, Default, Eq, PartialEq)]
306pub struct PrivateList(List);
307
308impl From<List> for PrivateList {
309    #[inline]
310    fn from(mut list: List) -> Self {
311        list.typ = Some(Type::Private);
312        Self(list)
313    }
314}
315
316impl From<PrivateList> for List {
317    #[inline]
318    fn from(PrivateList(mut list): PrivateList) -> Self {
319        list.typ = None;
320        list
321    }
322}
323
324impl PrivateList {
325    #[inline]
332    pub fn from_bytes(bytes: &[u8]) -> Result<Self, Error> {
333        let list = List::from_bytes(bytes)?;
334        Ok(list.into())
335    }
336
337    #[inline]
339    #[must_use]
340    pub fn is_empty(&self) -> bool {
341        self.0.is_empty()
342    }
343}
344
345impl FromStr for PrivateList {
346    type Err = Error;
347
348    #[inline]
349    fn from_str(s: &str) -> Result<Self, Self::Err> {
350        let list = List::from_str(s)?;
351        Ok(list.into())
352    }
353}
354
355impl Psl for PrivateList {
356    #[inline]
357    fn find<'a, T>(&self, labels: T) -> Info
358    where
359        T: Iterator<Item = &'a [u8]>,
360    {
361        self.0.find(labels)
362    }
363}
364
365#[cfg(test)]
366mod tests {
367    use super::*;
368
369    const LIST: &[u8] = b"
370        // BEGIN ICANN DOMAINS
371        com.uk
372        ";
373
374    #[test]
375    fn list_construction() {
376        let list = List::from_bytes(LIST).unwrap();
377        let expected = List {
378            typ: None,
379            rules: Node {
380                children: {
381                    let mut children = Children::default();
382                    children.insert(
383                        #[cfg(not(feature = "anycase"))]
384                        b"uk".to_vec(),
385                        #[cfg(feature = "anycase")]
386                        UniCase::new(Cow::from("uk")),
387                        Node {
388                            children: {
389                                let mut children = Children::default();
390                                children.insert(
391                                    #[cfg(not(feature = "anycase"))]
392                                    b"com".to_vec(),
393                                    #[cfg(feature = "anycase")]
394                                    UniCase::new(Cow::from("com")),
395                                    Node {
396                                        children: Default::default(),
397                                        leaf: Some(Leaf {
398                                            is_exception: false,
399                                            typ: Type::Icann,
400                                        }),
401                                    },
402                                );
403                                children
404                            },
405                            leaf: None,
406                        },
407                    );
408                    children
409                },
410                leaf: None,
411            },
412        };
413        assert_eq!(list, expected);
414    }
415
416    #[test]
417    fn find_localhost() {
418        let list = List::from_bytes(LIST).unwrap();
419        let labels = b"localhost".rsplit(|x| *x == b'.');
420        assert_eq!(list.find(labels), Info { len: 9, typ: None });
421    }
422
423    #[test]
424    fn find_uk() {
425        let list = List::from_bytes(LIST).unwrap();
426        let labels = b"uk".rsplit(|x| *x == b'.');
427        assert_eq!(list.find(labels), Info { len: 2, typ: None });
428    }
429
430    #[test]
431    fn find_com_uk() {
432        let list = List::from_bytes(LIST).unwrap();
433        let labels = b"com.uk".rsplit(|x| *x == b'.');
434        assert_eq!(
435            list.find(labels),
436            Info {
437                len: 6,
438                typ: Some(Type::Icann)
439            }
440        );
441    }
442
443    #[test]
444    fn find_ide_kyoto_jp() {
445        let list = List::from_bytes(b"// BEGIN ICANN DOMAINS\nide.kyoto.jp").unwrap();
446        let labels = b"ide.kyoto.jp".rsplit(|x| *x == b'.');
447        assert_eq!(
448            list.find(labels),
449            Info {
450                len: 12,
451                typ: Some(Type::Icann)
452            }
453        );
454    }
455}