1use std::collections::BTreeMap;
2use std::io::Write;
3use std::path::{Path, PathBuf};
4use std::sync::{Arc, Mutex};
5use std::{env, fmt, fs};
6
7use crate::utils::is_ci;
8use crate::{
9 content::{yaml, Content},
10 elog,
11};
12
13use once_cell::sync::Lazy;
14
15static WORKSPACES: Lazy<Mutex<BTreeMap<String, Arc<PathBuf>>>> =
16 Lazy::new(|| Mutex::new(BTreeMap::new()));
17static TOOL_CONFIGS: Lazy<Mutex<BTreeMap<PathBuf, Arc<ToolConfig>>>> =
18 Lazy::new(|| Mutex::new(BTreeMap::new()));
19
20pub fn get_tool_config(workspace_dir: &Path) -> Arc<ToolConfig> {
21 TOOL_CONFIGS
22 .lock()
23 .unwrap()
24 .entry(workspace_dir.to_path_buf())
25 .or_insert_with(|| {
26 ToolConfig::from_workspace(workspace_dir)
27 .unwrap_or_else(|e| panic!("Error building config from {:?}: {}", workspace_dir, e))
28 .into()
29 })
30 .clone()
31}
32
33#[cfg(feature = "_cargo_insta_internal")]
35#[derive(Clone, Copy, Debug, PartialEq, Eq, clap::ValueEnum)]
36pub enum TestRunner {
37 Auto,
38 CargoTest,
39 Nextest,
40}
41
42#[cfg(feature = "_cargo_insta_internal")]
43impl TestRunner {
44 pub fn resolve_fallback(&self, test_runner_fallback: bool) -> &TestRunner {
47 use crate::utils::get_cargo;
48 if self == &TestRunner::Nextest
49 && test_runner_fallback
50 && std::process::Command::new(get_cargo())
51 .arg("nextest")
52 .arg("--version")
53 .output()
54 .map(|output| !output.status.success())
55 .unwrap_or(true)
56 {
57 &TestRunner::Auto
58 } else {
59 self
60 }
61 }
62}
63
64#[derive(Clone, Copy, Debug, PartialEq, Eq)]
66pub enum OutputBehavior {
67 Diff,
69 Summary,
71 Minimal,
73 Nothing,
75}
76
77#[cfg(feature = "_cargo_insta_internal")]
79#[derive(Clone, Copy, Debug, PartialEq, Eq, clap::ValueEnum)]
80pub enum UnreferencedSnapshots {
81 Auto,
82 Reject,
83 Delete,
84 Warn,
85 Ignore,
86}
87
88#[derive(Clone, Copy, Debug, PartialEq, Eq)]
90pub enum SnapshotUpdate {
91 Always,
92 Auto,
93 Unseen,
94 New,
95 No,
96 Force,
97}
98
99#[derive(Debug)]
100pub enum Error {
101 Deserialize(crate::content::Error),
102 Env(&'static str),
103 #[allow(unused)]
104 Config(&'static str),
105}
106
107impl fmt::Display for Error {
108 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
109 match self {
110 Error::Deserialize(_) => write!(f, "failed to deserialize tool config"),
111 Error::Env(var) => write!(f, "invalid value for env var '{}'", var),
112 Error::Config(var) => write!(f, "invalid value for config '{}'", var),
113 }
114 }
115}
116
117impl std::error::Error for Error {
118 fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
119 match self {
120 Error::Deserialize(ref err) => Some(err),
121 _ => None,
122 }
123 }
124}
125
126#[derive(Debug, Clone)]
128pub struct ToolConfig {
129 force_pass: bool,
130 require_full_match: bool,
131 output: OutputBehavior,
132 snapshot_update: SnapshotUpdate,
133 #[cfg(feature = "glob")]
134 glob_fail_fast: bool,
135 #[cfg(feature = "_cargo_insta_internal")]
136 test_runner_fallback: bool,
137 #[cfg(feature = "_cargo_insta_internal")]
138 test_runner: TestRunner,
139 #[cfg(feature = "_cargo_insta_internal")]
140 test_unreferenced: UnreferencedSnapshots,
141 #[cfg(feature = "_cargo_insta_internal")]
142 auto_review: bool,
143 #[cfg(feature = "_cargo_insta_internal")]
144 auto_accept_unseen: bool,
145 #[cfg(feature = "_cargo_insta_internal")]
146 review_include_ignored: bool,
147 #[cfg(feature = "_cargo_insta_internal")]
148 review_include_hidden: bool,
149 #[cfg(feature = "_cargo_insta_internal")]
150 review_warn_undiscovered: bool,
151}
152
153impl ToolConfig {
154 pub fn from_workspace(workspace_dir: &Path) -> Result<ToolConfig, Error> {
156 let mut cfg = None;
157 for choice in &[".config/insta.yaml", "insta.yaml", ".insta.yaml"] {
158 let path = workspace_dir.join(choice);
159 match fs::read_to_string(&path) {
160 Ok(s) => {
161 cfg = Some(yaml::parse_str(&s, &path).map_err(Error::Deserialize)?);
162 break;
163 }
164 Err(_) => continue,
169 }
170 }
171 let cfg = cfg.unwrap_or_else(|| Content::Map(Default::default()));
172
173 let force_update_old_env_vars = if let Ok("1") = env::var("INSTA_FORCE_UPDATE").as_deref() {
187 true
194 } else if let Ok("1") = env::var("INSTA_FORCE_UPDATE_SNAPSHOTS").as_deref() {
195 elog!("INSTA_FORCE_UPDATE_SNAPSHOTS is deprecated, use INSTA_UPDATE=force. (If running from `cargo insta`, no action is required; upgrading `cargo-insta` will silence this warning.)");
202 true
203 } else {
204 false
205 };
206 if force_update_old_env_vars {
207 env::set_var("INSTA_UPDATE", "force");
208 }
209
210 Ok(ToolConfig {
211 require_full_match: match env::var("INSTA_REQUIRE_FULL_MATCH").as_deref() {
212 Err(_) | Ok("") => resolve(&cfg, &["behavior", "require_full_match"])
213 .and_then(|x| x.as_bool())
214 .unwrap_or(false),
215 Ok("0") => false,
216 Ok("1") => true,
217 _ => return Err(Error::Env("INSTA_REQUIRE_FULL_MATCH")),
218 },
219 force_pass: match env::var("INSTA_FORCE_PASS").as_deref() {
220 Err(_) | Ok("") => resolve(&cfg, &["behavior", "force_pass"])
221 .and_then(|x| x.as_bool())
222 .unwrap_or(false),
223 Ok("0") => false,
224 Ok("1") => true,
225 _ => return Err(Error::Env("INSTA_FORCE_PASS")),
226 },
227 output: {
228 let env_var = env::var("INSTA_OUTPUT");
229 let val = match env_var.as_deref() {
230 Err(_) | Ok("") => resolve(&cfg, &["behavior", "output"])
231 .and_then(|x| x.as_str())
232 .unwrap_or("diff"),
233 Ok(val) => val,
234 };
235 match val {
236 "diff" => OutputBehavior::Diff,
237 "summary" => OutputBehavior::Summary,
238 "minimal" => OutputBehavior::Minimal,
239 "none" => OutputBehavior::Nothing,
240 _ => return Err(Error::Env("INSTA_OUTPUT")),
241 }
242 },
243 snapshot_update: {
244 let env_var = env::var("INSTA_UPDATE");
245 let val = match env_var.as_deref() {
246 Err(_) | Ok("") => resolve(&cfg, &["behavior", "update"])
247 .and_then(|x| x.as_str())
248 .or(resolve(&cfg, &["behavior", "force_update"]).and_then(|x| {
250 elog!("`force_update: true` is deprecated in insta config files, use `update: force`");
251 match x.as_bool() {
252 Some(true) => Some("force"),
253 _ => None,
254 }
255 }))
256 .unwrap_or("auto"),
257 Ok(val) => val,
258 };
259 match val {
260 "auto" => SnapshotUpdate::Auto,
261 "always" | "1" => SnapshotUpdate::Always,
262 "new" => SnapshotUpdate::New,
263 "unseen" => SnapshotUpdate::Unseen,
264 "no" => SnapshotUpdate::No,
265 "force" => SnapshotUpdate::Force,
266 _ => return Err(Error::Env("INSTA_UPDATE")),
267 }
268 },
269 #[cfg(feature = "glob")]
270 glob_fail_fast: match env::var("INSTA_GLOB_FAIL_FAST").as_deref() {
271 Err(_) | Ok("") => resolve(&cfg, &["behavior", "glob_fail_fast"])
272 .and_then(|x| x.as_bool())
273 .unwrap_or(false),
274 Ok("1") => true,
275 Ok("0") => false,
276 _ => return Err(Error::Env("INSTA_GLOB_FAIL_FAST")),
277 },
278 #[cfg(feature = "_cargo_insta_internal")]
279 test_runner: {
280 let env_var = env::var("INSTA_TEST_RUNNER");
281 match env_var.as_deref() {
282 Err(_) | Ok("") => resolve(&cfg, &["test", "runner"])
283 .and_then(|x| x.as_str())
284 .unwrap_or("auto"),
285 Ok(val) => val,
286 }
287 .parse::<TestRunner>()
288 .map_err(|_| Error::Env("INSTA_TEST_RUNNER"))?
289 },
290 #[cfg(feature = "_cargo_insta_internal")]
291 test_runner_fallback: match env::var("INSTA_TEST_RUNNER_FALLBACK").as_deref() {
292 Err(_) | Ok("") => resolve(&cfg, &["test", "runner_fallback"])
293 .and_then(|x| x.as_bool())
294 .unwrap_or(false),
295 Ok("1") => true,
296 Ok("0") => false,
297 _ => return Err(Error::Env("INSTA_RUNNER_FALLBACK")),
298 },
299 #[cfg(feature = "_cargo_insta_internal")]
300 test_unreferenced: {
301 resolve(&cfg, &["test", "unreferenced"])
302 .and_then(|x| x.as_str())
303 .unwrap_or("ignore")
304 .parse::<UnreferencedSnapshots>()
305 .map_err(|_| Error::Config("unreferenced"))?
306 },
307 #[cfg(feature = "_cargo_insta_internal")]
308 auto_review: resolve(&cfg, &["test", "auto_review"])
309 .and_then(|x| x.as_bool())
310 .unwrap_or(false),
311 #[cfg(feature = "_cargo_insta_internal")]
312 auto_accept_unseen: resolve(&cfg, &["test", "auto_accept_unseen"])
313 .and_then(|x| x.as_bool())
314 .unwrap_or(false),
315 #[cfg(feature = "_cargo_insta_internal")]
316 review_include_hidden: resolve(&cfg, &["review", "include_hidden"])
317 .and_then(|x| x.as_bool())
318 .unwrap_or(false),
319 #[cfg(feature = "_cargo_insta_internal")]
320 review_include_ignored: resolve(&cfg, &["review", "include_ignored"])
321 .and_then(|x| x.as_bool())
322 .unwrap_or(false),
323 #[cfg(feature = "_cargo_insta_internal")]
324 review_warn_undiscovered: resolve(&cfg, &["review", "warn_undiscovered"])
325 .and_then(|x| x.as_bool())
326 .unwrap_or(true),
327 })
328 }
329
330 pub fn require_full_match(&self) -> bool {
334 self.require_full_match
335 }
336
337 pub fn force_pass(&self) -> bool {
339 self.force_pass
340 }
341
342 pub fn output_behavior(&self) -> OutputBehavior {
344 self.output
345 }
346
347 pub fn snapshot_update(&self) -> SnapshotUpdate {
349 self.snapshot_update
350 }
351
352 #[cfg(feature = "glob")]
354 pub fn glob_fail_fast(&self) -> bool {
355 self.glob_fail_fast
356 }
357}
358
359#[cfg(feature = "_cargo_insta_internal")]
360impl ToolConfig {
361 pub fn test_runner(&self) -> TestRunner {
363 self.test_runner
364 }
365
366 pub fn test_runner_fallback(&self) -> bool {
368 self.test_runner_fallback
369 }
370
371 pub fn test_unreferenced(&self) -> UnreferencedSnapshots {
372 self.test_unreferenced
373 }
374
375 pub fn auto_review(&self) -> bool {
377 self.auto_review
378 }
379
380 pub fn auto_accept_unseen(&self) -> bool {
382 self.auto_accept_unseen
383 }
384
385 pub fn review_include_hidden(&self) -> bool {
386 self.review_include_hidden
387 }
388
389 pub fn review_include_ignored(&self) -> bool {
390 self.review_include_ignored
391 }
392
393 pub fn review_warn_undiscovered(&self) -> bool {
394 self.review_warn_undiscovered
395 }
396}
397
398#[derive(Clone, Copy, Debug, PartialEq, Eq)]
400pub enum SnapshotUpdateBehavior {
401 InPlace,
403 NewFile,
405 NoUpdate,
407}
408
409pub fn snapshot_update_behavior(tool_config: &ToolConfig, unseen: bool) -> SnapshotUpdateBehavior {
411 match tool_config.snapshot_update() {
412 SnapshotUpdate::Always => SnapshotUpdateBehavior::InPlace,
413 SnapshotUpdate::Auto => {
414 if is_ci() {
415 SnapshotUpdateBehavior::NoUpdate
416 } else {
417 SnapshotUpdateBehavior::NewFile
418 }
419 }
420 SnapshotUpdate::Unseen => {
421 if unseen {
422 SnapshotUpdateBehavior::NewFile
423 } else {
424 SnapshotUpdateBehavior::InPlace
425 }
426 }
427 SnapshotUpdate::New => SnapshotUpdateBehavior::NewFile,
428 SnapshotUpdate::No => SnapshotUpdateBehavior::NoUpdate,
429 SnapshotUpdate::Force => SnapshotUpdateBehavior::InPlace,
430 }
431}
432
433pub enum Workspace {
434 DetectWithCargo(&'static str),
435 UseAsIs(&'static str),
436}
437
438pub fn get_cargo_workspace(workspace: Workspace) -> Arc<PathBuf> {
446 if let Ok(workspace_root) = env::var("INSTA_WORKSPACE_ROOT") {
450 return PathBuf::from(workspace_root).into();
451 }
452
453 let manifest_dir = match workspace {
458 Workspace::UseAsIs(workspace_root) => return PathBuf::from(workspace_root).into(),
459 Workspace::DetectWithCargo(manifest_dir) => manifest_dir,
460 };
461
462 let error_message = || {
463 format!(
464 "`cargo metadata --format-version=1 --no-deps` in path `{}`",
465 manifest_dir
466 )
467 };
468
469 WORKSPACES
470 .lock()
471 .unwrap()
473 .entry(manifest_dir.to_string())
474 .or_insert_with(|| {
475 let output = std::process::Command::new(
476 env::var("CARGO").unwrap_or_else(|_| "cargo".to_string()),
477 )
478 .args(["metadata", "--format-version=1", "--no-deps"])
479 .current_dir(manifest_dir)
480 .output()
481 .unwrap_or_else(|e| panic!("failed to run {}\n\n{}", error_message(), e));
482
483 crate::content::yaml::vendored::yaml::YamlLoader::load_from_str(
484 std::str::from_utf8(&output.stdout).unwrap(),
485 )
486 .map_err(|e| e.to_string())
487 .and_then(|docs| {
488 docs.into_iter()
489 .next()
490 .ok_or_else(|| "No content found in yaml".to_string())
491 })
492 .and_then(|metadata| {
493 metadata["workspace_root"]
494 .clone()
495 .into_string()
496 .ok_or_else(|| "Couldn't find `workspace_root`".to_string())
497 })
498 .map(|path| PathBuf::from(path).into())
499 .unwrap_or_else(|e| {
500 panic!(
501 "failed to parse cargo metadata output from {}: {}\n\n{:?}",
502 error_message(),
503 e,
504 output.stdout
505 )
506 })
507 })
508 .clone()
509}
510
511#[test]
512fn test_get_cargo_workspace_manifest_dir() {
513 let workspace = get_cargo_workspace(Workspace::DetectWithCargo(env!("CARGO_MANIFEST_DIR")));
514 assert!(workspace.ends_with("insta"));
516}
517
518#[test]
519fn test_get_cargo_workspace_insta_workspace() {
520 let workspace = get_cargo_workspace(Workspace::UseAsIs("/tmp/insta_workspace_root"));
521 assert!(workspace.ends_with("insta_workspace_root"));
523}
524
525#[cfg(feature = "_cargo_insta_internal")]
526impl std::str::FromStr for TestRunner {
527 type Err = ();
528
529 fn from_str(value: &str) -> Result<TestRunner, ()> {
530 match value {
531 "auto" => Ok(TestRunner::Auto),
532 "cargo-test" => Ok(TestRunner::CargoTest),
533 "nextest" => Ok(TestRunner::Nextest),
534 _ => Err(()),
535 }
536 }
537}
538
539#[cfg(feature = "_cargo_insta_internal")]
540impl std::str::FromStr for UnreferencedSnapshots {
541 type Err = ();
542
543 fn from_str(value: &str) -> Result<UnreferencedSnapshots, ()> {
544 match value {
545 "auto" => Ok(UnreferencedSnapshots::Auto),
546 "reject" | "error" => Ok(UnreferencedSnapshots::Reject),
547 "delete" => Ok(UnreferencedSnapshots::Delete),
548 "warn" => Ok(UnreferencedSnapshots::Warn),
549 "ignore" => Ok(UnreferencedSnapshots::Ignore),
550 _ => Err(()),
551 }
552 }
553}
554
555pub fn memoize_snapshot_file(snapshot_file: &Path) {
557 if let Ok(path) = env::var("INSTA_SNAPSHOT_REFERENCES_FILE") {
558 let mut f = fs::OpenOptions::new()
559 .append(true)
560 .create(true)
561 .open(path)
562 .unwrap();
563 f.write_all(format!("{}\n", snapshot_file.display()).as_bytes())
564 .unwrap();
565 }
566}
567
568fn resolve<'a>(value: &'a Content, path: &[&str]) -> Option<&'a Content> {
569 path.iter()
570 .try_fold(value, |node, segment| match node.resolve_inner() {
571 Content::Map(fields) => fields
572 .iter()
573 .find(|x| x.0.as_str() == Some(segment))
574 .map(|x| &x.1),
575 Content::Struct(_, fields) | Content::StructVariant(_, _, _, fields) => {
576 fields.iter().find(|x| x.0 == *segment).map(|x| &x.1)
577 }
578 _ => None,
579 })
580}