domain_macros/
lib.rs

1//! Procedural macros for [`domain`].
2//!
3//! [`domain`]: https://docs.rs/domain
4
5use proc_macro as pm;
6use proc_macro2::TokenStream;
7use quote::{format_ident, ToTokens};
8use syn::{Error, Ident, Result};
9
10mod impls;
11use impls::ImplSkeleton;
12
13mod data;
14use data::Struct;
15
16mod repr;
17use repr::Repr;
18
19//----------- SplitBytes -----------------------------------------------------
20
21#[proc_macro_derive(SplitBytes)]
22pub fn derive_split_bytes(input: pm::TokenStream) -> pm::TokenStream {
23    fn inner(input: syn::DeriveInput) -> Result<TokenStream> {
24        let data = match &input.data {
25            syn::Data::Struct(data) => data,
26            syn::Data::Enum(data) => {
27                return Err(Error::new_spanned(
28                    data.enum_token,
29                    "'SplitBytes' can only be 'derive'd for 'struct's",
30                ));
31            }
32            syn::Data::Union(data) => {
33                return Err(Error::new_spanned(
34                    data.union_token,
35                    "'SplitBytes' can only be 'derive'd for 'struct's",
36                ));
37            }
38        };
39
40        // Construct an 'ImplSkeleton' so that we can add trait bounds.
41        let mut skeleton = ImplSkeleton::new(&input, false);
42
43        // Add the parsing lifetime to the 'impl'.
44        let (lifetime, param) = skeleton.new_lifetime_param(
45            "bytes",
46            skeleton.lifetimes.iter().map(|l| l.lifetime.clone()),
47        );
48        skeleton.lifetimes.push(param);
49        skeleton.bound = Some(
50            syn::parse_quote!(::domain::new::base::wire::SplitBytes<#lifetime>),
51        );
52
53        // Inspect the 'struct' fields.
54        let data = Struct::new_as_self(&data.fields);
55        let builder = data.builder(field_prefixed);
56
57        // Establish bounds on the fields.
58        for field in data.fields() {
59            skeleton.require_bound(
60                field.ty.clone(),
61                syn::parse_quote!(::domain::new::base::wire::SplitBytes<#lifetime>),
62            );
63        }
64
65        // Define 'parse_bytes()'.
66        let init_vars = builder.init_vars();
67        let tys = data.fields().map(|f| &f.ty);
68        skeleton.contents.stmts.push(syn::parse_quote! {
69            fn split_bytes(
70                bytes: & #lifetime [::domain::__core::primitive::u8],
71            ) -> ::domain::__core::result::Result<
72                (Self, & #lifetime [::domain::__core::primitive::u8]),
73                ::domain::new::base::wire::ParseError,
74            > {
75                #(let (#init_vars, bytes) =
76                    <#tys as ::domain::new::base::wire::SplitBytes<#lifetime>>
77                    ::split_bytes(bytes)?;)*
78                Ok((#builder, bytes))
79            }
80        });
81
82        Ok(skeleton.into_token_stream())
83    }
84
85    let input = syn::parse_macro_input!(input as syn::DeriveInput);
86    inner(input)
87        .unwrap_or_else(syn::Error::into_compile_error)
88        .into()
89}
90
91//----------- ParseBytes -----------------------------------------------------
92
93#[proc_macro_derive(ParseBytes)]
94pub fn derive_parse_bytes(input: pm::TokenStream) -> pm::TokenStream {
95    fn inner(input: syn::DeriveInput) -> Result<TokenStream> {
96        let data = match &input.data {
97            syn::Data::Struct(data) => data,
98            syn::Data::Enum(data) => {
99                return Err(Error::new_spanned(
100                    data.enum_token,
101                    "'ParseBytes' can only be 'derive'd for 'struct's",
102                ));
103            }
104            syn::Data::Union(data) => {
105                return Err(Error::new_spanned(
106                    data.union_token,
107                    "'ParseBytes' can only be 'derive'd for 'struct's",
108                ));
109            }
110        };
111
112        // Construct an 'ImplSkeleton' so that we can add trait bounds.
113        let mut skeleton = ImplSkeleton::new(&input, false);
114
115        // Add the parsing lifetime to the 'impl'.
116        let (lifetime, param) = skeleton.new_lifetime_param(
117            "bytes",
118            skeleton.lifetimes.iter().map(|l| l.lifetime.clone()),
119        );
120        skeleton.lifetimes.push(param);
121        skeleton.bound = Some(
122            syn::parse_quote!(::domain::new::base::wire::ParseBytes<#lifetime>),
123        );
124
125        // Inspect the 'struct' fields.
126        let data = Struct::new_as_self(&data.fields);
127        let builder = data.builder(field_prefixed);
128
129        // Establish bounds on the fields.
130        for field in data.sized_fields() {
131            skeleton.require_bound(
132                field.ty.clone(),
133                syn::parse_quote!(::domain::new::base::wire::SplitBytes<#lifetime>),
134            );
135        }
136        if let Some(field) = data.unsized_field() {
137            skeleton.require_bound(
138                field.ty.clone(),
139                syn::parse_quote!(::domain::new::base::wire::ParseBytes<#lifetime>),
140            );
141        }
142
143        // Finish early if the 'struct' has no fields.
144        if data.is_empty() {
145            skeleton.contents.stmts.push(syn::parse_quote! {
146                fn parse_bytes(
147                    bytes: & #lifetime [::domain::__core::primitive::u8],
148                ) -> ::domain::__core::result::Result<
149                    Self,
150                    ::domain::new::base::wire::ParseError,
151                > {
152                    if bytes.is_empty() {
153                        Ok(#builder)
154                    } else {
155                        Err(::domain::new::base::wire::ParseError)
156                    }
157                }
158            });
159
160            return Ok(skeleton.into_token_stream());
161        }
162
163        // Define 'parse_bytes()'.
164        let init_vars = builder.sized_init_vars();
165        let tys = builder.sized_fields().map(|f| &f.ty);
166        let unsized_ty = &builder.unsized_field().unwrap().ty;
167        let unsized_init_var = builder.unsized_init_var().unwrap();
168        skeleton.contents.stmts.push(syn::parse_quote! {
169            fn parse_bytes(
170                bytes: & #lifetime [::domain::__core::primitive::u8],
171            ) -> ::domain::__core::result::Result<
172                Self,
173                ::domain::new::base::wire::ParseError,
174            > {
175                #(let (#init_vars, bytes) =
176                    <#tys as ::domain::new::base::wire::SplitBytes<#lifetime>>
177                    ::split_bytes(bytes)?;)*
178                let #unsized_init_var =
179                    <#unsized_ty as ::domain::new::base::wire::ParseBytes<#lifetime>>
180                    ::parse_bytes(bytes)?;
181                Ok(#builder)
182            }
183        });
184
185        Ok(skeleton.into_token_stream())
186    }
187
188    let input = syn::parse_macro_input!(input as syn::DeriveInput);
189    inner(input)
190        .unwrap_or_else(syn::Error::into_compile_error)
191        .into()
192}
193
194//----------- SplitBytesZC ---------------------------------------------------
195
196#[proc_macro_derive(SplitBytesZC)]
197pub fn derive_split_bytes_zc(input: pm::TokenStream) -> pm::TokenStream {
198    fn inner(input: syn::DeriveInput) -> Result<TokenStream> {
199        let data = match &input.data {
200            syn::Data::Struct(data) => data,
201            syn::Data::Enum(data) => {
202                return Err(Error::new_spanned(
203                    data.enum_token,
204                    "'SplitBytesZC' can only be 'derive'd for 'struct's",
205                ));
206            }
207            syn::Data::Union(data) => {
208                return Err(Error::new_spanned(
209                    data.union_token,
210                    "'SplitBytesZC' can only be 'derive'd for 'struct's",
211                ));
212            }
213        };
214
215        let _ = Repr::determine(&input.attrs, "SplitBytesZC")?;
216
217        // Construct an 'ImplSkeleton' so that we can add trait bounds.
218        let mut skeleton = ImplSkeleton::new(&input, true);
219        skeleton.bound =
220            Some(syn::parse_quote!(::domain::new::base::wire::SplitBytesZC));
221
222        // Inspect the 'struct' fields.
223        let data = Struct::new_as_self(&data.fields);
224
225        // Establish bounds on the fields.
226        for field in data.fields() {
227            skeleton.require_bound(
228                field.ty.clone(),
229                syn::parse_quote!(::domain::new::base::wire::SplitBytesZC),
230            );
231        }
232
233        // Finish early if the 'struct' has no fields.
234        if data.is_empty() {
235            skeleton.contents.stmts.push(syn::parse_quote! {
236                fn split_bytes_by_ref(
237                    bytes: &[::domain::__core::primitive::u8],
238                ) -> ::domain::__core::result::Result<
239                    (&Self, &[::domain::__core::primitive::u8]),
240                    ::domain::new::base::wire::ParseError,
241                > {
242                    Ok((
243                        // SAFETY: 'Self' is a 'struct' with no fields,
244                        // and so has size 0 and alignment 1.  It can be
245                        // constructed at any address.
246                        unsafe { &*bytes.as_ptr().cast::<Self>() },
247                        bytes,
248                    ))
249                }
250            });
251
252            return Ok(skeleton.into_token_stream());
253        }
254
255        // Define 'split_bytes_by_ref()'.
256        let tys = data.sized_fields().map(|f| &f.ty);
257        let unsized_ty = &data.unsized_field().unwrap().ty;
258        skeleton.contents.stmts.push(syn::parse_quote! {
259            fn split_bytes_by_ref(
260                bytes: &[::domain::__core::primitive::u8],
261            ) -> ::domain::__core::result::Result<
262                (&Self, &[::domain::__core::primitive::u8]),
263                ::domain::new::base::wire::ParseError,
264            > {
265                let start = bytes.as_ptr();
266                #(let (_, bytes) =
267                    <#tys as ::domain::new::base::wire::SplitBytesZC>
268                    ::split_bytes_by_ref(bytes)?;)*
269                let (last, rest) =
270                    <#unsized_ty as ::domain::new::base::wire::SplitBytesZC>
271                    ::split_bytes_by_ref(bytes)?;
272                let ptr =
273                    <#unsized_ty as ::domain::utils::dst::UnsizedCopy>
274                    ::ptr_with_addr(last, start as *const ());
275
276                // SAFETY:
277                // - The original 'bytes' contained a valid instance of every
278                //   field in 'Self', in succession.
279                // - Every field implements 'ParseBytesZC' and so has no
280                //   alignment restriction.
281                // - 'Self' is unaligned, since every field is unaligned, and
282                //   any explicit alignment modifiers only make it unaligned.
283                // - 'start' is thus the start of a valid instance of 'Self'.
284                // - 'ptr' has the same address as 'start' but can be cast to
285                //   'Self', since it has the right pointer metadata.
286                Ok((unsafe { &*(ptr as *const Self) }, rest))
287            }
288        });
289
290        Ok(skeleton.into_token_stream())
291    }
292
293    let input = syn::parse_macro_input!(input as syn::DeriveInput);
294    inner(input)
295        .unwrap_or_else(syn::Error::into_compile_error)
296        .into()
297}
298
299//----------- ParseBytesZC ---------------------------------------------------
300
301#[proc_macro_derive(ParseBytesZC)]
302pub fn derive_parse_bytes_zc(input: pm::TokenStream) -> pm::TokenStream {
303    fn inner(input: syn::DeriveInput) -> Result<TokenStream> {
304        let data = match &input.data {
305            syn::Data::Struct(data) => data,
306            syn::Data::Enum(data) => {
307                return Err(Error::new_spanned(
308                    data.enum_token,
309                    "'ParseBytesZC' can only be 'derive'd for 'struct's",
310                ));
311            }
312            syn::Data::Union(data) => {
313                return Err(Error::new_spanned(
314                    data.union_token,
315                    "'ParseBytesZC' can only be 'derive'd for 'struct's",
316                ));
317            }
318        };
319
320        let _ = Repr::determine(&input.attrs, "ParseBytesZC")?;
321
322        // Construct an 'ImplSkeleton' so that we can add trait bounds.
323        let mut skeleton = ImplSkeleton::new(&input, true);
324        skeleton.bound =
325            Some(syn::parse_quote!(::domain::new::base::wire::ParseBytesZC));
326
327        // Inspect the 'struct' fields.
328        let data = Struct::new_as_self(&data.fields);
329
330        // Establish bounds on the fields.
331        for field in data.sized_fields() {
332            skeleton.require_bound(
333                field.ty.clone(),
334                syn::parse_quote!(::domain::new::base::wire::SplitBytesZC),
335            );
336        }
337        if let Some(field) = data.unsized_field() {
338            skeleton.require_bound(
339                field.ty.clone(),
340                syn::parse_quote!(::domain::new::base::wire::ParseBytesZC),
341            );
342        }
343
344        // Finish early if the 'struct' has no fields.
345        if data.is_empty() {
346            skeleton.contents.stmts.push(syn::parse_quote! {
347                fn parse_bytes_by_ref(
348                    bytes: &[::domain::__core::primitive::u8],
349                ) -> ::domain::__core::result::Result<
350                    &Self,
351                    ::domain::new::base::wire::ParseError,
352                > {
353                    if bytes.is_empty() {
354                        // SAFETY: 'Self' is a 'struct' with no fields,
355                        // and so has size 0 and alignment 1.  It can be
356                        // constructed at any address.
357                        Ok(unsafe { &*bytes.as_ptr().cast::<Self>() })
358                    } else {
359                        Err(::domain::new::base::wire::ParseError)
360                    }
361                }
362            });
363
364            return Ok(skeleton.into_token_stream());
365        }
366
367        // Define 'parse_bytes_by_ref()'.
368        let tys = data.sized_fields().map(|f| &f.ty);
369        let unsized_ty = &data.unsized_field().unwrap().ty;
370        skeleton.contents.stmts.push(syn::parse_quote! {
371            fn parse_bytes_by_ref(
372                bytes: &[::domain::__core::primitive::u8],
373            ) -> ::domain::__core::result::Result<
374                &Self,
375                ::domain::new::base::wire::ParseError,
376            > {
377                let start = bytes.as_ptr();
378                #(let (_, bytes) =
379                    <#tys as ::domain::new::base::wire::SplitBytesZC>
380                    ::split_bytes_by_ref(bytes)?;)*
381                let last =
382                    <#unsized_ty as ::domain::new::base::wire::ParseBytesZC>
383                    ::parse_bytes_by_ref(bytes)?;
384                let ptr =
385                    <#unsized_ty as ::domain::utils::dst::UnsizedCopy>
386                    ::ptr_with_addr(last, start as *const ());
387
388                // SAFETY:
389                // - The original 'bytes' contained a valid instance of every
390                //   field in 'Self', in succession.
391                // - Every field implements 'ParseBytesZC' and so has no
392                //   alignment restriction.
393                // - 'Self' is unaligned, since every field is unaligned, and
394                //   any explicit alignment modifiers only make it unaligned.
395                // - 'start' is thus the start of a valid instance of 'Self'.
396                // - 'ptr' has the same address as 'start' but can be cast to
397                //   'Self', since it has the right pointer metadata.
398                Ok(unsafe { &*(ptr as *const Self) })
399            }
400        });
401
402        Ok(skeleton.into_token_stream())
403    }
404
405    let input = syn::parse_macro_input!(input as syn::DeriveInput);
406    inner(input)
407        .unwrap_or_else(syn::Error::into_compile_error)
408        .into()
409}
410
411//----------- BuildBytes -----------------------------------------------------
412
413#[proc_macro_derive(BuildBytes)]
414pub fn derive_build_bytes(input: pm::TokenStream) -> pm::TokenStream {
415    fn inner(input: syn::DeriveInput) -> Result<TokenStream> {
416        let data = match &input.data {
417            syn::Data::Struct(data) => data,
418            syn::Data::Enum(data) => {
419                return Err(Error::new_spanned(
420                    data.enum_token,
421                    "'BuildBytes' can only be 'derive'd for 'struct's",
422                ));
423            }
424            syn::Data::Union(data) => {
425                return Err(Error::new_spanned(
426                    data.union_token,
427                    "'BuildBytes' can only be 'derive'd for 'struct's",
428                ));
429            }
430        };
431
432        // Construct an 'ImplSkeleton' so that we can add trait bounds.
433        let mut skeleton = ImplSkeleton::new(&input, false);
434        skeleton.bound =
435            Some(syn::parse_quote!(::domain::new::base::wire::BuildBytes));
436
437        // Inspect the 'struct' fields.
438        let data = Struct::new_as_self(&data.fields);
439
440        // Get a lifetime for the input buffer.
441        let lifetime = skeleton.new_lifetime("bytes");
442
443        // Establish bounds on the fields.
444        for field in data.fields() {
445            skeleton.require_bound(
446                field.ty.clone(),
447                syn::parse_quote!(::domain::new::base::wire::BuildBytes),
448            );
449        }
450
451        // Define 'build_bytes()'.
452        let members = data.members();
453        let tys = data.fields().map(|f| &f.ty);
454        skeleton.contents.stmts.push(syn::parse_quote! {
455            fn build_bytes<#lifetime>(
456                &self,
457                mut bytes: & #lifetime mut [::domain::__core::primitive::u8],
458            ) -> ::domain::__core::result::Result<
459                & #lifetime mut [::domain::__core::primitive::u8],
460                ::domain::new::base::wire::TruncationError,
461            > {
462                #(bytes = <#tys as ::domain::new::base::wire::BuildBytes>
463                    ::build_bytes(&self.#members, bytes)?;)*
464                Ok(bytes)
465            }
466        });
467
468        // Define 'built_bytes_size()'.
469        let members = data.members();
470        let tys = data.fields().map(|f| &f.ty);
471        skeleton.contents.stmts.push(syn::parse_quote! {
472            fn built_bytes_size(&self) -> ::domain::__core::primitive::usize {
473                0 #(+ <#tys as ::domain::new::base::wire::BuildBytes>
474                        ::built_bytes_size(&self.#members))*
475            }
476        });
477
478        Ok(skeleton.into_token_stream())
479    }
480
481    let input = syn::parse_macro_input!(input as syn::DeriveInput);
482    inner(input)
483        .unwrap_or_else(syn::Error::into_compile_error)
484        .into()
485}
486
487//----------- AsBytes --------------------------------------------------------
488
489#[proc_macro_derive(AsBytes)]
490pub fn derive_as_bytes(input: pm::TokenStream) -> pm::TokenStream {
491    fn inner(input: syn::DeriveInput) -> Result<TokenStream> {
492        let data = match &input.data {
493            syn::Data::Struct(data) => data,
494            syn::Data::Enum(data) => {
495                return Err(Error::new_spanned(
496                    data.enum_token,
497                    "'AsBytes' can only be 'derive'd for 'struct's",
498                ));
499            }
500            syn::Data::Union(data) => {
501                return Err(Error::new_spanned(
502                    data.union_token,
503                    "'AsBytes' can only be 'derive'd for 'struct's",
504                ));
505            }
506        };
507
508        let _ = Repr::determine(&input.attrs, "AsBytes")?;
509
510        // Construct an 'ImplSkeleton' so that we can add trait bounds.
511        let mut skeleton = ImplSkeleton::new(&input, true);
512        skeleton.bound =
513            Some(syn::parse_quote!(::domain::new::base::wire::AsBytes));
514
515        // Establish bounds on the fields.
516        for field in data.fields.iter() {
517            skeleton.require_bound(
518                field.ty.clone(),
519                syn::parse_quote!(::domain::new::base::wire::AsBytes),
520            );
521        }
522
523        // The default implementation of 'as_bytes()' works perfectly.
524
525        Ok(skeleton.into_token_stream())
526    }
527
528    let input = syn::parse_macro_input!(input as syn::DeriveInput);
529    inner(input)
530        .unwrap_or_else(syn::Error::into_compile_error)
531        .into()
532}
533
534//----------- UnsizedCopy ----------------------------------------------------
535
536#[proc_macro_derive(UnsizedCopy)]
537pub fn derive_unsized_copy(input: pm::TokenStream) -> pm::TokenStream {
538    fn inner(input: syn::DeriveInput) -> Result<TokenStream> {
539        // Construct an 'ImplSkeleton' so that we can add trait bounds.
540        let mut skeleton = ImplSkeleton::new(&input, true);
541        skeleton.bound =
542            Some(syn::parse_quote!(::domain::utils::dst::UnsizedCopy));
543
544        let struct_data = match &input.data {
545            syn::Data::Struct(data) if !data.fields.is_empty() => {
546                let data = Struct::new_as_self(&data.fields);
547                for field in data.sized_fields() {
548                    skeleton.require_bound(
549                        field.ty.clone(),
550                        syn::parse_quote!(::domain::__core::marker::Copy),
551                    );
552                }
553
554                skeleton.require_bound(
555                    data.unsized_field().unwrap().ty.clone(),
556                    syn::parse_quote!(::domain::utils::dst::UnsizedCopy),
557                );
558
559                Some(data)
560            }
561
562            syn::Data::Struct(_) => None,
563
564            syn::Data::Enum(data) => {
565                for variant in data.variants.iter() {
566                    for field in variant.fields.iter() {
567                        skeleton.require_bound(
568                            field.ty.clone(),
569                            syn::parse_quote!(::domain::__core::marker::Copy),
570                        );
571                    }
572                }
573
574                None
575            }
576
577            syn::Data::Union(data) => {
578                for field in data.fields.named.iter() {
579                    skeleton.require_bound(
580                        field.ty.clone(),
581                        syn::parse_quote!(::domain::__core::marker::Copy),
582                    );
583                }
584
585                None
586            }
587        };
588
589        if let Some(data) = struct_data {
590            let sized_tys = data.sized_fields().map(|f| &f.ty);
591            let unsized_ty = &data.unsized_field().unwrap().ty;
592            let unsized_member = data.unsized_member().unwrap();
593
594            skeleton.contents.stmts.push(syn::parse_quote! {
595                type Alignment = (#(#sized_tys,)* <#unsized_ty as ::domain::utils::dst::UnsizedCopy>::Alignment,);
596            });
597
598            skeleton.contents.stmts.push(syn::parse_quote! {
599                fn ptr_with_addr(&self, addr: *const ()) -> *const Self {
600                    ::domain::utils::dst::UnsizedCopy::ptr_with_addr(
601                        &self.#unsized_member,
602                        addr,
603                    ) as *const Self
604                }
605            });
606        } else {
607            skeleton.contents.stmts.push(syn::parse_quote! {
608                type Alignment = Self;
609            });
610
611            skeleton.contents.stmts.push(syn::parse_quote! {
612                fn ptr_with_addr(&self, addr: *const ()) -> *const Self {
613                    addr as *const Self
614                }
615            });
616        }
617
618        Ok(skeleton.into_token_stream())
619    }
620
621    let input = syn::parse_macro_input!(input as syn::DeriveInput);
622    inner(input)
623        .unwrap_or_else(syn::Error::into_compile_error)
624        .into()
625}
626
627//----------- Utility Functions ----------------------------------------------
628
629/// Add a `field_` prefix to member names.
630fn field_prefixed(member: syn::Member) -> Ident {
631    format_ident!("field_{}", member)
632}