include_dir_macros/
lib.rs

1//! Implementation details of the `include_dir`.
2//!
3//! You probably don't want to use this crate directly.
4#![cfg_attr(feature = "nightly", feature(track_path, proc_macro_tracked_env))]
5
6use proc_macro::{TokenStream, TokenTree};
7use proc_macro2::Literal;
8use quote::quote;
9use std::{
10    error::Error,
11    fmt::{self, Display, Formatter},
12    path::{Path, PathBuf},
13    time::SystemTime,
14};
15
16/// Embed the contents of a directory in your crate.
17#[proc_macro]
18pub fn include_dir(input: TokenStream) -> TokenStream {
19    let tokens: Vec<_> = input.into_iter().collect();
20
21    let path = match tokens.as_slice() {
22        [TokenTree::Literal(lit)] => unwrap_string_literal(lit),
23        _ => panic!("This macro only accepts a single, non-empty string argument"),
24    };
25
26    let path = resolve_path(&path, get_env).unwrap();
27
28    expand_dir(&path, &path).into()
29}
30
31fn unwrap_string_literal(lit: &proc_macro::Literal) -> String {
32    let mut repr = lit.to_string();
33    if !repr.starts_with('"') || !repr.ends_with('"') {
34        panic!("This macro only accepts a single, non-empty string argument")
35    }
36
37    repr.remove(0);
38    repr.pop();
39
40    repr
41}
42
43fn expand_dir(root: &Path, path: &Path) -> proc_macro2::TokenStream {
44    let children = read_dir(path).unwrap_or_else(|e| {
45        panic!(
46            "Unable to read the entries in \"{}\": {}",
47            path.display(),
48            e
49        )
50    });
51
52    let mut child_tokens = Vec::new();
53
54    for child in children {
55        if child.is_dir() {
56            let tokens = expand_dir(root, &child);
57            child_tokens.push(quote! {
58                include_dir::DirEntry::Dir(#tokens)
59            });
60        } else if child.is_file() {
61            let tokens = expand_file(root, &child);
62            child_tokens.push(quote! {
63                include_dir::DirEntry::File(#tokens)
64            });
65        } else {
66            panic!("\"{}\" is neither a file nor a directory", child.display());
67        }
68    }
69
70    let path = normalize_path(root, path);
71
72    quote! {
73        include_dir::Dir::new(#path, {
74            const ENTRIES: &'static [include_dir::DirEntry<'static>] = &[ #(#child_tokens),*];
75            ENTRIES
76    })
77    }
78}
79
80fn expand_file(root: &Path, path: &Path) -> proc_macro2::TokenStream {
81    let abs = path
82        .canonicalize()
83        .unwrap_or_else(|e| panic!("failed to resolve \"{}\": {}", path.display(), e));
84    let literal = match abs.to_str() {
85        Some(abs) => quote!(include_bytes!(#abs)),
86        None => {
87            let contents = read_file(path);
88            let literal = Literal::byte_string(&contents);
89            quote!(#literal)
90        }
91    };
92
93    let normalized_path = normalize_path(root, path);
94
95    let tokens = quote! {
96        include_dir::File::new(#normalized_path, #literal)
97    };
98
99    match metadata(path) {
100        Some(metadata) => quote!(#tokens.with_metadata(#metadata)),
101        None => tokens,
102    }
103}
104
105fn metadata(path: &Path) -> Option<proc_macro2::TokenStream> {
106    fn to_unix(t: SystemTime) -> u64 {
107        t.duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs()
108    }
109
110    if !cfg!(feature = "metadata") {
111        return None;
112    }
113
114    let meta = path.metadata().ok()?;
115    let accessed = meta.accessed().map(to_unix).ok()?;
116    let created = meta.created().map(to_unix).ok()?;
117    let modified = meta.modified().map(to_unix).ok()?;
118
119    Some(quote! {
120        include_dir::Metadata::new(
121            std::time::Duration::from_secs(#accessed),
122            std::time::Duration::from_secs(#created),
123            std::time::Duration::from_secs(#modified),
124        )
125    })
126}
127
128/// Make sure that paths use the same separator regardless of whether the host
129/// machine is Windows or Linux.
130fn normalize_path(root: &Path, path: &Path) -> String {
131    let stripped = path
132        .strip_prefix(root)
133        .expect("Should only ever be called using paths inside the root path");
134    let as_string = stripped.to_string_lossy();
135
136    as_string.replace('\\', "/")
137}
138
139fn read_dir(dir: &Path) -> Result<Vec<PathBuf>, Box<dyn Error>> {
140    if !dir.is_dir() {
141        panic!("\"{}\" is not a directory", dir.display());
142    }
143
144    track_path(dir);
145
146    let mut paths = Vec::new();
147
148    for entry in dir.read_dir()? {
149        let entry = entry?;
150        paths.push(entry.path());
151    }
152
153    paths.sort();
154
155    Ok(paths)
156}
157
158fn read_file(path: &Path) -> Vec<u8> {
159    track_path(path);
160    std::fs::read(path).unwrap_or_else(|e| panic!("Unable to read \"{}\": {}", path.display(), e))
161}
162
163fn resolve_path(
164    raw: &str,
165    get_env: impl Fn(&str) -> Option<String>,
166) -> Result<PathBuf, Box<dyn Error>> {
167    let mut unprocessed = raw;
168    let mut resolved = String::new();
169
170    while let Some(dollar_sign) = unprocessed.find('$') {
171        let (head, tail) = unprocessed.split_at(dollar_sign);
172        resolved.push_str(head);
173
174        match parse_identifier(&tail[1..]) {
175            Some((variable, rest)) => {
176                let value = get_env(variable).ok_or_else(|| MissingVariable {
177                    variable: variable.to_string(),
178                })?;
179                resolved.push_str(&value);
180                unprocessed = rest;
181            }
182            None => {
183                return Err(UnableToParseVariable { rest: tail.into() }.into());
184            }
185        }
186    }
187    resolved.push_str(unprocessed);
188
189    Ok(PathBuf::from(resolved))
190}
191
192#[derive(Debug, PartialEq)]
193struct MissingVariable {
194    variable: String,
195}
196
197impl Error for MissingVariable {}
198
199impl Display for MissingVariable {
200    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
201        write!(f, "Unable to resolve ${}", self.variable)
202    }
203}
204
205#[derive(Debug, PartialEq)]
206struct UnableToParseVariable {
207    rest: String,
208}
209
210impl Error for UnableToParseVariable {}
211
212impl Display for UnableToParseVariable {
213    fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
214        write!(f, "Unable to parse a variable from \"{}\"", self.rest)
215    }
216}
217
218fn parse_identifier(text: &str) -> Option<(&str, &str)> {
219    let mut calls = 0;
220
221    let (head, tail) = take_while(text, |c| {
222        calls += 1;
223
224        match c {
225            '_' => true,
226            letter if letter.is_ascii_alphabetic() => true,
227            digit if digit.is_ascii_digit() && calls > 1 => true,
228            _ => false,
229        }
230    });
231
232    if head.is_empty() {
233        None
234    } else {
235        Some((head, tail))
236    }
237}
238
239fn take_while(s: &str, mut predicate: impl FnMut(char) -> bool) -> (&str, &str) {
240    let mut index = 0;
241
242    for c in s.chars() {
243        if predicate(c) {
244            index += c.len_utf8();
245        } else {
246            break;
247        }
248    }
249
250    s.split_at(index)
251}
252
253#[cfg(feature = "nightly")]
254fn get_env(variable: &str) -> Option<String> {
255    proc_macro::tracked_env::var(variable).ok()
256}
257
258#[cfg(not(feature = "nightly"))]
259fn get_env(variable: &str) -> Option<String> {
260    std::env::var(variable).ok()
261}
262
263fn track_path(_path: &Path) {
264    #[cfg(feature = "nightly")]
265    proc_macro::tracked_path::path(_path.to_string_lossy());
266}
267
268#[cfg(test)]
269mod tests {
270    use super::*;
271
272    #[test]
273    fn resolve_path_with_no_environment_variables() {
274        let path = "./file.txt";
275
276        let resolved = resolve_path(path, |_| unreachable!()).unwrap();
277
278        assert_eq!(resolved.to_str().unwrap(), path);
279    }
280
281    #[test]
282    fn simple_environment_variable() {
283        let path = "./$VAR";
284
285        let resolved = resolve_path(path, |name| {
286            assert_eq!(name, "VAR");
287            Some("file.txt".to_string())
288        })
289        .unwrap();
290
291        assert_eq!(resolved.to_str().unwrap(), "./file.txt");
292    }
293
294    #[test]
295    fn dont_resolve_recursively() {
296        let path = "./$TOP_LEVEL.txt";
297
298        let resolved = resolve_path(path, |name| match name {
299            "TOP_LEVEL" => Some("$NESTED".to_string()),
300            "$NESTED" => unreachable!("Shouldn't resolve recursively"),
301            _ => unreachable!(),
302        })
303        .unwrap();
304
305        assert_eq!(resolved.to_str().unwrap(), "./$NESTED.txt");
306    }
307
308    #[test]
309    fn parse_valid_identifiers() {
310        let inputs = vec![
311            ("a", "a"),
312            ("a_", "a_"),
313            ("_asf", "_asf"),
314            ("a1", "a1"),
315            ("a1_#sd", "a1_"),
316        ];
317
318        for (src, expected) in inputs {
319            let (got, rest) = parse_identifier(src).unwrap();
320            assert_eq!(got.len() + rest.len(), src.len());
321            assert_eq!(got, expected);
322        }
323    }
324
325    #[test]
326    fn unknown_environment_variable() {
327        let path = "$UNKNOWN";
328
329        let err = resolve_path(path, |_| None).unwrap_err();
330
331        let missing_variable = err.downcast::<MissingVariable>().unwrap();
332        assert_eq!(
333            *missing_variable,
334            MissingVariable {
335                variable: String::from("UNKNOWN"),
336            }
337        );
338    }
339
340    #[test]
341    fn invalid_variables() {
342        let inputs = &["$1", "$"];
343
344        for input in inputs {
345            let err = resolve_path(input, |_| unreachable!()).unwrap_err();
346
347            let err = err.downcast::<UnableToParseVariable>().unwrap();
348            assert_eq!(
349                *err,
350                UnableToParseVariable {
351                    rest: input.to_string(),
352                }
353            );
354        }
355    }
356}