askama_derive/
input.rs

1use std::collections::hash_map::HashMap;
2use std::path::{Path, PathBuf};
3use std::str::FromStr;
4
5use mime::Mime;
6use quote::ToTokens;
7use syn::punctuated::Punctuated;
8
9use crate::config::{get_template_source, read_config_file, Config};
10use crate::CompileError;
11use parser::{Node, Parsed, Syntax};
12
13pub(crate) struct TemplateInput<'a> {
14    pub(crate) ast: &'a syn::DeriveInput,
15    pub(crate) config: &'a Config<'a>,
16    pub(crate) syntax: &'a Syntax<'a>,
17    pub(crate) source: &'a Source,
18    pub(crate) print: Print,
19    pub(crate) escaper: &'a str,
20    pub(crate) ext: Option<&'a str>,
21    pub(crate) mime_type: String,
22    pub(crate) path: PathBuf,
23}
24
25impl TemplateInput<'_> {
26    /// Extract the template metadata from the `DeriveInput` structure. This
27    /// mostly recovers the data for the `TemplateInput` fields from the
28    /// `template()` attribute list fields.
29    pub(crate) fn new<'n>(
30        ast: &'n syn::DeriveInput,
31        config: &'n Config<'_>,
32        args: &'n TemplateArgs,
33    ) -> Result<TemplateInput<'n>, CompileError> {
34        let TemplateArgs {
35            source,
36            print,
37            escaping,
38            ext,
39            syntax,
40            ..
41        } = args;
42
43        // Validate the `source` and `ext` value together, since they are
44        // related. In case `source` was used instead of `path`, the value
45        // of `ext` is merged into a synthetic `path` value here.
46        let source = source
47            .as_ref()
48            .expect("template path or source not found in attributes");
49        let path = match (&source, &ext) {
50            (Source::Path(path), _) => config.find_template(path, None)?,
51            (&Source::Source(_), Some(ext)) => PathBuf::from(format!("{}.{}", ast.ident, ext)),
52            (&Source::Source(_), None) => {
53                return Err("must include 'ext' attribute when using 'source' attribute".into())
54            }
55        };
56
57        // Validate syntax
58        let syntax = syntax.as_deref().map_or_else(
59            || Ok(config.syntaxes.get(config.default_syntax).unwrap()),
60            |s| {
61                config
62                    .syntaxes
63                    .get(s)
64                    .ok_or_else(|| CompileError::from(format!("attribute syntax {s} not exist")))
65            },
66        )?;
67
68        // Match extension against defined output formats
69
70        let escaping = escaping
71            .as_deref()
72            .unwrap_or_else(|| path.extension().map(|s| s.to_str().unwrap()).unwrap_or(""));
73
74        let mut escaper = None;
75        for (extensions, path) in &config.escapers {
76            if extensions.contains(escaping) {
77                escaper = Some(path);
78                break;
79            }
80        }
81
82        let escaper = escaper.ok_or_else(|| {
83            CompileError::from(format!("no escaper defined for extension '{escaping}'"))
84        })?;
85
86        let mime_type =
87            extension_to_mime_type(ext_default_to_path(ext.as_deref(), &path).unwrap_or("txt"))
88                .to_string();
89
90        Ok(TemplateInput {
91            ast,
92            config,
93            syntax,
94            source,
95            print: *print,
96            escaper,
97            ext: ext.as_deref(),
98            mime_type,
99            path,
100        })
101    }
102
103    pub(crate) fn find_used_templates(
104        &self,
105        map: &mut HashMap<PathBuf, Parsed>,
106    ) -> Result<(), CompileError> {
107        let source = match &self.source {
108            Source::Source(s) => s.clone(),
109            Source::Path(_) => get_template_source(&self.path)?,
110        };
111
112        let mut dependency_graph = Vec::new();
113        let mut check = vec![(self.path.clone(), source)];
114        while let Some((path, source)) = check.pop() {
115            let parsed = Parsed::new(source, self.syntax)?;
116
117            let mut top = true;
118            let mut nested = vec![parsed.nodes()];
119            while let Some(nodes) = nested.pop() {
120                for n in nodes {
121                    let mut add_to_check = |path: PathBuf| -> Result<(), CompileError> {
122                        if !map.contains_key(&path) {
123                            // Add a dummy entry to `map` in order to prevent adding `path`
124                            // multiple times to `check`.
125                            map.insert(path.clone(), Parsed::default());
126                            let source = get_template_source(&path)?;
127                            check.push((path, source));
128                        }
129                        Ok(())
130                    };
131
132                    use Node::*;
133                    match n {
134                        Extends(extends) if top => {
135                            let extends = self.config.find_template(extends.path, Some(&path))?;
136                            let dependency_path = (path.clone(), extends.clone());
137                            if dependency_graph.contains(&dependency_path) {
138                                return Err(format!(
139                                    "cyclic dependency in graph {:#?}",
140                                    dependency_graph
141                                        .iter()
142                                        .map(|e| format!("{:#?} --> {:#?}", e.0, e.1))
143                                        .collect::<Vec<String>>()
144                                )
145                                .into());
146                            }
147                            dependency_graph.push(dependency_path);
148                            add_to_check(extends)?;
149                        }
150                        Macro(m) if top => {
151                            nested.push(&m.nodes);
152                        }
153                        Import(import) if top => {
154                            let import = self.config.find_template(import.path, Some(&path))?;
155                            add_to_check(import)?;
156                        }
157                        Include(include) => {
158                            let include = self.config.find_template(include.path, Some(&path))?;
159                            add_to_check(include)?;
160                        }
161                        BlockDef(b) => {
162                            nested.push(&b.nodes);
163                        }
164                        If(i) => {
165                            for cond in &i.branches {
166                                nested.push(&cond.nodes);
167                            }
168                        }
169                        Loop(l) => {
170                            nested.push(&l.body);
171                            nested.push(&l.else_nodes);
172                        }
173                        Match(m) => {
174                            for arm in &m.arms {
175                                nested.push(&arm.nodes);
176                            }
177                        }
178                        Lit(_)
179                        | Comment(_)
180                        | Expr(_, _)
181                        | Call(_)
182                        | Extends(_)
183                        | Let(_)
184                        | Import(_)
185                        | Macro(_)
186                        | Raw(_)
187                        | Continue(_)
188                        | Break(_) => {}
189                    }
190                }
191                top = false;
192            }
193            map.insert(path, parsed);
194        }
195        Ok(())
196    }
197
198    #[inline]
199    pub(crate) fn extension(&self) -> Option<&str> {
200        ext_default_to_path(self.ext, &self.path)
201    }
202}
203
204#[derive(Debug, Default)]
205pub(crate) struct TemplateArgs {
206    source: Option<Source>,
207    print: Print,
208    escaping: Option<String>,
209    ext: Option<String>,
210    syntax: Option<String>,
211    config: Option<String>,
212    pub(crate) whitespace: Option<String>,
213}
214
215impl TemplateArgs {
216    pub(crate) fn new(ast: &'_ syn::DeriveInput) -> Result<Self, CompileError> {
217        // Check that an attribute called `template()` exists once and that it is
218        // the proper type (list).
219        let mut template_args = None;
220        for attr in &ast.attrs {
221            if !attr.path().is_ident("template") {
222                continue;
223            }
224
225            match attr.parse_args_with(Punctuated::<syn::Meta, syn::Token![,]>::parse_terminated) {
226                Ok(args) if template_args.is_none() => template_args = Some(args),
227                Ok(_) => return Err("duplicated 'template' attribute".into()),
228                Err(e) => return Err(format!("unable to parse template arguments: {e}").into()),
229            };
230        }
231
232        let template_args =
233            template_args.ok_or_else(|| CompileError::from("no attribute 'template' found"))?;
234
235        let mut args = Self::default();
236        // Loop over the meta attributes and find everything that we
237        // understand. Return a CompileError if something is not right.
238        // `source` contains an enum that can represent `path` or `source`.
239        for item in template_args {
240            let pair = match item {
241                syn::Meta::NameValue(pair) => pair,
242                _ => {
243                    return Err(format!(
244                        "unsupported attribute argument {:?}",
245                        item.to_token_stream()
246                    )
247                    .into())
248                }
249            };
250
251            let ident = match pair.path.get_ident() {
252                Some(ident) => ident,
253                None => unreachable!("not possible in syn::Meta::NameValue(…)"),
254            };
255
256            let value = match pair.value {
257                syn::Expr::Lit(lit) => lit,
258                syn::Expr::Group(group) => match *group.expr {
259                    syn::Expr::Lit(lit) => lit,
260                    _ => {
261                        return Err(format!("unsupported argument value type for {ident:?}").into())
262                    }
263                },
264                _ => return Err(format!("unsupported argument value type for {ident:?}").into()),
265            };
266
267            if ident == "path" {
268                if let syn::Lit::Str(s) = value.lit {
269                    if args.source.is_some() {
270                        return Err("must specify 'source' or 'path', not both".into());
271                    }
272                    args.source = Some(Source::Path(s.value()));
273                } else {
274                    return Err("template path must be string literal".into());
275                }
276            } else if ident == "source" {
277                if let syn::Lit::Str(s) = value.lit {
278                    if args.source.is_some() {
279                        return Err("must specify 'source' or 'path', not both".into());
280                    }
281                    args.source = Some(Source::Source(s.value()));
282                } else {
283                    return Err("template source must be string literal".into());
284                }
285            } else if ident == "print" {
286                if let syn::Lit::Str(s) = value.lit {
287                    args.print = s.value().parse()?;
288                } else {
289                    return Err("print value must be string literal".into());
290                }
291            } else if ident == "escape" {
292                if let syn::Lit::Str(s) = value.lit {
293                    args.escaping = Some(s.value());
294                } else {
295                    return Err("escape value must be string literal".into());
296                }
297            } else if ident == "ext" {
298                if let syn::Lit::Str(s) = value.lit {
299                    args.ext = Some(s.value());
300                } else {
301                    return Err("ext value must be string literal".into());
302                }
303            } else if ident == "syntax" {
304                if let syn::Lit::Str(s) = value.lit {
305                    args.syntax = Some(s.value())
306                } else {
307                    return Err("syntax value must be string literal".into());
308                }
309            } else if ident == "config" {
310                if let syn::Lit::Str(s) = value.lit {
311                    args.config = Some(s.value());
312                } else {
313                    return Err("config value must be string literal".into());
314                }
315            } else if ident == "whitespace" {
316                if let syn::Lit::Str(s) = value.lit {
317                    args.whitespace = Some(s.value())
318                } else {
319                    return Err("whitespace value must be string literal".into());
320                }
321            } else {
322                return Err(format!("unsupported attribute key {ident:?} found").into());
323            }
324        }
325
326        Ok(args)
327    }
328
329    pub(crate) fn config(&self) -> Result<String, CompileError> {
330        read_config_file(self.config.as_deref())
331    }
332}
333
334#[inline]
335fn ext_default_to_path<'a>(ext: Option<&'a str>, path: &'a Path) -> Option<&'a str> {
336    ext.or_else(|| extension(path))
337}
338
339fn extension(path: &Path) -> Option<&str> {
340    let ext = path.extension().map(|s| s.to_str().unwrap())?;
341
342    const JINJA_EXTENSIONS: [&str; 3] = ["j2", "jinja", "jinja2"];
343    if JINJA_EXTENSIONS.contains(&ext) {
344        Path::new(path.file_stem().unwrap())
345            .extension()
346            .map(|s| s.to_str().unwrap())
347            .or(Some(ext))
348    } else {
349        Some(ext)
350    }
351}
352
353#[derive(Debug)]
354pub(crate) enum Source {
355    Path(String),
356    Source(String),
357}
358
359#[derive(Clone, Copy, Debug, PartialEq)]
360pub(crate) enum Print {
361    All,
362    Ast,
363    Code,
364    None,
365}
366
367impl FromStr for Print {
368    type Err = CompileError;
369
370    fn from_str(s: &str) -> Result<Print, Self::Err> {
371        use self::Print::*;
372        Ok(match s {
373            "all" => All,
374            "ast" => Ast,
375            "code" => Code,
376            "none" => None,
377            v => return Err(format!("invalid value for print option: {v}",).into()),
378        })
379    }
380}
381
382impl Default for Print {
383    fn default() -> Self {
384        Self::None
385    }
386}
387
388pub(crate) fn extension_to_mime_type(ext: &str) -> Mime {
389    let basic_type = mime_guess::from_ext(ext).first_or_octet_stream();
390    for (simple, utf_8) in &TEXT_TYPES {
391        if &basic_type == simple {
392            return utf_8.clone();
393        }
394    }
395    basic_type
396}
397
398const TEXT_TYPES: [(Mime, Mime); 7] = [
399    (mime::TEXT_PLAIN, mime::TEXT_PLAIN_UTF_8),
400    (mime::TEXT_HTML, mime::TEXT_HTML_UTF_8),
401    (mime::TEXT_CSS, mime::TEXT_CSS_UTF_8),
402    (mime::TEXT_CSV, mime::TEXT_CSV_UTF_8),
403    (
404        mime::TEXT_TAB_SEPARATED_VALUES,
405        mime::TEXT_TAB_SEPARATED_VALUES_UTF_8,
406    ),
407    (
408        mime::APPLICATION_JAVASCRIPT,
409        mime::APPLICATION_JAVASCRIPT_UTF_8,
410    ),
411    (mime::IMAGE_SVG, mime::IMAGE_SVG),
412];
413
414#[cfg(test)]
415mod tests {
416    use super::*;
417
418    #[test]
419    fn test_ext() {
420        assert_eq!(extension(Path::new("foo-bar.txt")), Some("txt"));
421        assert_eq!(extension(Path::new("foo-bar.html")), Some("html"));
422        assert_eq!(extension(Path::new("foo-bar.unknown")), Some("unknown"));
423        assert_eq!(extension(Path::new("foo-bar.svg")), Some("svg"));
424
425        assert_eq!(extension(Path::new("foo/bar/baz.txt")), Some("txt"));
426        assert_eq!(extension(Path::new("foo/bar/baz.html")), Some("html"));
427        assert_eq!(extension(Path::new("foo/bar/baz.unknown")), Some("unknown"));
428        assert_eq!(extension(Path::new("foo/bar/baz.svg")), Some("svg"));
429    }
430
431    #[test]
432    fn test_double_ext() {
433        assert_eq!(extension(Path::new("foo-bar.html.txt")), Some("txt"));
434        assert_eq!(extension(Path::new("foo-bar.txt.html")), Some("html"));
435        assert_eq!(extension(Path::new("foo-bar.txt.unknown")), Some("unknown"));
436
437        assert_eq!(extension(Path::new("foo/bar/baz.html.txt")), Some("txt"));
438        assert_eq!(extension(Path::new("foo/bar/baz.txt.html")), Some("html"));
439        assert_eq!(
440            extension(Path::new("foo/bar/baz.txt.unknown")),
441            Some("unknown")
442        );
443    }
444
445    #[test]
446    fn test_skip_jinja_ext() {
447        assert_eq!(extension(Path::new("foo-bar.html.j2")), Some("html"));
448        assert_eq!(extension(Path::new("foo-bar.html.jinja")), Some("html"));
449        assert_eq!(extension(Path::new("foo-bar.html.jinja2")), Some("html"));
450
451        assert_eq!(extension(Path::new("foo/bar/baz.txt.j2")), Some("txt"));
452        assert_eq!(extension(Path::new("foo/bar/baz.txt.jinja")), Some("txt"));
453        assert_eq!(extension(Path::new("foo/bar/baz.txt.jinja2")), Some("txt"));
454    }
455
456    #[test]
457    fn test_only_jinja_ext() {
458        assert_eq!(extension(Path::new("foo-bar.j2")), Some("j2"));
459        assert_eq!(extension(Path::new("foo-bar.jinja")), Some("jinja"));
460        assert_eq!(extension(Path::new("foo-bar.jinja2")), Some("jinja2"));
461    }
462}