1use proc_macro2::{Span, TokenStream};
12use syn::spanned::Spanned;
13use syn::{DeriveInput, Expr, Field, Ident, Path, Type, Variant};
14
15use crate::ast::*;
16use crate::attr::{self, ParamsMode, ParsedAttributes, StratMode};
17use crate::error::{self, Context, Ctx, DeriveResult};
18use crate::use_tracking::{UseMarkable, UseTracker};
19use crate::util::{fields_to_vec, is_unit_type, self_ty};
20use crate::void::IsUninhabited;
21
22pub fn impl_proptest_arbitrary(ast: DeriveInput) -> TokenStream {
27 let mut ctx = Context::default();
28 let result = derive_proptest_arbitrary(&mut ctx, ast);
29 match (result, ctx.check()) {
30 (Ok(derive), Ok(())) => derive,
31 (_, Err(err)) => err,
32 (Err(result), Ok(())) => panic!(
33 "[proptest_derive]: internal error, this is a bug! \
34 result: {:?}",
35 result
36 ),
37 }
38}
39
40struct DeriveData<B> {
43 ident: Ident,
44 attrs: ParsedAttributes,
45 tracker: UseTracker,
46 body: B,
47}
48
49fn derive_proptest_arbitrary(
51 ctx: Ctx,
52 ast: DeriveInput,
53) -> DeriveResult<TokenStream> {
54 use syn::Data::*;
55
56 error::if_has_lifetimes(ctx, &ast);
58
59 let attrs = attr::parse_top_attributes(ctx, &ast.attrs)?;
61
62 let mut tracker = UseTracker::new(ast.generics);
64 if attrs.no_bound {
65 tracker.no_track();
66 }
67
68 let the_impl = match ast.data {
70 Struct(data) => derive_struct(
72 ctx,
73 DeriveData {
74 tracker,
75 attrs,
76 ident: ast.ident,
77 body: fields_to_vec(data.fields),
78 },
79 ),
80 Enum(data) => derive_enum(
82 ctx,
83 DeriveData {
84 tracker,
85 attrs,
86 ident: ast.ident,
87 body: data.variants.into_iter().collect(),
88 },
89 ),
90 _ => error::not_struct_or_enum(ctx)?,
92 }?;
93
94 let q = the_impl.into_tokens(ctx)?;
96
97 Ok(q)
99}
100
101fn derive_struct(
107 ctx: Ctx,
108 mut ast: DeriveData<Vec<Field>>,
109) -> DeriveResult<Impl> {
110 error::if_enum_attrs_present(ctx, &ast.attrs, error::STRUCT);
112
113 error::if_strategy_present(ctx, &ast.attrs, error::STRUCT);
115
116 let v_path = ast.ident.clone().into();
117 let parts = if ast.body.is_empty() {
118 error::if_present_on_unit_struct(ctx, &ast.attrs);
120 let (strat, ctor) = pair_unit_self(&v_path);
121 (Params::empty(), strat, ctor)
122 } else {
123 if (&*ast.body).is_uninhabited() {
131 error::uninhabited_struct(ctx);
132 }
133
134 let closure = map_closure(v_path, &ast.body);
136
137 let parts = if let Some(param_ty) = ast.attrs.params.into_option() {
140 add_top_params(
142 param_ty,
143 derive_product_has_params(
144 ctx,
145 &mut ast.tracker,
146 error::STRUCT_FIELD,
147 closure,
148 ast.body,
149 )?,
150 )
151 } else {
152 derive_product_no_params(
154 ctx,
155 &mut ast.tracker,
156 ast.body,
157 error::STRUCT_FIELD,
158 )?
159 .finish(closure)
160 };
161
162 add_top_filter(ast.attrs.filter, parts)
164 };
165
166 Ok(Impl::new(ast.ident, ast.tracker, parts))
168}
169
170fn add_top_filter(filter: Vec<Expr>, parts: ImplParts) -> ImplParts {
172 let (params, strat, ctor) = parts;
173 let (strat, ctor) = add_filter_self(filter, (strat, ctor));
174 (params, strat, ctor)
175}
176
177fn add_filter_self(filter: Vec<Expr>, pair: StratPair) -> StratPair {
179 pair_filter(filter, self_ty(), pair)
180}
181
182fn add_top_params(
186 param_ty: Option<Type>,
187 (strat, ctor): StratPair,
188) -> ImplParts {
189 let params = Params::empty();
190 if let Some(params_ty) = param_ty {
191 (params + params_ty, strat, extract_api(ctor, FromReg::Top))
193 } else {
194 (params, strat, ctor)
195 }
196}
197
198fn derive_product_has_params(
201 ctx: Ctx,
202 ut: &mut UseTracker,
203 item: &str,
204 closure: MapClosure,
205 fields: Vec<Field>,
206) -> DeriveResult<StratPair> {
207 let len = fields.len();
211 fields
212 .into_iter()
213 .try_fold(StratAcc::new(len), |acc, field| {
214 let attrs = attr::parse_attributes(ctx, &field.attrs)?;
215
216 error::if_enum_attrs_present(ctx, &attrs, item);
218
219 error::if_specified_params(ctx, &attrs, item);
221
222 let span = field.span();
224 let ty = field.ty.clone();
225 let pair =
226 product_handle_default_params(ut, ty, span, attrs.strategy);
227 let pair = pair_filter(attrs.filter, field.ty, pair);
228 Ok(acc.add(pair))
229 })
230 .map(|acc| acc.finish(closure))
231}
232
233fn product_handle_default_params(
235 ut: &mut UseTracker,
236 ty: Type,
237 span: Span,
238 strategy: StratMode,
239) -> StratPair {
240 match strategy {
241 StratMode::Strategy(strat) => pair_existential(ty, strat),
244 StratMode::Value(value) => pair_value(ty, value),
246 StratMode::Regex(regex) => pair_regex(ty, regex),
248 StratMode::Arbitrary => {
250 ty.mark_uses(ut);
251 pair_any(ty, span)
252 }
253 }
254}
255
256fn derive_product_no_params(
259 ctx: Ctx,
260 ut: &mut UseTracker,
261 fields: Vec<Field>,
262 item: &str,
263) -> DeriveResult<PartsAcc<Ctor>> {
264 let acc = PartsAcc::new(fields.len());
268 fields.into_iter().try_fold(acc, |mut acc, field| {
269 let attrs = attr::parse_attributes(ctx, &field.attrs)?;
270
271 error::if_enum_attrs_present(ctx, &attrs, item);
273
274 let span = field.span();
275 let ty = field.ty;
276
277 let strat = pair_filter(
278 attrs.filter,
279 ty.clone(),
280 match attrs.params {
281 ParamsMode::Passthrough => match attrs.strategy {
283 StratMode::Strategy(strat) => pair_existential(ty, strat),
285 StratMode::Value(value) => pair_value(ty, value),
287 StratMode::Regex(regex) => pair_regex(ty, regex),
289 StratMode::Arbitrary => {
291 ty.mark_uses(ut);
292
293 let pref = acc.add_param(arbitrary_param(&ty));
295 pair_any_with(ty, pref, span)
296 }
297 },
298 ParamsMode::Default => {
300 product_handle_default_params(ut, ty, span, attrs.strategy)
301 }
302 ParamsMode::Specified(params_ty) =>
304 {
306 extract_nparam(
307 &mut acc,
308 params_ty,
309 match attrs.strategy {
310 StratMode::Strategy(strat) => {
312 pair_existential(ty, strat)
313 }
314 StratMode::Value(value) => {
316 pair_value_exist(ty, value)
317 }
318 StratMode::Regex(regex) => {
321 error::cant_set_param_and_regex(ctx, item);
322 pair_regex(ty, regex)
323 }
324 StratMode::Arbitrary => {
327 error::cant_set_param_but_not_strat(
328 ctx, &ty, item,
329 )?
330 }
331 },
332 )
333 }
334 },
335 );
336 Ok(acc.add_strat(strat))
337 })
338}
339
340fn extract_nparam<C>(
343 acc: &mut PartsAcc<C>,
344 params_ty: Type,
345 (strat, ctor): StratPair,
346) -> StratPair {
347 (
348 strat,
349 extract_api(ctor, FromReg::Num(acc.add_param(params_ty))),
350 )
351}
352
353fn derive_enum(
359 ctx: Ctx,
360 mut ast: DeriveData<Vec<Variant>>,
361) -> DeriveResult<Impl> {
362 error::if_skip_present(ctx, &ast.attrs, error::ENUM);
364
365 error::if_strategy_present(ctx, &ast.attrs, error::ENUM);
367
368 error::if_weight_present(ctx, &ast.attrs, error::ENUM);
370
371 if ast.body.is_empty() {
373 error::uninhabited_enum_with_no_variants(ctx)?;
374 }
375
376 if (&*ast.body).is_uninhabited() {
378 error::uninhabited_enum_variants_uninhabited(ctx)?;
379 }
380
381 let parts = if let Some(sty) = ast.attrs.params.into_option() {
384 derive_enum_has_params(ctx, &mut ast.tracker, &ast.ident, ast.body, sty)
386 } else {
387 derive_enum_no_params(ctx, &mut ast.tracker, &ast.ident, ast.body)
389 }?;
390
391 let parts = add_top_filter(ast.attrs.filter, parts);
392
393 Ok(Impl::new(ast.ident, ast.tracker, parts))
395}
396
397fn derive_enum_no_params(
399 ctx: Ctx,
400 ut: &mut UseTracker,
401 _self: &Ident,
402 variants: Vec<Variant>,
403) -> DeriveResult<ImplParts> {
404 let mut acc = PartsAcc::new(variants.len());
406
407 for variant in variants {
409 if let Some((weight, ident, fields, attrs)) =
410 keep_inhabited_variant(ctx, _self, variant)?
411 {
412 let path = parse_quote!( #_self::#ident );
413 let (strat, ctor) = if fields.is_empty() {
414 pair_unit_variant(ctx, &attrs, path)
416 } else {
417 derive_variant_with_fields(
419 ctx, ut, path, attrs, fields, &mut acc,
420 )?
421 };
422 acc = acc.add_strat((strat, (weight, ctor)));
423 }
424 }
425
426 ensure_union_has_strategies(ctx, &acc.strats);
427
428 Ok(acc.finish(ctx))
430}
431
432fn ensure_union_has_strategies<C>(ctx: Ctx, strats: &StratAcc<C>) {
434 if strats.is_empty() {
435 error::uninhabited_enum_because_of_skipped_variants(ctx);
438 }
439}
440
441fn derive_variant_with_fields<C>(
444 ctx: Ctx,
445 ut: &mut UseTracker,
446 v_path: Path,
447 attrs: ParsedAttributes,
448 fields: Vec<Field>,
449 acc: &mut PartsAcc<C>,
450) -> DeriveResult<StratPair> {
451 let filter = attrs.filter.clone();
452
453 let pair = match attrs.params {
454 ParamsMode::Passthrough => match attrs.strategy {
456 StratMode::Strategy(strat) => {
458 deny_all_attrs_on_fields(ctx, fields)?;
459 pair_existential_self(strat)
460 }
461 StratMode::Value(value) => {
463 deny_all_attrs_on_fields(ctx, fields)?;
464 pair_value_self(value)
465 }
466 StratMode::Regex(regex) => {
467 deny_all_attrs_on_fields(ctx, fields)?;
468 pair_regex_self(regex)
469 }
470 StratMode::Arbitrary => {
472 variant_no_explicit_strategy(ctx, ut, v_path, fields, acc)?
473 }
474 },
475 ParamsMode::Default => {
477 variant_handle_default_params(ctx, ut, v_path, attrs, fields)?
478 }
479 ParamsMode::Specified(params_ty) => extract_nparam(
481 acc,
482 params_ty,
483 match attrs.strategy {
484 StratMode::Strategy(strat) => {
486 deny_all_attrs_on_fields(ctx, fields)?;
487 pair_existential_self(strat)
488 }
489 StratMode::Value(value) => {
491 deny_all_attrs_on_fields(ctx, fields)?;
492 pair_value_exist_self(value)
493 }
494 StratMode::Regex(regex) => {
497 error::cant_set_param_and_regex(ctx, error::ENUM_VARIANT);
498 deny_all_attrs_on_fields(ctx, fields)?;
499 pair_regex_self(regex)
500 }
501 StratMode::Arbitrary => {
504 let ty = self_ty();
505 error::cant_set_param_but_not_strat(
506 ctx,
507 &ty,
508 error::ENUM_VARIANT,
509 )?
510 }
511 },
512 ),
513 };
514 let pair = add_filter_self(filter, pair);
515 Ok(pair)
516}
517
518fn variant_no_explicit_strategy<C>(
521 ctx: Ctx,
522 ut: &mut UseTracker,
523 v_path: Path,
524 fields: Vec<Field>,
525 acc: &mut PartsAcc<C>,
526) -> DeriveResult<StratPair> {
527 let closure = map_closure(v_path, &fields);
529 let fields_acc =
530 derive_product_no_params(ctx, ut, fields, error::ENUM_VARIANT_FIELD)?;
531 let (params, count) = fields_acc.params.consume();
532 let (strat, ctor) = fields_acc.strats.finish(closure);
533
534 let params_ty = params.into();
537 Ok((
538 strat,
539 if is_unit_type(¶ms_ty) {
540 ctor
541 } else {
542 let pref = acc.add_param(params_ty);
543 extract_all(ctor, count, FromReg::Num(pref))
544 },
545 ))
546}
547
548fn variant_handle_default_params(
550 ctx: Ctx,
551 ut: &mut UseTracker,
552 v_path: Path,
553 attrs: ParsedAttributes,
554 fields: Vec<Field>,
555) -> DeriveResult<StratPair> {
556 let pair = match attrs.strategy {
557 StratMode::Strategy(strat) => {
559 deny_all_attrs_on_fields(ctx, fields)?;
560 pair_existential_self(strat)
561 }
562 StratMode::Value(value) => {
564 deny_all_attrs_on_fields(ctx, fields)?;
565 pair_value_self(value)
566 }
567 StratMode::Regex(regex) => {
568 deny_all_attrs_on_fields(ctx, fields)?;
569 pair_regex_self(regex)
570 }
571 StratMode::Arbitrary =>
573 {
575 derive_product_has_params(
576 ctx,
577 ut,
578 error::ENUM_VARIANT_FIELD,
579 map_closure(v_path, &fields),
580 fields,
581 )?
582 }
583 };
584
585 Ok(pair)
586}
587
588fn deny_all_attrs_on_fields(ctx: Ctx, fields: Vec<Field>) -> DeriveResult<()> {
590 fields.into_iter().try_for_each(|field| {
591 let f_attr = attr::parse_attributes(ctx, &field.attrs)?;
592 error::if_anything_specified(ctx, &f_attr, error::ENUM_VARIANT_FIELD);
593 Ok(())
594 })
595}
596
597fn derive_enum_has_params(
600 ctx: Ctx,
601 ut: &mut UseTracker,
602 _self: &Ident,
603 variants: Vec<Variant>,
604 sty: Option<Type>,
605) -> DeriveResult<ImplParts> {
606 let mut acc = StratAcc::new(variants.len());
608
609 for variant in variants {
611 let parts = keep_inhabited_variant(ctx, _self, variant)?;
612 if let Some((weight, ident, fields, attrs)) = parts {
613 let path = parse_quote!( #_self::#ident );
614 let (strat, ctor) = if fields.is_empty() {
615 pair_unit_variant(ctx, &attrs, path)
617 } else {
618 let filter = attrs.filter.clone();
620 add_filter_self(
621 filter,
622 variant_handle_default_params(
623 ctx, ut, path, attrs, fields,
624 )?,
625 )
626 };
627 acc = acc.add((strat, (weight, ctor)));
628 }
629 }
630
631 ensure_union_has_strategies(ctx, &acc);
632
633 Ok(add_top_params(sty, acc.finish(ctx)))
634}
635
636fn keep_inhabited_variant(
638 ctx: Ctx,
639 _self: &Ident,
640 variant: Variant,
641) -> DeriveResult<Option<(u32, Ident, Vec<Field>, ParsedAttributes)>> {
642 let attrs = attr::parse_attributes(ctx, &variant.attrs)?;
643 let fields = fields_to_vec(variant.fields);
644
645 if attrs.skip {
646 ensure_has_only_skip_attr(ctx, &attrs, error::ENUM_VARIANT);
649 fields.into_iter().try_for_each(|field| {
650 let f_attrs = attr::parse_attributes(ctx, &field.attrs)?;
651 error::if_skip_present(ctx, &f_attrs, error::ENUM_VARIANT_FIELD);
652 ensure_has_only_skip_attr(ctx, &f_attrs, error::ENUM_VARIANT_FIELD);
653 Ok(())
654 })?;
655
656 return Ok(None);
657 }
658
659 if (&*fields).is_uninhabited() {
661 return Ok(None);
662 }
663
664 let weight = attrs.weight.unwrap_or(1);
666
667 Ok(Some((weight, variant.ident, fields, attrs)))
668}
669
670fn ensure_has_only_skip_attr(ctx: Ctx, attrs: &ParsedAttributes, item: &str) {
672 if attrs.params.is_set() {
673 error::skipped_variant_has_param(ctx, item);
674 }
675
676 if attrs.strategy.is_set() {
677 error::skipped_variant_has_strat(ctx, item);
678 }
679
680 if attrs.weight.is_some() {
681 error::skipped_variant_has_weight(ctx, item);
682 }
683
684 if !attrs.filter.is_empty() {
685 error::skipped_variant_has_filter(ctx, item);
686 }
687}
688
689fn pair_unit_variant(
691 ctx: Ctx,
692 attrs: &ParsedAttributes,
693 v_path: Path,
694) -> StratPair {
695 error::if_present_on_unit_variant(ctx, attrs);
696 pair_unit_self(&v_path)
697}
698
699struct PartsAcc<C> {
705 params: ParamAcc,
707 strats: StratAcc<C>,
709}
710
711impl<C> PartsAcc<C> {
712 fn new(size: usize) -> Self {
715 Self {
716 params: ParamAcc::empty(),
717 strats: StratAcc::new(size),
718 }
719 }
720
721 fn add_strat(self, pair: (Strategy, C)) -> Self {
723 Self {
724 strats: self.strats.add(pair),
725 params: self.params,
726 }
727 }
728
729 fn add_param(&mut self, ty: Type) -> usize {
732 self.params.add(ty)
733 }
734}
735
736impl PartsAcc<Ctor> {
737 fn finish(self, closure: MapClosure) -> ImplParts {
741 let (params, count) = self.params.consume();
742 let (strat, ctor) = self.strats.finish(closure);
743 (params, strat, extract_all(ctor, count, FromReg::Top))
744 }
745}
746
747impl PartsAcc<(u32, Ctor)> {
748 fn finish(self, ctx: Ctx) -> ImplParts {
752 let (params, count) = self.params.consume();
753 let (strat, ctor) = self.strats.finish(ctx);
754 (params, strat, extract_all(ctor, count, FromReg::Top))
755 }
756}
757
758struct ParamAcc {
764 types: Params,
766}
767
768impl ParamAcc {
769 fn empty() -> Self {
771 Self {
772 types: Params::empty(),
773 }
774 }
775
776 fn add(&mut self, ty: Type) -> usize {
778 let var = self.types.len();
779 self.types += ty;
780 var
781 }
782
783 fn consume(self) -> (Params, usize) {
785 let count = self.types.len();
786 (self.types, count)
787 }
788}
789
790struct StratAcc<C> {
796 types: Vec<Strategy>,
798 ctors: Vec<C>,
800}
801
802impl<C> StratAcc<C> {
803 fn new(size: usize) -> Self {
806 Self {
807 types: Vec::with_capacity(size),
808 ctors: Vec::with_capacity(size),
809 }
810 }
811
812 fn add(mut self, (strat, ctor): (Strategy, C)) -> Self {
815 self.types.push(strat);
816 self.ctors.push(ctor);
817 self
818 }
819
820 fn consume(self) -> (Vec<Strategy>, Vec<C>) {
824 (self.types, self.ctors)
825 }
826
827 fn is_empty(&self) -> bool {
829 self.types.is_empty()
830 }
831}
832
833impl StratAcc<Ctor> {
834 fn finish(self, closure: MapClosure) -> StratPair {
837 pair_map(self.consume(), closure)
838 }
839}
840
841impl StratAcc<(u32, Ctor)> {
842 fn finish(self, ctx: Ctx) -> StratPair {
846 if self
848 .ctors
849 .iter()
850 .map(|&(w, _)| w)
851 .try_fold(0u32, |acc, w| acc.checked_add(w))
852 .is_none()
853 {
854 error::weight_overflowing(ctx)
855 }
856
857 pair_oneof(self.consume())
858 }
859}