Contributors: 2
Author Tokens Token Proportion Commits Commit Proportion
Benno Lossin 3234 99.97% 16 94.12%
Gary Guo 1 0.03% 1 5.88%
Total 3235 17


// SPDX-License-Identifier: Apache-2.0 OR MIT

use proc_macro2::{Span, TokenStream};
use quote::{format_ident, quote, quote_spanned};
use syn::{
    braced,
    parse::{End, Parse},
    parse_quote,
    punctuated::Punctuated,
    spanned::Spanned,
    token, Attribute, Block, Expr, ExprCall, ExprPath, Ident, Path, Token, Type,
};

use crate::diagnostics::{DiagCtxt, ErrorGuaranteed};

pub(crate) struct Initializer {
    attrs: Vec<InitializerAttribute>,
    this: Option<This>,
    path: Path,
    brace_token: token::Brace,
    fields: Punctuated<InitializerField, Token![,]>,
    rest: Option<(Token![..], Expr)>,
    error: Option<(Token![?], Type)>,
}

struct This {
    _and_token: Token![&],
    ident: Ident,
    _in_token: Token![in],
}

struct InitializerField {
    attrs: Vec<Attribute>,
    kind: InitializerKind,
}

enum InitializerKind {
    Value {
        ident: Ident,
        value: Option<(Token![:], Expr)>,
    },
    Init {
        ident: Ident,
        _left_arrow_token: Token![<-],
        value: Expr,
    },
    Code {
        _underscore_token: Token![_],
        _colon_token: Token![:],
        block: Block,
    },
}

impl InitializerKind {
    fn ident(&self) -> Option<&Ident> {
        match self {
            Self::Value { ident, .. } | Self::Init { ident, .. } => Some(ident),
            Self::Code { .. } => None,
        }
    }
}

enum InitializerAttribute {
    DefaultError(DefaultErrorAttribute),
}

struct DefaultErrorAttribute {
    ty: Box<Type>,
}

pub(crate) fn expand(
    Initializer {
        attrs,
        this,
        path,
        brace_token,
        fields,
        rest,
        error,
    }: Initializer,
    default_error: Option<&'static str>,
    pinned: bool,
    dcx: &mut DiagCtxt,
) -> Result<TokenStream, ErrorGuaranteed> {
    let error = error.map_or_else(
        || {
            if let Some(default_error) = attrs.iter().fold(None, |acc, attr| {
                #[expect(irrefutable_let_patterns)]
                if let InitializerAttribute::DefaultError(DefaultErrorAttribute { ty }) = attr {
                    Some(ty.clone())
                } else {
                    acc
                }
            }) {
                default_error
            } else if let Some(default_error) = default_error {
                syn::parse_str(default_error).unwrap()
            } else {
                dcx.error(brace_token.span.close(), "expected `? <type>` after `}`");
                parse_quote!(::core::convert::Infallible)
            }
        },
        |(_, err)| Box::new(err),
    );
    let slot = format_ident!("slot");
    let (has_data_trait, data_trait, get_data, init_from_closure) = if pinned {
        (
            format_ident!("HasPinData"),
            format_ident!("PinData"),
            format_ident!("__pin_data"),
            format_ident!("pin_init_from_closure"),
        )
    } else {
        (
            format_ident!("HasInitData"),
            format_ident!("InitData"),
            format_ident!("__init_data"),
            format_ident!("init_from_closure"),
        )
    };
    let init_kind = get_init_kind(rest, dcx);
    let zeroable_check = match init_kind {
        InitKind::Normal => quote!(),
        InitKind::Zeroing => quote! {
            // The user specified `..Zeroable::zeroed()` at the end of the list of fields.
            // Therefore we check if the struct implements `Zeroable` and then zero the memory.
            // This allows us to also remove the check that all fields are present (since we
            // already set the memory to zero and that is a valid bit pattern).
            fn assert_zeroable<T: ?::core::marker::Sized>(_: *mut T)
            where T: ::pin_init::Zeroable
            {}
            // Ensure that the struct is indeed `Zeroable`.
            assert_zeroable(#slot);
            // SAFETY: The type implements `Zeroable` by the check above.
            unsafe { ::core::ptr::write_bytes(#slot, 0, 1) };
        },
    };
    let this = match this {
        None => quote!(),
        Some(This { ident, .. }) => quote! {
            // Create the `this` so it can be referenced by the user inside of the
            // expressions creating the individual fields.
            let #ident = unsafe { ::core::ptr::NonNull::new_unchecked(slot) };
        },
    };
    // `mixed_site` ensures that the data is not accessible to the user-controlled code.
    let data = Ident::new("__data", Span::mixed_site());
    let init_fields = init_fields(&fields, pinned, &data, &slot);
    let field_check = make_field_check(&fields, init_kind, &path);
    Ok(quote! {{
        // Get the data about fields from the supplied type.
        // SAFETY: TODO
        let #data = unsafe {
            use ::pin_init::__internal::#has_data_trait;
            // Can't use `<#path as #has_data_trait>::#get_data`, since the user is able to omit
            // generics (which need to be present with that syntax).
            #path::#get_data()
        };
        // Ensure that `#data` really is of type `#data` and help with type inference:
        let init = ::pin_init::__internal::#data_trait::make_closure::<_, #error>(
            #data,
            move |slot| {
                #zeroable_check
                #this
                #init_fields
                #field_check
                // SAFETY: we are the `init!` macro that is allowed to call this.
                Ok(unsafe { ::pin_init::__internal::InitOk::new() })
            }
        );
        let init = move |slot| -> ::core::result::Result<(), #error> {
            init(slot).map(|__InitOk| ())
        };
        // SAFETY: TODO
        let init = unsafe { ::pin_init::#init_from_closure::<_, #error>(init) };
        init
    }})
}

enum InitKind {
    Normal,
    Zeroing,
}

fn get_init_kind(rest: Option<(Token![..], Expr)>, dcx: &mut DiagCtxt) -> InitKind {
    let Some((dotdot, expr)) = rest else {
        return InitKind::Normal;
    };
    match &expr {
        Expr::Call(ExprCall { func, args, .. }) if args.is_empty() => match &**func {
            Expr::Path(ExprPath {
                attrs,
                qself: None,
                path:
                    Path {
                        leading_colon: None,
                        segments,
                    },
            }) if attrs.is_empty()
                && segments.len() == 2
                && segments[0].ident == "Zeroable"
                && segments[0].arguments.is_none()
                && segments[1].ident == "init_zeroed"
                && segments[1].arguments.is_none() =>
            {
                return InitKind::Zeroing;
            }
            _ => {}
        },
        _ => {}
    }
    dcx.error(
        dotdot.span().join(expr.span()).unwrap_or(expr.span()),
        "expected nothing or `..Zeroable::init_zeroed()`.",
    );
    InitKind::Normal
}

/// Generate the code that initializes the fields of the struct using the initializers in `field`.
fn init_fields(
    fields: &Punctuated<InitializerField, Token![,]>,
    pinned: bool,
    data: &Ident,
    slot: &Ident,
) -> TokenStream {
    let mut guards = vec![];
    let mut guard_attrs = vec![];
    let mut res = TokenStream::new();
    for InitializerField { attrs, kind } in fields {
        let cfgs = {
            let mut cfgs = attrs.clone();
            cfgs.retain(|attr| attr.path().is_ident("cfg"));
            cfgs
        };
        let init = match kind {
            InitializerKind::Value { ident, value } => {
                let mut value_ident = ident.clone();
                let value_prep = value.as_ref().map(|value| &value.1).map(|value| {
                    // Setting the span of `value_ident` to `value`'s span improves error messages
                    // when the type of `value` is wrong.
                    value_ident.set_span(value.span());
                    quote!(let #value_ident = #value;)
                });
                // Again span for better diagnostics
                let write = quote_spanned!(ident.span()=> ::core::ptr::write);
                // NOTE: the field accessor ensures that the initialized field is properly aligned.
                // Unaligned fields will cause the compiler to emit E0793. We do not support
                // unaligned fields since `Init::__init` requires an aligned pointer; the call to
                // `ptr::write` below has the same requirement.
                let accessor = if pinned {
                    let project_ident = format_ident!("__project_{ident}");
                    quote! {
                        // SAFETY: TODO
                        unsafe { #data.#project_ident(&mut (*#slot).#ident) }
                    }
                } else {
                    quote! {
                        // SAFETY: TODO
                        unsafe { &mut (*#slot).#ident }
                    }
                };
                quote! {
                    #(#attrs)*
                    {
                        #value_prep
                        // SAFETY: TODO
                        unsafe { #write(::core::ptr::addr_of_mut!((*#slot).#ident), #value_ident) };
                    }
                    #(#cfgs)*
                    #[allow(unused_variables)]
                    let #ident = #accessor;
                }
            }
            InitializerKind::Init { ident, value, .. } => {
                // Again span for better diagnostics
                let init = format_ident!("init", span = value.span());
                // NOTE: the field accessor ensures that the initialized field is properly aligned.
                // Unaligned fields will cause the compiler to emit E0793. We do not support
                // unaligned fields since `Init::__init` requires an aligned pointer; the call to
                // `ptr::write` below has the same requirement.
                let (value_init, accessor) = if pinned {
                    let project_ident = format_ident!("__project_{ident}");
                    (
                        quote! {
                            // SAFETY:
                            // - `slot` is valid, because we are inside of an initializer closure, we
                            //   return when an error/panic occurs.
                            // - We also use `#data` to require the correct trait (`Init` or `PinInit`)
                            //   for `#ident`.
                            unsafe { #data.#ident(::core::ptr::addr_of_mut!((*#slot).#ident), #init)? };
                        },
                        quote! {
                            // SAFETY: TODO
                            unsafe { #data.#project_ident(&mut (*#slot).#ident) }
                        },
                    )
                } else {
                    (
                        quote! {
                            // SAFETY: `slot` is valid, because we are inside of an initializer
                            // closure, we return when an error/panic occurs.
                            unsafe {
                                ::pin_init::Init::__init(
                                    #init,
                                    ::core::ptr::addr_of_mut!((*#slot).#ident),
                                )?
                            };
                        },
                        quote! {
                            // SAFETY: TODO
                            unsafe { &mut (*#slot).#ident }
                        },
                    )
                };
                quote! {
                    #(#attrs)*
                    {
                        let #init = #value;
                        #value_init
                    }
                    #(#cfgs)*
                    #[allow(unused_variables)]
                    let #ident = #accessor;
                }
            }
            InitializerKind::Code { block: value, .. } => quote! {
                #(#attrs)*
                #[allow(unused_braces)]
                #value
            },
        };
        res.extend(init);
        if let Some(ident) = kind.ident() {
            // `mixed_site` ensures that the guard is not accessible to the user-controlled code.
            let guard = format_ident!("__{ident}_guard", span = Span::mixed_site());
            res.extend(quote! {
                #(#cfgs)*
                // Create the drop guard:
                //
                // We rely on macro hygiene to make it impossible for users to access this local
                // variable.
                // SAFETY: We forget the guard later when initialization has succeeded.
                let #guard = unsafe {
                    ::pin_init::__internal::DropGuard::new(
                        ::core::ptr::addr_of_mut!((*slot).#ident)
                    )
                };
            });
            guards.push(guard);
            guard_attrs.push(cfgs);
        }
    }
    quote! {
        #res
        // If execution reaches this point, all fields have been initialized. Therefore we can now
        // dismiss the guards by forgetting them.
        #(
            #(#guard_attrs)*
            ::core::mem::forget(#guards);
        )*
    }
}

/// Generate the check for ensuring that every field has been initialized.
fn make_field_check(
    fields: &Punctuated<InitializerField, Token![,]>,
    init_kind: InitKind,
    path: &Path,
) -> TokenStream {
    let field_attrs = fields
        .iter()
        .filter_map(|f| f.kind.ident().map(|_| &f.attrs));
    let field_name = fields.iter().filter_map(|f| f.kind.ident());
    match init_kind {
        InitKind::Normal => quote! {
            // We use unreachable code to ensure that all fields have been mentioned exactly once,
            // this struct initializer will still be type-checked and complain with a very natural
            // error message if a field is forgotten/mentioned more than once.
            #[allow(unreachable_code, clippy::diverging_sub_expression)]
            // SAFETY: this code is never executed.
            let _ = || unsafe {
                ::core::ptr::write(slot, #path {
                    #(
                        #(#field_attrs)*
                        #field_name: ::core::panic!(),
                    )*
                })
            };
        },
        InitKind::Zeroing => quote! {
            // We use unreachable code to ensure that all fields have been mentioned at most once.
            // Since the user specified `..Zeroable::zeroed()` at the end, all missing fields will
            // be zeroed. This struct initializer will still be type-checked and complain with a
            // very natural error message if a field is mentioned more than once, or doesn't exist.
            #[allow(unreachable_code, clippy::diverging_sub_expression, unused_assignments)]
            // SAFETY: this code is never executed.
            let _ = || unsafe {
                ::core::ptr::write(slot, #path {
                    #(
                        #(#field_attrs)*
                        #field_name: ::core::panic!(),
                    )*
                    ..::core::mem::zeroed()
                })
            };
        },
    }
}

impl Parse for Initializer {
    fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
        let attrs = input.call(Attribute::parse_outer)?;
        let this = input.peek(Token![&]).then(|| input.parse()).transpose()?;
        let path = input.parse()?;
        let content;
        let brace_token = braced!(content in input);
        let mut fields = Punctuated::new();
        loop {
            let lh = content.lookahead1();
            if lh.peek(End) || lh.peek(Token![..]) {
                break;
            } else if lh.peek(Ident) || lh.peek(Token![_]) || lh.peek(Token![#]) {
                fields.push_value(content.parse()?);
                let lh = content.lookahead1();
                if lh.peek(End) {
                    break;
                } else if lh.peek(Token![,]) {
                    fields.push_punct(content.parse()?);
                } else {
                    return Err(lh.error());
                }
            } else {
                return Err(lh.error());
            }
        }
        let rest = content
            .peek(Token![..])
            .then(|| Ok::<_, syn::Error>((content.parse()?, content.parse()?)))
            .transpose()?;
        let error = input
            .peek(Token![?])
            .then(|| Ok::<_, syn::Error>((input.parse()?, input.parse()?)))
            .transpose()?;
        let attrs = attrs
            .into_iter()
            .map(|a| {
                if a.path().is_ident("default_error") {
                    a.parse_args::<DefaultErrorAttribute>()
                        .map(InitializerAttribute::DefaultError)
                } else {
                    Err(syn::Error::new_spanned(a, "unknown initializer attribute"))
                }
            })
            .collect::<Result<Vec<_>, _>>()?;
        Ok(Self {
            attrs,
            this,
            path,
            brace_token,
            fields,
            rest,
            error,
        })
    }
}

impl Parse for DefaultErrorAttribute {
    fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
        Ok(Self { ty: input.parse()? })
    }
}

impl Parse for This {
    fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
        Ok(Self {
            _and_token: input.parse()?,
            ident: input.parse()?,
            _in_token: input.parse()?,
        })
    }
}

impl Parse for InitializerField {
    fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
        let attrs = input.call(Attribute::parse_outer)?;
        Ok(Self {
            attrs,
            kind: input.parse()?,
        })
    }
}

impl Parse for InitializerKind {
    fn parse(input: syn::parse::ParseStream<'_>) -> syn::Result<Self> {
        let lh = input.lookahead1();
        if lh.peek(Token![_]) {
            Ok(Self::Code {
                _underscore_token: input.parse()?,
                _colon_token: input.parse()?,
                block: input.parse()?,
            })
        } else if lh.peek(Ident) {
            let ident = input.parse()?;
            let lh = input.lookahead1();
            if lh.peek(Token![<-]) {
                Ok(Self::Init {
                    ident,
                    _left_arrow_token: input.parse()?,
                    value: input.parse()?,
                })
            } else if lh.peek(Token![:]) {
                Ok(Self::Value {
                    ident,
                    value: Some((input.parse()?, input.parse()?)),
                })
            } else if lh.peek(Token![,]) || lh.peek(End) {
                Ok(Self::Value { ident, value: None })
            } else {
                Err(lh.error())
            }
        } else {
            Err(lh.error())
        }
    }
}