1#![cfg_attr(feature = "nightly", feature(track_path, proc_macro_tracked_env))]
5
6use proc_macro::{TokenStream, TokenTree};
7use proc_macro2::Literal;
8use quote::quote;
9use std::{
10 error::Error,
11 fmt::{self, Display, Formatter},
12 path::{Path, PathBuf},
13 time::SystemTime,
14};
15
16#[proc_macro]
18pub fn include_dir(input: TokenStream) -> TokenStream {
19 let tokens: Vec<_> = input.into_iter().collect();
20
21 let path = match tokens.as_slice() {
22 [TokenTree::Literal(lit)] => unwrap_string_literal(lit),
23 _ => panic!("This macro only accepts a single, non-empty string argument"),
24 };
25
26 let path = resolve_path(&path, get_env).unwrap();
27
28 expand_dir(&path, &path).into()
29}
30
31fn unwrap_string_literal(lit: &proc_macro::Literal) -> String {
32 let mut repr = lit.to_string();
33 if !repr.starts_with('"') || !repr.ends_with('"') {
34 panic!("This macro only accepts a single, non-empty string argument")
35 }
36
37 repr.remove(0);
38 repr.pop();
39
40 repr
41}
42
43fn expand_dir(root: &Path, path: &Path) -> proc_macro2::TokenStream {
44 let children = read_dir(path).unwrap_or_else(|e| {
45 panic!(
46 "Unable to read the entries in \"{}\": {}",
47 path.display(),
48 e
49 )
50 });
51
52 let mut child_tokens = Vec::new();
53
54 for child in children {
55 if child.is_dir() {
56 let tokens = expand_dir(root, &child);
57 child_tokens.push(quote! {
58 include_dir::DirEntry::Dir(#tokens)
59 });
60 } else if child.is_file() {
61 let tokens = expand_file(root, &child);
62 child_tokens.push(quote! {
63 include_dir::DirEntry::File(#tokens)
64 });
65 } else {
66 panic!("\"{}\" is neither a file nor a directory", child.display());
67 }
68 }
69
70 let path = normalize_path(root, path);
71
72 quote! {
73 include_dir::Dir::new(#path, {
74 const ENTRIES: &'static [include_dir::DirEntry<'static>] = &[ #(#child_tokens),*];
75 ENTRIES
76 })
77 }
78}
79
80fn expand_file(root: &Path, path: &Path) -> proc_macro2::TokenStream {
81 let abs = path
82 .canonicalize()
83 .unwrap_or_else(|e| panic!("failed to resolve \"{}\": {}", path.display(), e));
84 let literal = match abs.to_str() {
85 Some(abs) => quote!(include_bytes!(#abs)),
86 None => {
87 let contents = read_file(path);
88 let literal = Literal::byte_string(&contents);
89 quote!(#literal)
90 }
91 };
92
93 let normalized_path = normalize_path(root, path);
94
95 let tokens = quote! {
96 include_dir::File::new(#normalized_path, #literal)
97 };
98
99 match metadata(path) {
100 Some(metadata) => quote!(#tokens.with_metadata(#metadata)),
101 None => tokens,
102 }
103}
104
105fn metadata(path: &Path) -> Option<proc_macro2::TokenStream> {
106 fn to_unix(t: SystemTime) -> u64 {
107 t.duration_since(SystemTime::UNIX_EPOCH).unwrap().as_secs()
108 }
109
110 if !cfg!(feature = "metadata") {
111 return None;
112 }
113
114 let meta = path.metadata().ok()?;
115 let accessed = meta.accessed().map(to_unix).ok()?;
116 let created = meta.created().map(to_unix).ok()?;
117 let modified = meta.modified().map(to_unix).ok()?;
118
119 Some(quote! {
120 include_dir::Metadata::new(
121 std::time::Duration::from_secs(#accessed),
122 std::time::Duration::from_secs(#created),
123 std::time::Duration::from_secs(#modified),
124 )
125 })
126}
127
128fn normalize_path(root: &Path, path: &Path) -> String {
131 let stripped = path
132 .strip_prefix(root)
133 .expect("Should only ever be called using paths inside the root path");
134 let as_string = stripped.to_string_lossy();
135
136 as_string.replace('\\', "/")
137}
138
139fn read_dir(dir: &Path) -> Result<Vec<PathBuf>, Box<dyn Error>> {
140 if !dir.is_dir() {
141 panic!("\"{}\" is not a directory", dir.display());
142 }
143
144 track_path(dir);
145
146 let mut paths = Vec::new();
147
148 for entry in dir.read_dir()? {
149 let entry = entry?;
150 paths.push(entry.path());
151 }
152
153 paths.sort();
154
155 Ok(paths)
156}
157
158fn read_file(path: &Path) -> Vec<u8> {
159 track_path(path);
160 std::fs::read(path).unwrap_or_else(|e| panic!("Unable to read \"{}\": {}", path.display(), e))
161}
162
163fn resolve_path(
164 raw: &str,
165 get_env: impl Fn(&str) -> Option<String>,
166) -> Result<PathBuf, Box<dyn Error>> {
167 let mut unprocessed = raw;
168 let mut resolved = String::new();
169
170 while let Some(dollar_sign) = unprocessed.find('$') {
171 let (head, tail) = unprocessed.split_at(dollar_sign);
172 resolved.push_str(head);
173
174 match parse_identifier(&tail[1..]) {
175 Some((variable, rest)) => {
176 let value = get_env(variable).ok_or_else(|| MissingVariable {
177 variable: variable.to_string(),
178 })?;
179 resolved.push_str(&value);
180 unprocessed = rest;
181 }
182 None => {
183 return Err(UnableToParseVariable { rest: tail.into() }.into());
184 }
185 }
186 }
187 resolved.push_str(unprocessed);
188
189 Ok(PathBuf::from(resolved))
190}
191
192#[derive(Debug, PartialEq)]
193struct MissingVariable {
194 variable: String,
195}
196
197impl Error for MissingVariable {}
198
199impl Display for MissingVariable {
200 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
201 write!(f, "Unable to resolve ${}", self.variable)
202 }
203}
204
205#[derive(Debug, PartialEq)]
206struct UnableToParseVariable {
207 rest: String,
208}
209
210impl Error for UnableToParseVariable {}
211
212impl Display for UnableToParseVariable {
213 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
214 write!(f, "Unable to parse a variable from \"{}\"", self.rest)
215 }
216}
217
218fn parse_identifier(text: &str) -> Option<(&str, &str)> {
219 let mut calls = 0;
220
221 let (head, tail) = take_while(text, |c| {
222 calls += 1;
223
224 match c {
225 '_' => true,
226 letter if letter.is_ascii_alphabetic() => true,
227 digit if digit.is_ascii_digit() && calls > 1 => true,
228 _ => false,
229 }
230 });
231
232 if head.is_empty() {
233 None
234 } else {
235 Some((head, tail))
236 }
237}
238
239fn take_while(s: &str, mut predicate: impl FnMut(char) -> bool) -> (&str, &str) {
240 let mut index = 0;
241
242 for c in s.chars() {
243 if predicate(c) {
244 index += c.len_utf8();
245 } else {
246 break;
247 }
248 }
249
250 s.split_at(index)
251}
252
253#[cfg(feature = "nightly")]
254fn get_env(variable: &str) -> Option<String> {
255 proc_macro::tracked_env::var(variable).ok()
256}
257
258#[cfg(not(feature = "nightly"))]
259fn get_env(variable: &str) -> Option<String> {
260 std::env::var(variable).ok()
261}
262
263fn track_path(_path: &Path) {
264 #[cfg(feature = "nightly")]
265 proc_macro::tracked_path::path(_path.to_string_lossy());
266}
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271
272 #[test]
273 fn resolve_path_with_no_environment_variables() {
274 let path = "./file.txt";
275
276 let resolved = resolve_path(path, |_| unreachable!()).unwrap();
277
278 assert_eq!(resolved.to_str().unwrap(), path);
279 }
280
281 #[test]
282 fn simple_environment_variable() {
283 let path = "./$VAR";
284
285 let resolved = resolve_path(path, |name| {
286 assert_eq!(name, "VAR");
287 Some("file.txt".to_string())
288 })
289 .unwrap();
290
291 assert_eq!(resolved.to_str().unwrap(), "./file.txt");
292 }
293
294 #[test]
295 fn dont_resolve_recursively() {
296 let path = "./$TOP_LEVEL.txt";
297
298 let resolved = resolve_path(path, |name| match name {
299 "TOP_LEVEL" => Some("$NESTED".to_string()),
300 "$NESTED" => unreachable!("Shouldn't resolve recursively"),
301 _ => unreachable!(),
302 })
303 .unwrap();
304
305 assert_eq!(resolved.to_str().unwrap(), "./$NESTED.txt");
306 }
307
308 #[test]
309 fn parse_valid_identifiers() {
310 let inputs = vec![
311 ("a", "a"),
312 ("a_", "a_"),
313 ("_asf", "_asf"),
314 ("a1", "a1"),
315 ("a1_#sd", "a1_"),
316 ];
317
318 for (src, expected) in inputs {
319 let (got, rest) = parse_identifier(src).unwrap();
320 assert_eq!(got.len() + rest.len(), src.len());
321 assert_eq!(got, expected);
322 }
323 }
324
325 #[test]
326 fn unknown_environment_variable() {
327 let path = "$UNKNOWN";
328
329 let err = resolve_path(path, |_| None).unwrap_err();
330
331 let missing_variable = err.downcast::<MissingVariable>().unwrap();
332 assert_eq!(
333 *missing_variable,
334 MissingVariable {
335 variable: String::from("UNKNOWN"),
336 }
337 );
338 }
339
340 #[test]
341 fn invalid_variables() {
342 let inputs = &["$1", "$"];
343
344 for input in inputs {
345 let err = resolve_path(input, |_| unreachable!()).unwrap_err();
346
347 let err = err.downcast::<UnableToParseVariable>().unwrap();
348 assert_eq!(
349 *err,
350 UnableToParseVariable {
351 rest: input.to_string(),
352 }
353 );
354 }
355 }
356}