mz_expr/
visit.rs

1// Copyright Materialize, Inc. and contributors. All rights reserved.
2//
3// Use of this software is governed by the Business Source License
4// included in the LICENSE file.
5//
6// As of the Change Date specified in that file, in accordance with
7// the Business Source License, use of this software will be governed
8// by the Apache License, Version 2.0.
9
10//! Visitor support for recursive data types.
11//!
12//! Recursive types can implement the [`VisitChildren`] trait, to
13//! specify how their recursive entries can be accessed. The extension
14//! trait [`Visit`] then adds support for recursively traversing
15//! instances of those types.
16//!
17//! # Naming
18//!
19//! Visitor methods follow this naming pattern:
20//!
21//! ```text
22//! [try_]visit_[mut_]{children,post,pre}
23//! ```
24//!
25//! * The `try`-prefix specifies whether the visitor callback is
26//!   fallible (prefix present) or infallible (prefix omitted).
27//! * The `mut`-suffix specifies whether the visitor callback gets
28//!   access to mutable (prefix present) or immutable (prefix omitted)
29//!   child references.
30//! * The final suffix determines the nature of the traversal:
31//!   * `children`: only visit direct children
32//!   * `post`: recursively visit children in post-order
33//!   * `pre`: recursively visit children in pre-order
34//!   * no suffix: recursively visit children in pre- and post-oder
35//!     using a ~Visitor~` that encapsulates the shared context.
36
37use std::marker::PhantomData;
38
39use mz_ore::stack::{CheckedRecursion, RecursionGuard, RecursionLimitError, maybe_grow};
40
41use crate::RECURSION_LIMIT;
42
43/// A trait for types that can visit their direct children of type `T`.
44///
45/// Implementing [`VisitChildren<Self>`] automatically also implements
46/// the [`Visit`] trait, which enables recursive traversal.
47///
48/// Note that care needs to be taken when implementing this trait for
49/// mutually recursive types (such as a type A where A has children
50/// of type B and vice versa). More specifically, at the moment it is
51/// not possible to implement versions of `VisitChildren<A> for A` such
52/// that A considers as its children all A-nodes occurring at leaf
53/// positions of B-children and vice versa for `VisitChildren<B> for B`.
54/// Doing this will result in recusion limit violations as indicated
55/// in the accompanying `test_recursive_types_b` test.
56pub trait VisitChildren<T> {
57    /// Apply an infallible immutable function `f` to each direct child.
58    fn visit_children<F>(&self, f: F)
59    where
60        F: FnMut(&T);
61
62    /// Apply an infallible mutable function `f` to each direct child.
63    fn visit_mut_children<F>(&mut self, f: F)
64    where
65        F: FnMut(&mut T);
66
67    /// Apply a fallible immutable function `f` to each direct child.
68    ///
69    /// For mutually recursive implementations (say consisting of two
70    /// types `A` and `B`), recursing through `B` in order to find all
71    /// `A`-children of a node of type `A` might cause lead to a
72    /// [`RecursionLimitError`], hence the bound on `E`.
73    fn try_visit_children<F, E>(&self, f: F) -> Result<(), E>
74    where
75        F: FnMut(&T) -> Result<(), E>,
76        E: From<RecursionLimitError>;
77
78    /// Apply a fallible mutable function `f` to each direct child.
79    ///
80    /// For mutually recursive implementations (say consisting of two
81    /// types `A` and `B`), recursing through `B` in order to find all
82    /// `A`-children of a node of type `A` might cause lead to a
83    /// [`RecursionLimitError`], hence the bound on `E`.
84    fn try_visit_mut_children<F, E>(&mut self, f: F) -> Result<(), E>
85    where
86        F: FnMut(&mut T) -> Result<(), E>,
87        E: From<RecursionLimitError>;
88}
89
90/// A trait for types that can recursively visit their children of the
91/// same type.
92///
93/// This trait is automatically implemented for all implementors of
94/// [`VisitChildren`].
95///
96/// All methods provided by this trait ensure that the stack is grown
97/// as needed, to avoid stack overflows when traversing deeply
98/// recursive objects. They also enforce a recursion limit of
99/// [`RECURSION_LIMIT`] by returning an error when that limit
100/// is exceeded.
101///
102/// There are also `*_nolimit` methods that don't enforce a recursion
103/// limit. Those methods are deprecated and should not be used.
104pub trait Visit {
105    /// Post-order immutable infallible visitor for `self`.
106    fn visit_post<F>(&self, f: &mut F) -> Result<(), RecursionLimitError>
107    where
108        F: FnMut(&Self);
109
110    /// Post-order immutable infallible visitor for `self`.
111    /// Does not enforce a recursion limit.
112    #[deprecated = "Use `visit_post` instead."]
113    fn visit_post_nolimit<F>(&self, f: &mut F)
114    where
115        F: FnMut(&Self);
116
117    /// Post-order mutable infallible visitor for `self`.
118    fn visit_mut_post<F>(&mut self, f: &mut F) -> Result<(), RecursionLimitError>
119    where
120        F: FnMut(&mut Self);
121
122    /// Post-order mutable infallible visitor for `self`.
123    /// Does not enforce a recursion limit.
124    #[deprecated = "Use `visit_mut_post` instead."]
125    fn visit_mut_post_nolimit<F>(&mut self, f: &mut F)
126    where
127        F: FnMut(&mut Self);
128
129    /// Post-order immutable fallible visitor for `self`.
130    fn try_visit_post<F, E>(&self, f: &mut F) -> Result<(), E>
131    where
132        F: FnMut(&Self) -> Result<(), E>,
133        E: From<RecursionLimitError>;
134
135    /// Post-order mutable fallible visitor for `self`.
136    fn try_visit_mut_post<F, E>(&mut self, f: &mut F) -> Result<(), E>
137    where
138        F: FnMut(&mut Self) -> Result<(), E>,
139        E: From<RecursionLimitError>;
140
141    /// Pre-order immutable infallible visitor for `self`.
142    fn visit_pre<F>(&self, f: &mut F) -> Result<(), RecursionLimitError>
143    where
144        F: FnMut(&Self);
145
146    /// Pre-order immutable infallible visitor for `self`, which also accumulates context
147    /// information along the path from the root to the current node's parent.
148    /// `acc_fun` is a similar closure as in `fold`. The accumulated context is passed to the
149    /// visitor, along with the current node.
150    ///
151    /// For example, one can use this on a `MirScalarExpr` to tell the visitor whether the current
152    /// subexpression has a negation somewhere above it.
153    ///
154    /// When using it on a `MirRelationExpr`, one has to be mindful that `Let` bindings are not
155    /// followed, i.e., the context won't include what happens with a `Let` binding in some other
156    /// `MirRelationExpr` where the binding occurs in a `Get`.
157    fn visit_pre_with_context<Context, AccFun, Visitor>(
158        &self,
159        init: Context,
160        acc_fun: &mut AccFun,
161        visitor: &mut Visitor,
162    ) -> Result<(), RecursionLimitError>
163    where
164        Context: Clone,
165        AccFun: FnMut(Context, &Self) -> Context,
166        Visitor: FnMut(&Context, &Self);
167
168    /// Pre-order immutable infallible visitor for `self`.
169    /// Does not enforce a recursion limit.
170    #[deprecated = "Use `visit_pre` instead."]
171    fn visit_pre_nolimit<F>(&self, f: &mut F)
172    where
173        F: FnMut(&Self);
174
175    /// Pre-order mutable infallible visitor for `self`.
176    fn visit_mut_pre<F>(&mut self, f: &mut F) -> Result<(), RecursionLimitError>
177    where
178        F: FnMut(&mut Self);
179
180    /// Pre-order mutable infallible visitor for `self`.
181    /// Does not enforce a recursion limit.
182    #[deprecated = "Use `visit_mut_pre` instead."]
183    fn visit_mut_pre_nolimit<F>(&mut self, f: &mut F)
184    where
185        F: FnMut(&mut Self);
186
187    /// Pre-order immutable fallible visitor for `self`.
188    fn try_visit_pre<F, E>(&self, f: &mut F) -> Result<(), E>
189    where
190        F: FnMut(&Self) -> Result<(), E>,
191        E: From<RecursionLimitError>;
192
193    /// Pre-order mutable fallible visitor for `self`.
194    fn try_visit_mut_pre<F, E>(&mut self, f: &mut F) -> Result<(), E>
195    where
196        F: FnMut(&mut Self) -> Result<(), E>,
197        E: From<RecursionLimitError>;
198
199    /// A generalization of [`Visit::visit_pre`] and [`Visit::visit_post`].
200    ///
201    /// The function `pre` runs on `self` before it runs on any of the children.
202    /// The function `post` runs on children first before the parent.
203    ///
204    /// Optionally, `pre` can return which children, if any, should be visited
205    /// (default is to visit all children).
206    fn visit_pre_post<F1, F2>(
207        &self,
208        pre: &mut F1,
209        post: &mut F2,
210    ) -> Result<(), RecursionLimitError>
211    where
212        F1: FnMut(&Self) -> Option<Vec<&Self>>,
213        F2: FnMut(&Self);
214
215    /// A generalization of [`Visit::visit_pre`] and [`Visit::visit_post`].
216    /// Does not enforce a recursion limit.
217    ///
218    /// The function `pre` runs on `self` before it runs on any of the children.
219    /// The function `post` runs on children first before the parent.
220    ///
221    /// Optionally, `pre` can return which children, if any, should be visited
222    /// (default is to visit all children).
223    #[deprecated = "Use `visit` instead."]
224    fn visit_pre_post_nolimit<F1, F2>(&self, pre: &mut F1, post: &mut F2)
225    where
226        F1: FnMut(&Self) -> Option<Vec<&Self>>,
227        F2: FnMut(&Self);
228
229    /// A generalization of [`Visit::visit_mut_pre`] and [`Visit::visit_mut_post`].
230    ///
231    /// The function `pre` runs on `self` before it runs on any of the children.
232    /// The function `post` runs on children first before the parent.
233    ///
234    /// Optionally, `pre` can return which children, if any, should be visited
235    /// (default is to visit all children).
236    #[deprecated = "Use `visit_mut` instead."]
237    fn visit_mut_pre_post<F1, F2>(
238        &mut self,
239        pre: &mut F1,
240        post: &mut F2,
241    ) -> Result<(), RecursionLimitError>
242    where
243        F1: FnMut(&mut Self) -> Option<Vec<&mut Self>>,
244        F2: FnMut(&mut Self);
245
246    /// A generalization of [`Visit::visit_mut_pre`] and [`Visit::visit_mut_post`].
247    /// Does not enforce a recursion limit.
248    ///
249    /// The function `pre` runs on `self` before it runs on any of the children.
250    /// The function `post` runs on children first before the parent.
251    ///
252    /// Optionally, `pre` can return which children, if any, should be visited
253    /// (default is to visit all children).
254    #[deprecated = "Use `visit_mut_pre_post` instead."]
255    fn visit_mut_pre_post_nolimit<F1, F2>(&mut self, pre: &mut F1, post: &mut F2)
256    where
257        F1: FnMut(&mut Self) -> Option<Vec<&mut Self>>,
258        F2: FnMut(&mut Self);
259
260    fn visit<V>(&self, visitor: &mut V) -> Result<(), RecursionLimitError>
261    where
262        Self: Sized,
263        V: Visitor<Self>;
264
265    fn visit_mut<V>(&mut self, visitor: &mut V) -> Result<(), RecursionLimitError>
266    where
267        Self: Sized,
268        V: VisitorMut<Self>;
269
270    fn try_visit<V, E>(&self, visitor: &mut V) -> Result<(), E>
271    where
272        Self: Sized,
273        V: TryVisitor<Self, E>,
274        E: From<RecursionLimitError>;
275
276    fn try_visit_mut<V, E>(&mut self, visitor: &mut V) -> Result<(), E>
277    where
278        Self: Sized,
279        V: TryVisitorMut<Self, E>,
280        E: From<RecursionLimitError>;
281}
282
283pub trait Visitor<T> {
284    fn pre_visit(&mut self, expr: &T);
285    fn post_visit(&mut self, expr: &T);
286}
287
288pub trait VisitorMut<T> {
289    fn pre_visit(&mut self, expr: &mut T);
290    fn post_visit(&mut self, expr: &mut T);
291}
292
293pub trait TryVisitor<T, E: From<RecursionLimitError>> {
294    fn pre_visit(&mut self, expr: &T) -> Result<(), E>;
295    fn post_visit(&mut self, expr: &T) -> Result<(), E>;
296}
297
298pub trait TryVisitorMut<T, E: From<RecursionLimitError>> {
299    fn pre_visit(&mut self, expr: &mut T) -> Result<(), E>;
300    fn post_visit(&mut self, expr: &mut T) -> Result<(), E>;
301}
302
303impl<T: VisitChildren<T>> Visit for T {
304    fn visit_post<F>(&self, f: &mut F) -> Result<(), RecursionLimitError>
305    where
306        F: FnMut(&Self),
307    {
308        StackSafeVisit::new().visit_post(self, f)
309    }
310
311    fn visit_post_nolimit<F>(&self, f: &mut F)
312    where
313        F: FnMut(&Self),
314    {
315        StackSafeVisit::new().visit_post_nolimit(self, f)
316    }
317
318    fn visit_mut_post<F>(&mut self, f: &mut F) -> Result<(), RecursionLimitError>
319    where
320        F: FnMut(&mut Self),
321    {
322        StackSafeVisit::new().visit_mut_post(self, f)
323    }
324
325    fn visit_mut_post_nolimit<F>(&mut self, f: &mut F)
326    where
327        F: FnMut(&mut Self),
328    {
329        StackSafeVisit::new().visit_mut_post_nolimit(self, f)
330    }
331
332    fn try_visit_post<F, E>(&self, f: &mut F) -> Result<(), E>
333    where
334        F: FnMut(&Self) -> Result<(), E>,
335        E: From<RecursionLimitError>,
336    {
337        StackSafeVisit::new().try_visit_post(self, f)
338    }
339
340    fn try_visit_mut_post<F, E>(&mut self, f: &mut F) -> Result<(), E>
341    where
342        F: FnMut(&mut Self) -> Result<(), E>,
343        E: From<RecursionLimitError>,
344    {
345        StackSafeVisit::new().try_visit_mut_post(self, f)
346    }
347
348    fn visit_pre<F>(&self, f: &mut F) -> Result<(), RecursionLimitError>
349    where
350        F: FnMut(&Self),
351    {
352        StackSafeVisit::new().visit_pre(self, f)
353    }
354
355    fn visit_pre_with_context<Context, AccFun, Visitor>(
356        &self,
357        init: Context,
358        acc_fun: &mut AccFun,
359        visitor: &mut Visitor,
360    ) -> Result<(), RecursionLimitError>
361    where
362        Context: Clone,
363        AccFun: FnMut(Context, &Self) -> Context,
364        Visitor: FnMut(&Context, &Self),
365    {
366        StackSafeVisit::new().visit_pre_with_context(self, init, acc_fun, visitor)
367    }
368
369    fn visit_pre_nolimit<F>(&self, f: &mut F)
370    where
371        F: FnMut(&Self),
372    {
373        StackSafeVisit::new().visit_pre_nolimit(self, f)
374    }
375
376    fn visit_mut_pre<F>(&mut self, f: &mut F) -> Result<(), RecursionLimitError>
377    where
378        F: FnMut(&mut Self),
379    {
380        StackSafeVisit::new().visit_mut_pre(self, f)
381    }
382
383    fn visit_mut_pre_nolimit<F>(&mut self, f: &mut F)
384    where
385        F: FnMut(&mut Self),
386    {
387        StackSafeVisit::new().visit_mut_pre_nolimit(self, f)
388    }
389
390    fn try_visit_pre<F, E>(&self, f: &mut F) -> Result<(), E>
391    where
392        F: FnMut(&Self) -> Result<(), E>,
393        E: From<RecursionLimitError>,
394    {
395        StackSafeVisit::new().try_visit_pre(self, f)
396    }
397
398    fn try_visit_mut_pre<F, E>(&mut self, f: &mut F) -> Result<(), E>
399    where
400        F: FnMut(&mut Self) -> Result<(), E>,
401        E: From<RecursionLimitError>,
402    {
403        StackSafeVisit::new().try_visit_mut_pre(self, f)
404    }
405
406    fn visit_pre_post<F1, F2>(&self, pre: &mut F1, post: &mut F2) -> Result<(), RecursionLimitError>
407    where
408        F1: FnMut(&Self) -> Option<Vec<&Self>>,
409        F2: FnMut(&Self),
410    {
411        StackSafeVisit::new().visit_pre_post(self, pre, post)
412    }
413
414    fn visit_pre_post_nolimit<F1, F2>(&self, pre: &mut F1, post: &mut F2)
415    where
416        F1: FnMut(&Self) -> Option<Vec<&Self>>,
417        F2: FnMut(&Self),
418    {
419        StackSafeVisit::new().visit_pre_post_nolimit(self, pre, post)
420    }
421
422    fn visit_mut_pre_post<F1, F2>(
423        &mut self,
424        pre: &mut F1,
425        post: &mut F2,
426    ) -> Result<(), RecursionLimitError>
427    where
428        F1: FnMut(&mut Self) -> Option<Vec<&mut Self>>,
429        F2: FnMut(&mut Self),
430    {
431        StackSafeVisit::new().visit_mut_pre_post(self, pre, post)
432    }
433
434    fn visit_mut_pre_post_nolimit<F1, F2>(&mut self, pre: &mut F1, post: &mut F2)
435    where
436        F1: FnMut(&mut Self) -> Option<Vec<&mut Self>>,
437        F2: FnMut(&mut Self),
438    {
439        StackSafeVisit::new().visit_mut_pre_post_nolimit(self, pre, post)
440    }
441
442    fn visit<V>(&self, visitor: &mut V) -> Result<(), RecursionLimitError>
443    where
444        Self: Sized,
445        V: Visitor<Self>,
446    {
447        StackSafeVisit::new().visit(self, visitor)
448    }
449
450    fn visit_mut<V>(&mut self, visitor: &mut V) -> Result<(), RecursionLimitError>
451    where
452        Self: Sized,
453        V: VisitorMut<Self>,
454    {
455        StackSafeVisit::new().visit_mut(self, visitor)
456    }
457
458    fn try_visit<V, E>(&self, visitor: &mut V) -> Result<(), E>
459    where
460        Self: Sized,
461        V: TryVisitor<Self, E>,
462        E: From<RecursionLimitError>,
463    {
464        StackSafeVisit::new().try_visit(self, visitor)
465    }
466
467    fn try_visit_mut<V, E>(&mut self, visitor: &mut V) -> Result<(), E>
468    where
469        Self: Sized,
470        V: TryVisitorMut<Self, E>,
471        E: From<RecursionLimitError>,
472    {
473        StackSafeVisit::new().try_visit_mut(self, visitor)
474    }
475}
476
477struct StackSafeVisit<T> {
478    recursion_guard: RecursionGuard,
479    _type: PhantomData<T>,
480}
481
482impl<T> CheckedRecursion for StackSafeVisit<T> {
483    fn recursion_guard(&self) -> &RecursionGuard {
484        &self.recursion_guard
485    }
486}
487
488impl<T: VisitChildren<T>> StackSafeVisit<T> {
489    fn new() -> Self {
490        Self {
491            recursion_guard: RecursionGuard::with_limit(RECURSION_LIMIT),
492            _type: PhantomData,
493        }
494    }
495
496    fn visit_post<F>(&self, value: &T, f: &mut F) -> Result<(), RecursionLimitError>
497    where
498        F: FnMut(&T),
499    {
500        self.checked_recur(move |_| {
501            value.try_visit_children(|child| self.visit_post(child, f))?;
502            f(value);
503            Ok(())
504        })
505    }
506
507    fn visit_post_nolimit<F>(&self, value: &T, f: &mut F)
508    where
509        F: FnMut(&T),
510    {
511        maybe_grow(|| {
512            value.visit_children(|child| self.visit_post_nolimit(child, f));
513            f(value)
514        })
515    }
516
517    fn visit_mut_post<F>(&self, value: &mut T, f: &mut F) -> Result<(), RecursionLimitError>
518    where
519        F: FnMut(&mut T),
520    {
521        self.checked_recur(move |_| {
522            value.try_visit_mut_children(|child| self.visit_mut_post(child, f))?;
523            f(value);
524            Ok(())
525        })
526    }
527
528    fn visit_mut_post_nolimit<F>(&self, value: &mut T, f: &mut F)
529    where
530        F: FnMut(&mut T),
531    {
532        maybe_grow(|| {
533            value.visit_mut_children(|child| self.visit_mut_post_nolimit(child, f));
534            f(value)
535        })
536    }
537
538    fn try_visit_post<F, E>(&self, value: &T, f: &mut F) -> Result<(), E>
539    where
540        F: FnMut(&T) -> Result<(), E>,
541        E: From<RecursionLimitError>,
542    {
543        self.checked_recur(move |_| {
544            value.try_visit_children(|child| self.try_visit_post(child, f))?;
545            f(value)
546        })
547    }
548
549    fn try_visit_mut_post<F, E>(&self, value: &mut T, f: &mut F) -> Result<(), E>
550    where
551        F: FnMut(&mut T) -> Result<(), E>,
552        E: From<RecursionLimitError>,
553    {
554        self.checked_recur(move |_| {
555            value.try_visit_mut_children(|child| self.try_visit_mut_post(child, f))?;
556            f(value)
557        })
558    }
559
560    fn visit_pre<F>(&self, value: &T, f: &mut F) -> Result<(), RecursionLimitError>
561    where
562        F: FnMut(&T),
563    {
564        self.checked_recur(move |_| {
565            f(value);
566            value.try_visit_children(|child| self.visit_pre(child, f))
567        })
568    }
569
570    fn visit_pre_with_context<Context, AccFun, Visitor>(
571        &self,
572        node: &T,
573        init: Context,
574        acc_fun: &mut AccFun,
575        visitor: &mut Visitor,
576    ) -> Result<(), RecursionLimitError>
577    where
578        Context: Clone,
579        AccFun: FnMut(Context, &T) -> Context,
580        Visitor: FnMut(&Context, &T),
581    {
582        self.checked_recur(move |_| {
583            visitor(&init, node);
584            let context = acc_fun(init, node);
585            node.try_visit_children(|child| {
586                self.visit_pre_with_context(child, context.clone(), acc_fun, visitor)
587            })
588        })
589    }
590
591    fn visit_pre_nolimit<F>(&self, value: &T, f: &mut F)
592    where
593        F: FnMut(&T),
594    {
595        maybe_grow(|| {
596            f(value);
597            value.visit_children(|child| self.visit_pre_nolimit(child, f))
598        })
599    }
600
601    fn visit_mut_pre<F>(&self, value: &mut T, f: &mut F) -> Result<(), RecursionLimitError>
602    where
603        F: FnMut(&mut T),
604    {
605        self.checked_recur(move |_| {
606            f(value);
607            value.try_visit_mut_children(|child| self.visit_mut_pre(child, f))
608        })
609    }
610
611    fn visit_mut_pre_nolimit<F>(&self, value: &mut T, f: &mut F)
612    where
613        F: FnMut(&mut T),
614    {
615        maybe_grow(|| {
616            f(value);
617            value.visit_mut_children(|child| self.visit_mut_pre_nolimit(child, f))
618        })
619    }
620
621    fn try_visit_pre<F, E>(&self, value: &T, f: &mut F) -> Result<(), E>
622    where
623        F: FnMut(&T) -> Result<(), E>,
624        E: From<RecursionLimitError>,
625    {
626        self.checked_recur(move |_| {
627            f(value)?;
628            value.try_visit_children(|child| self.try_visit_pre(child, f))
629        })
630    }
631
632    fn try_visit_mut_pre<F, E>(&self, value: &mut T, f: &mut F) -> Result<(), E>
633    where
634        F: FnMut(&mut T) -> Result<(), E>,
635        E: From<RecursionLimitError>,
636    {
637        self.checked_recur(move |_| {
638            f(value)?;
639            value.try_visit_mut_children(|child| self.try_visit_mut_pre(child, f))
640        })
641    }
642
643    fn visit_pre_post<F1, F2>(
644        &self,
645        value: &T,
646        pre: &mut F1,
647        post: &mut F2,
648    ) -> Result<(), RecursionLimitError>
649    where
650        F1: FnMut(&T) -> Option<Vec<&T>>,
651        F2: FnMut(&T),
652    {
653        self.checked_recur(move |_| {
654            if let Some(to_visit) = pre(value) {
655                for child in to_visit {
656                    self.visit_pre_post(child, pre, post)?;
657                }
658            } else {
659                value.try_visit_children(|child| self.visit_pre_post(child, pre, post))?;
660            }
661            post(value);
662            Ok(())
663        })
664    }
665
666    fn visit_pre_post_nolimit<F1, F2>(&self, value: &T, pre: &mut F1, post: &mut F2)
667    where
668        F1: FnMut(&T) -> Option<Vec<&T>>,
669        F2: FnMut(&T),
670    {
671        maybe_grow(|| {
672            if let Some(to_visit) = pre(value) {
673                for child in to_visit {
674                    self.visit_pre_post_nolimit(child, pre, post);
675                }
676            } else {
677                value.visit_children(|child| self.visit_pre_post_nolimit(child, pre, post));
678            }
679            post(value);
680        })
681    }
682
683    fn visit_mut_pre_post<F1, F2>(
684        &self,
685        value: &mut T,
686        pre: &mut F1,
687        post: &mut F2,
688    ) -> Result<(), RecursionLimitError>
689    where
690        F1: FnMut(&mut T) -> Option<Vec<&mut T>>,
691        F2: FnMut(&mut T),
692    {
693        self.checked_recur(move |_| {
694            if let Some(to_visit) = pre(value) {
695                for child in to_visit {
696                    self.visit_mut_pre_post(child, pre, post)?;
697                }
698            } else {
699                value.try_visit_mut_children(|child| self.visit_mut_pre_post(child, pre, post))?;
700            }
701            post(value);
702            Ok(())
703        })
704    }
705
706    fn visit_mut_pre_post_nolimit<F1, F2>(&self, value: &mut T, pre: &mut F1, post: &mut F2)
707    where
708        F1: FnMut(&mut T) -> Option<Vec<&mut T>>,
709        F2: FnMut(&mut T),
710    {
711        maybe_grow(|| {
712            if let Some(to_visit) = pre(value) {
713                for child in to_visit {
714                    self.visit_mut_pre_post_nolimit(child, pre, post);
715                }
716            } else {
717                value.visit_mut_children(|child| self.visit_mut_pre_post_nolimit(child, pre, post));
718            }
719            post(value);
720        })
721    }
722
723    fn visit<V>(&self, value: &T, visitor: &mut V) -> Result<(), RecursionLimitError>
724    where
725        Self: Sized,
726        V: Visitor<T>,
727    {
728        self.checked_recur(move |this| {
729            visitor.pre_visit(value);
730            value.try_visit_children(|child| this.visit(child, visitor))?;
731            visitor.post_visit(value);
732            Ok(())
733        })
734    }
735
736    fn visit_mut<V>(&self, value: &mut T, visitor: &mut V) -> Result<(), RecursionLimitError>
737    where
738        Self: Sized,
739        V: VisitorMut<T>,
740    {
741        self.checked_recur(move |this| {
742            visitor.pre_visit(value);
743            value.try_visit_mut_children(|child| this.visit_mut(child, visitor))?;
744            visitor.post_visit(value);
745            Ok(())
746        })
747    }
748
749    fn try_visit<V, E>(&self, value: &T, visitor: &mut V) -> Result<(), E>
750    where
751        Self: Sized,
752        V: TryVisitor<T, E>,
753        E: From<RecursionLimitError>,
754    {
755        self.checked_recur(move |_| {
756            visitor.pre_visit(value)?;
757            value.try_visit_children(|child| self.try_visit(child, visitor))?;
758            visitor.post_visit(value)?;
759            Ok(())
760        })
761    }
762
763    fn try_visit_mut<V, E>(&self, value: &mut T, visitor: &mut V) -> Result<(), E>
764    where
765        Self: Sized,
766        V: TryVisitorMut<T, E>,
767        E: From<RecursionLimitError>,
768    {
769        self.checked_recur(move |_| {
770            visitor.pre_visit(value)?;
771            value.try_visit_mut_children(|child| self.try_visit_mut(child, visitor))?;
772            visitor.post_visit(value)?;
773            Ok(())
774        })
775    }
776}
777
778#[cfg(test)]
779mod tests {
780    use mz_ore::assert_ok;
781
782    use super::*;
783
784    #[derive(Debug, Eq, PartialEq)]
785    enum A {
786        Add(Box<A>, Box<A>),
787        Lit(u64),
788        FrB(Box<B>),
789    }
790
791    #[derive(Debug, Eq, PartialEq)]
792    enum B {
793        Mul(Box<B>, Box<B>),
794        Lit(u64),
795        FrA(Box<A>),
796    }
797
798    impl VisitChildren<A> for A {
799        fn visit_children<F>(&self, mut f: F)
800        where
801            F: FnMut(&A),
802        {
803            VisitChildren::visit_children(self, |expr: &B| {
804                #[allow(deprecated)]
805                Visit::visit_post_nolimit(expr, &mut |expr| match expr {
806                    B::FrA(expr) => f(expr.as_ref()),
807                    _ => (),
808                });
809            });
810
811            match self {
812                A::Add(lhs, rhs) => {
813                    f(lhs);
814                    f(rhs);
815                }
816                A::Lit(_) => (),
817                A::FrB(_) => (),
818            }
819        }
820
821        fn visit_mut_children<F>(&mut self, mut f: F)
822        where
823            F: FnMut(&mut A),
824        {
825            VisitChildren::visit_mut_children(self, |expr: &mut B| {
826                #[allow(deprecated)]
827                Visit::visit_mut_post_nolimit(expr, &mut |expr| match expr {
828                    B::FrA(expr) => f(expr.as_mut()),
829                    _ => (),
830                });
831            });
832
833            match self {
834                A::Add(lhs, rhs) => {
835                    f(lhs);
836                    f(rhs);
837                }
838                A::Lit(_) => (),
839                A::FrB(_) => (),
840            }
841        }
842
843        fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
844        where
845            F: FnMut(&A) -> Result<(), E>,
846            E: From<RecursionLimitError>,
847        {
848            VisitChildren::try_visit_children(self, |expr: &B| {
849                Visit::try_visit_post(expr, &mut |expr| match expr {
850                    B::FrA(expr) => f(expr.as_ref()),
851                    _ => Ok(()),
852                })
853            })?;
854
855            match self {
856                A::Add(lhs, rhs) => {
857                    f(lhs)?;
858                    f(rhs)?;
859                }
860                A::Lit(_) => (),
861                A::FrB(_) => (),
862            }
863            Ok(())
864        }
865
866        fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
867        where
868            F: FnMut(&mut A) -> Result<(), E>,
869            E: From<RecursionLimitError>,
870        {
871            VisitChildren::try_visit_mut_children(self, |expr: &mut B| {
872                Visit::try_visit_mut_post(expr, &mut |expr| match expr {
873                    B::FrA(expr) => f(expr.as_mut()),
874                    _ => Ok(()),
875                })
876            })?;
877
878            match self {
879                A::Add(lhs, rhs) => {
880                    f(lhs)?;
881                    f(rhs)?;
882                }
883                A::Lit(_) => (),
884                A::FrB(_) => (),
885            }
886            Ok(())
887        }
888    }
889
890    impl VisitChildren<B> for A {
891        fn visit_children<F>(&self, mut f: F)
892        where
893            F: FnMut(&B),
894        {
895            match self {
896                A::Add(_, _) => (),
897                A::Lit(_) => (),
898                A::FrB(expr) => f(expr),
899            }
900        }
901
902        fn visit_mut_children<F>(&mut self, mut f: F)
903        where
904            F: FnMut(&mut B),
905        {
906            match self {
907                A::Add(_, _) => (),
908                A::Lit(_) => (),
909                A::FrB(expr) => f(expr),
910            }
911        }
912
913        fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
914        where
915            F: FnMut(&B) -> Result<(), E>,
916            E: From<RecursionLimitError>,
917        {
918            match self {
919                A::Add(_, _) => Ok(()),
920                A::Lit(_) => Ok(()),
921                A::FrB(expr) => f(expr),
922            }
923        }
924
925        fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
926        where
927            F: FnMut(&mut B) -> Result<(), E>,
928            E: From<RecursionLimitError>,
929        {
930            match self {
931                A::Add(_, _) => Ok(()),
932                A::Lit(_) => Ok(()),
933                A::FrB(expr) => f(expr),
934            }
935        }
936    }
937
938    impl VisitChildren<B> for B {
939        fn visit_children<F>(&self, mut f: F)
940        where
941            F: FnMut(&B),
942        {
943            // VisitChildren::visit_children(self, |expr: &A| {
944            //     #[allow(deprecated)]
945            //     Visit::visit_post_nolimit(expr, &mut |expr| match expr {
946            //         A::FrB(expr) => f(expr.as_ref()),
947            //         _ => (),
948            //     });
949            // });
950
951            match self {
952                B::Mul(lhs, rhs) => {
953                    f(lhs);
954                    f(rhs);
955                }
956                B::Lit(_) => (),
957                B::FrA(_) => (),
958            }
959        }
960
961        fn visit_mut_children<F>(&mut self, mut f: F)
962        where
963            F: FnMut(&mut B),
964        {
965            // VisitChildren::visit_mut_children(self, |expr: &mut A| {
966            //     #[allow(deprecated)]
967            //     Visit::visit_mut_post_nolimit(expr, &mut |expr| match expr {
968            //         A::FrB(expr) => f(expr.as_mut()),
969            //         _ => (),
970            //     });
971            // });
972
973            match self {
974                B::Mul(lhs, rhs) => {
975                    f(lhs);
976                    f(rhs);
977                }
978                B::Lit(_) => (),
979                B::FrA(_) => (),
980            }
981        }
982
983        fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
984        where
985            F: FnMut(&B) -> Result<(), E>,
986            E: From<RecursionLimitError>,
987        {
988            // VisitChildren::try_visit_children(self, |expr: &A| {
989            //     Visit::try_visit_post(expr, &mut |expr| match expr {
990            //         A::FrB(expr) => f(expr.as_ref()),
991            //         _ => Ok(()),
992            //     })
993            // })?;
994
995            match self {
996                B::Mul(lhs, rhs) => {
997                    f(lhs)?;
998                    f(rhs)?;
999                }
1000                B::Lit(_) => (),
1001                B::FrA(_) => (),
1002            }
1003            Ok(())
1004        }
1005
1006        fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
1007        where
1008            F: FnMut(&mut B) -> Result<(), E>,
1009            E: From<RecursionLimitError>,
1010        {
1011            // VisitChildren::try_visit_mut_children(self, |expr: &mut A| {
1012            //     Visit::try_visit_mut_post(expr, &mut |expr| match expr {
1013            //         A::FrB(expr) => f(expr.as_mut()),
1014            //         _ => Ok(()),
1015            //     })
1016            // })?;
1017
1018            match self {
1019                B::Mul(lhs, rhs) => {
1020                    f(lhs)?;
1021                    f(rhs)?;
1022                }
1023                B::Lit(_) => (),
1024                B::FrA(_) => (),
1025            }
1026            Ok(())
1027        }
1028    }
1029
1030    impl VisitChildren<A> for B {
1031        fn visit_children<F>(&self, mut f: F)
1032        where
1033            F: FnMut(&A),
1034        {
1035            match self {
1036                B::Mul(_, _) => (),
1037                B::Lit(_) => (),
1038                B::FrA(expr) => f(expr),
1039            }
1040        }
1041
1042        fn visit_mut_children<F>(&mut self, mut f: F)
1043        where
1044            F: FnMut(&mut A),
1045        {
1046            match self {
1047                B::Mul(_, _) => (),
1048                B::Lit(_) => (),
1049                B::FrA(expr) => f(expr),
1050            }
1051        }
1052
1053        fn try_visit_children<F, E>(&self, mut f: F) -> Result<(), E>
1054        where
1055            F: FnMut(&A) -> Result<(), E>,
1056            E: From<RecursionLimitError>,
1057        {
1058            match self {
1059                B::Mul(_, _) => Ok(()),
1060                B::Lit(_) => Ok(()),
1061                B::FrA(expr) => f(expr),
1062            }
1063        }
1064
1065        fn try_visit_mut_children<F, E>(&mut self, mut f: F) -> Result<(), E>
1066        where
1067            F: FnMut(&mut A) -> Result<(), E>,
1068            E: From<RecursionLimitError>,
1069        {
1070            match self {
1071                B::Mul(_, _) => Ok(()),
1072                B::Lit(_) => Ok(()),
1073                B::FrA(expr) => f(expr),
1074            }
1075        }
1076    }
1077
1078    /// x + (y + z)
1079    fn test_term_a(x: A, y: A, z: A) -> A {
1080        let x = Box::new(x);
1081        let y = Box::new(y);
1082        let z = Box::new(z);
1083        A::Add(x, Box::new(A::Add(y, z)))
1084    }
1085
1086    /// u + (v + w)
1087    fn test_term_b(u: B, v: B, w: B) -> B {
1088        let u = Box::new(u);
1089        let v = Box::new(v);
1090        let w = Box::new(w);
1091        B::Mul(u, Box::new(B::Mul(v, w)))
1092    }
1093
1094    fn a_to_b(x: A) -> B {
1095        B::FrA(Box::new(x))
1096    }
1097
1098    fn b_to_a(x: B) -> A {
1099        A::FrB(Box::new(x))
1100    }
1101
1102    fn test_term_rec_b(b: u64) -> B {
1103        test_term_b(
1104            a_to_b(test_term_a(
1105                b_to_a(test_term_b(B::Lit(b + 11), B::Lit(b + 12), B::Lit(b + 13))),
1106                b_to_a(test_term_b(B::Lit(b + 14), B::Lit(b + 15), B::Lit(b + 16))),
1107                b_to_a(test_term_b(B::Lit(b + 17), B::Lit(b + 18), B::Lit(b + 19))),
1108            )),
1109            a_to_b(test_term_a(
1110                b_to_a(test_term_b(B::Lit(b + 21), B::Lit(b + 22), B::Lit(b + 23))),
1111                b_to_a(test_term_b(B::Lit(b + 24), B::Lit(b + 25), B::Lit(b + 26))),
1112                b_to_a(test_term_b(B::Lit(b + 27), B::Lit(b + 28), B::Lit(b + 29))),
1113            )),
1114            a_to_b(test_term_a(
1115                b_to_a(test_term_b(B::Lit(b + 31), B::Lit(b + 32), B::Lit(b + 33))),
1116                b_to_a(test_term_b(B::Lit(b + 34), B::Lit(b + 35), B::Lit(b + 36))),
1117                b_to_a(test_term_b(B::Lit(b + 37), B::Lit(b + 38), B::Lit(b + 39))),
1118            )),
1119        )
1120    }
1121
1122    fn test_term_rec_a(b: u64) -> A {
1123        test_term_a(
1124            b_to_a(test_term_b(
1125                a_to_b(test_term_a(A::Lit(b + 11), A::Lit(b + 12), A::Lit(b + 13))),
1126                a_to_b(test_term_a(A::Lit(b + 14), A::Lit(b + 15), A::Lit(b + 16))),
1127                a_to_b(test_term_a(A::Lit(b + 17), A::Lit(b + 18), A::Lit(b + 19))),
1128            )),
1129            b_to_a(test_term_b(
1130                a_to_b(test_term_a(A::Lit(b + 21), A::Lit(b + 22), A::Lit(b + 23))),
1131                a_to_b(test_term_a(A::Lit(b + 24), A::Lit(b + 25), A::Lit(b + 26))),
1132                a_to_b(test_term_a(A::Lit(b + 27), A::Lit(b + 28), A::Lit(b + 29))),
1133            )),
1134            b_to_a(test_term_b(
1135                a_to_b(test_term_a(A::Lit(b + 31), A::Lit(b + 32), A::Lit(b + 33))),
1136                a_to_b(test_term_a(A::Lit(b + 34), A::Lit(b + 35), A::Lit(b + 36))),
1137                a_to_b(test_term_a(A::Lit(b + 37), A::Lit(b + 38), A::Lit(b + 39))),
1138            )),
1139        )
1140    }
1141
1142    #[mz_ore::test]
1143    #[cfg_attr(miri, ignore)] // error: unsupported operation: can't call foreign function `rust_psm_stack_pointer` on OS `linux`
1144    fn test_recursive_types_a() {
1145        let mut act = test_term_rec_a(0);
1146        let exp = test_term_rec_a(20);
1147
1148        let res = act.visit_mut_pre(&mut |expr| match expr {
1149            A::Lit(x) => *x = *x + 20,
1150            _ => (),
1151        });
1152
1153        assert_ok!(res);
1154        assert_eq!(act, exp);
1155    }
1156
1157    /// This test currently fails with the following error:
1158    ///
1159    ///   reached the recursion limit while instantiating
1160    ///   `<visit::Visitor<B> as CheckedRec...ore::stack::RecursionLimitError>`
1161    ///
1162    /// The problem (I think) is in the fact that the lambdas passed in the
1163    /// VisitChildren<A> for A and the VisitChildren<B> for B definitions end
1164    /// up in an infinite loop.
1165    ///
1166    /// More specifically, we run into the following cycle:
1167    ///
1168    /// - `<A as VisitChildren<A>>::visit_children`
1169    ///   - <A as VisitChildren<B>>::visit_children`
1170    ///     - <B as Visit>::visit_post_nolimit`
1171    ///       - <B as VisitChildren<B>>::visit_children`
1172    ///         - <B as VisitChildren<A>>::visit_children`
1173    ///           - <A as Visit>::visit_post_nolimit`
1174    ///             - <A as VisitChildren<A>>::visit_children`
1175    #[mz_ore::test]
1176    #[ignore = "making the VisitChildren definitions symmetric breaks the compiler"]
1177    fn test_recursive_types_b() {
1178        let mut act = test_term_rec_b(0);
1179        let exp = test_term_rec_b(30);
1180
1181        let res = act.visit_mut_pre(&mut |expr| match expr {
1182            B::Lit(x) => *x = *x + 30,
1183            _ => (),
1184        });
1185
1186        assert_ok!(res);
1187        assert_eq!(act, exp);
1188    }
1189}