diff options
Diffstat (limited to 'rust/pin-init/internal')
-rw-r--r-- | rust/pin-init/internal/src/helpers.rs | 152 | ||||
-rw-r--r-- | rust/pin-init/internal/src/lib.rs | 48 | ||||
-rw-r--r-- | rust/pin-init/internal/src/pin_data.rs | 132 | ||||
-rw-r--r-- | rust/pin-init/internal/src/pinned_drop.rs | 52 | ||||
-rw-r--r-- | rust/pin-init/internal/src/zeroable.rs | 76 |
5 files changed, 460 insertions, 0 deletions
diff --git a/rust/pin-init/internal/src/helpers.rs b/rust/pin-init/internal/src/helpers.rs new file mode 100644 index 000000000000..236f989a50f2 --- /dev/null +++ b/rust/pin-init/internal/src/helpers.rs @@ -0,0 +1,152 @@ +// SPDX-License-Identifier: Apache-2.0 OR MIT + +#[cfg(not(kernel))] +use proc_macro2 as proc_macro; + +use proc_macro::{TokenStream, TokenTree}; + +/// Parsed generics. +/// +/// See the field documentation for an explanation what each of the fields represents. +/// +/// # Examples +/// +/// ```rust,ignore +/// # let input = todo!(); +/// let (Generics { decl_generics, impl_generics, ty_generics }, rest) = parse_generics(input); +/// quote! { +/// struct Foo<$($decl_generics)*> { +/// // ... +/// } +/// +/// impl<$impl_generics> Foo<$ty_generics> { +/// fn foo() { +/// // ... +/// } +/// } +/// } +/// ``` +pub(crate) struct Generics { + /// The generics with bounds and default values (e.g. `T: Clone, const N: usize = 0`). + /// + /// Use this on type definitions e.g. `struct Foo<$decl_generics> ...` (or `union`/`enum`). + pub(crate) decl_generics: Vec<TokenTree>, + /// The generics with bounds (e.g. `T: Clone, const N: usize`). + /// + /// Use this on `impl` blocks e.g. `impl<$impl_generics> Trait for ...`. + pub(crate) impl_generics: Vec<TokenTree>, + /// The generics without bounds and without default values (e.g. `T, N`). + /// + /// Use this when you use the type that is declared with these generics e.g. + /// `Foo<$ty_generics>`. + pub(crate) ty_generics: Vec<TokenTree>, +} + +/// Parses the given `TokenStream` into `Generics` and the rest. +/// +/// The generics are not present in the rest, but a where clause might remain. +pub(crate) fn parse_generics(input: TokenStream) -> (Generics, Vec<TokenTree>) { + // The generics with bounds and default values. + let mut decl_generics = vec![]; + // `impl_generics`, the declared generics with their bounds. + let mut impl_generics = vec![]; + // Only the names of the generics, without any bounds. + let mut ty_generics = vec![]; + // Tokens not related to the generics e.g. the `where` token and definition. + let mut rest = vec![]; + // The current level of `<`. + let mut nesting = 0; + let mut toks = input.into_iter(); + // If we are at the beginning of a generic parameter. + let mut at_start = true; + let mut skip_until_comma = false; + while let Some(tt) = toks.next() { + if nesting == 1 && matches!(&tt, TokenTree::Punct(p) if p.as_char() == '>') { + // Found the end of the generics. + break; + } else if nesting >= 1 { + decl_generics.push(tt.clone()); + } + match tt.clone() { + TokenTree::Punct(p) if p.as_char() == '<' => { + if nesting >= 1 && !skip_until_comma { + // This is inside of the generics and part of some bound. + impl_generics.push(tt); + } + nesting += 1; + } + TokenTree::Punct(p) if p.as_char() == '>' => { + // This is a parsing error, so we just end it here. + if nesting == 0 { + break; + } else { + nesting -= 1; + if nesting >= 1 && !skip_until_comma { + // We are still inside of the generics and part of some bound. + impl_generics.push(tt); + } + } + } + TokenTree::Punct(p) if skip_until_comma && p.as_char() == ',' => { + if nesting == 1 { + impl_generics.push(tt.clone()); + impl_generics.push(tt); + skip_until_comma = false; + } + } + _ if !skip_until_comma => { + match nesting { + // If we haven't entered the generics yet, we still want to keep these tokens. + 0 => rest.push(tt), + 1 => { + // Here depending on the token, it might be a generic variable name. + match tt.clone() { + TokenTree::Ident(i) if at_start && i.to_string() == "const" => { + let Some(name) = toks.next() else { + // Parsing error. + break; + }; + impl_generics.push(tt); + impl_generics.push(name.clone()); + ty_generics.push(name.clone()); + decl_generics.push(name); + at_start = false; + } + TokenTree::Ident(_) if at_start => { + impl_generics.push(tt.clone()); + ty_generics.push(tt); + at_start = false; + } + TokenTree::Punct(p) if p.as_char() == ',' => { + impl_generics.push(tt.clone()); + ty_generics.push(tt); + at_start = true; + } + // Lifetimes begin with `'`. + TokenTree::Punct(p) if p.as_char() == '\'' && at_start => { + impl_generics.push(tt.clone()); + ty_generics.push(tt); + } + // Generics can have default values, we skip these. + TokenTree::Punct(p) if p.as_char() == '=' => { + skip_until_comma = true; + } + _ => impl_generics.push(tt), + } + } + _ => impl_generics.push(tt), + } + } + _ => {} + } + } + rest.extend(toks); + ( + Generics { + impl_generics, + decl_generics, + ty_generics, + }, + rest, + ) +} diff --git a/rust/pin-init/internal/src/lib.rs b/rust/pin-init/internal/src/lib.rs new file mode 100644 index 000000000000..babe5e878550 --- /dev/null +++ b/rust/pin-init/internal/src/lib.rs @@ -0,0 +1,48 @@ +// SPDX-License-Identifier: Apache-2.0 OR MIT + +// When fixdep scans this, it will find this string `CONFIG_RUSTC_VERSION_TEXT` +// and thus add a dependency on `include/config/RUSTC_VERSION_TEXT`, which is +// touched by Kconfig when the version string from the compiler changes. + +//! `pin-init` proc macros. + +#![cfg_attr(not(RUSTC_LINT_REASONS_IS_STABLE), feature(lint_reasons))] +// Allow `.into()` to convert +// - `proc_macro2::TokenStream` into `proc_macro::TokenStream` in the user-space version. +// - `proc_macro::TokenStream` into `proc_macro::TokenStream` in the kernel version. +// Clippy warns on this conversion, but it's required by the user-space version. +// +// Remove once we have `proc_macro2` in the kernel. +#![allow(clippy::useless_conversion)] +// Documentation is done in the pin-init crate instead. +#![allow(missing_docs)] + +use proc_macro::TokenStream; + +#[cfg(kernel)] +#[path = "../../../macros/quote.rs"] +#[macro_use] +mod quote; +#[cfg(not(kernel))] +#[macro_use] +extern crate quote; + +mod helpers; +mod pin_data; +mod pinned_drop; +mod zeroable; + +#[proc_macro_attribute] +pub fn pin_data(inner: TokenStream, item: TokenStream) -> TokenStream { + pin_data::pin_data(inner.into(), item.into()).into() +} + +#[proc_macro_attribute] +pub fn pinned_drop(args: TokenStream, input: TokenStream) -> TokenStream { + pinned_drop::pinned_drop(args.into(), input.into()).into() +} + +#[proc_macro_derive(Zeroable)] +pub fn derive_zeroable(input: TokenStream) -> TokenStream { + zeroable::derive(input.into()).into() +} diff --git a/rust/pin-init/internal/src/pin_data.rs b/rust/pin-init/internal/src/pin_data.rs new file mode 100644 index 000000000000..87d4a7eb1d35 --- /dev/null +++ b/rust/pin-init/internal/src/pin_data.rs @@ -0,0 +1,132 @@ +// SPDX-License-Identifier: Apache-2.0 OR MIT + +#[cfg(not(kernel))] +use proc_macro2 as proc_macro; + +use crate::helpers::{parse_generics, Generics}; +use proc_macro::{Group, Punct, Spacing, TokenStream, TokenTree}; + +pub(crate) fn pin_data(args: TokenStream, input: TokenStream) -> TokenStream { + // This proc-macro only does some pre-parsing and then delegates the actual parsing to + // `pin_init::__pin_data!`. + + let ( + Generics { + impl_generics, + decl_generics, + ty_generics, + }, + rest, + ) = parse_generics(input); + // The struct definition might contain the `Self` type. Since `__pin_data!` will define a new + // type with the same generics and bounds, this poses a problem, since `Self` will refer to the + // new type as opposed to this struct definition. Therefore we have to replace `Self` with the + // concrete name. + + // Errors that occur when replacing `Self` with `struct_name`. + let mut errs = TokenStream::new(); + // The name of the struct with ty_generics. + let struct_name = rest + .iter() + .skip_while(|tt| !matches!(tt, TokenTree::Ident(i) if i.to_string() == "struct")) + .nth(1) + .and_then(|tt| match tt { + TokenTree::Ident(_) => { + let tt = tt.clone(); + let mut res = vec![tt]; + if !ty_generics.is_empty() { + // We add this, so it is maximally compatible with e.g. `Self::CONST` which + // will be replaced by `StructName::<$generics>::CONST`. + res.push(TokenTree::Punct(Punct::new(':', Spacing::Joint))); + res.push(TokenTree::Punct(Punct::new(':', Spacing::Alone))); + res.push(TokenTree::Punct(Punct::new('<', Spacing::Alone))); + res.extend(ty_generics.iter().cloned()); + res.push(TokenTree::Punct(Punct::new('>', Spacing::Alone))); + } + Some(res) + } + _ => None, + }) + .unwrap_or_else(|| { + // If we did not find the name of the struct then we will use `Self` as the replacement + // and add a compile error to ensure it does not compile. + errs.extend( + "::core::compile_error!(\"Could not locate type name.\");" + .parse::<TokenStream>() + .unwrap(), + ); + "Self".parse::<TokenStream>().unwrap().into_iter().collect() + }); + let impl_generics = impl_generics + .into_iter() + .flat_map(|tt| replace_self_and_deny_type_defs(&struct_name, tt, &mut errs)) + .collect::<Vec<_>>(); + let mut rest = rest + .into_iter() + .flat_map(|tt| { + // We ignore top level `struct` tokens, since they would emit a compile error. + if matches!(&tt, TokenTree::Ident(i) if i.to_string() == "struct") { + vec![tt] + } else { + replace_self_and_deny_type_defs(&struct_name, tt, &mut errs) + } + }) + .collect::<Vec<_>>(); + // This should be the body of the struct `{...}`. + let last = rest.pop(); + let mut quoted = quote!(::pin_init::__pin_data! { + parse_input: + @args(#args), + @sig(#(#rest)*), + @impl_generics(#(#impl_generics)*), + @ty_generics(#(#ty_generics)*), + @decl_generics(#(#decl_generics)*), + @body(#last), + }); + quoted.extend(errs); + quoted +} + +/// Replaces `Self` with `struct_name` and errors on `enum`, `trait`, `struct` `union` and `impl` +/// keywords. +/// +/// The error is appended to `errs` to allow normal parsing to continue. +fn replace_self_and_deny_type_defs( + struct_name: &Vec<TokenTree>, + tt: TokenTree, + errs: &mut TokenStream, +) -> Vec<TokenTree> { + match tt { + TokenTree::Ident(ref i) + if i.to_string() == "enum" + || i.to_string() == "trait" + || i.to_string() == "struct" + || i.to_string() == "union" + || i.to_string() == "impl" => + { + errs.extend( + format!( + "::core::compile_error!(\"Cannot use `{i}` inside of struct definition with \ + `#[pin_data]`.\");" + ) + .parse::<TokenStream>() + .unwrap() + .into_iter() + .map(|mut tok| { + tok.set_span(tt.span()); + tok + }), + ); + vec![tt] + } + TokenTree::Ident(i) if i.to_string() == "Self" => struct_name.clone(), + TokenTree::Literal(_) | TokenTree::Punct(_) | TokenTree::Ident(_) => vec![tt], + TokenTree::Group(g) => vec![TokenTree::Group(Group::new( + g.delimiter(), + g.stream() + .into_iter() + .flat_map(|tt| replace_self_and_deny_type_defs(struct_name, tt, errs)) + .collect(), + ))], + } +} diff --git a/rust/pin-init/internal/src/pinned_drop.rs b/rust/pin-init/internal/src/pinned_drop.rs new file mode 100644 index 000000000000..c824dd8b436d --- /dev/null +++ b/rust/pin-init/internal/src/pinned_drop.rs @@ -0,0 +1,52 @@ +// SPDX-License-Identifier: Apache-2.0 OR MIT + +#[cfg(not(kernel))] +use proc_macro2 as proc_macro; + +use proc_macro::{TokenStream, TokenTree}; + +pub(crate) fn pinned_drop(_args: TokenStream, input: TokenStream) -> TokenStream { + let mut toks = input.into_iter().collect::<Vec<_>>(); + assert!(!toks.is_empty()); + // Ensure that we have an `impl` item. + assert!(matches!(&toks[0], TokenTree::Ident(i) if i.to_string() == "impl")); + // Ensure that we are implementing `PinnedDrop`. + let mut nesting: usize = 0; + let mut pinned_drop_idx = None; + for (i, tt) in toks.iter().enumerate() { + match tt { + TokenTree::Punct(p) if p.as_char() == '<' => { + nesting += 1; + } + TokenTree::Punct(p) if p.as_char() == '>' => { + nesting = nesting.checked_sub(1).unwrap(); + continue; + } + _ => {} + } + if i >= 1 && nesting == 0 { + // Found the end of the generics, this should be `PinnedDrop`. + assert!( + matches!(tt, TokenTree::Ident(i) if i.to_string() == "PinnedDrop"), + "expected 'PinnedDrop', found: '{:?}'", + tt + ); + pinned_drop_idx = Some(i); + break; + } + } + let idx = pinned_drop_idx + .unwrap_or_else(|| panic!("Expected an `impl` block implementing `PinnedDrop`.")); + // Fully qualify the `PinnedDrop`, as to avoid any tampering. + toks.splice(idx..idx, quote!(::pin_init::)); + // Take the `{}` body and call the declarative macro. + if let Some(TokenTree::Group(last)) = toks.pop() { + let last = last.stream(); + quote!(::pin_init::__pinned_drop! { + @impl_sig(#(#toks)*), + @impl_body(#last), + }) + } else { + TokenStream::from_iter(toks) + } +} diff --git a/rust/pin-init/internal/src/zeroable.rs b/rust/pin-init/internal/src/zeroable.rs new file mode 100644 index 000000000000..acc94008c152 --- /dev/null +++ b/rust/pin-init/internal/src/zeroable.rs @@ -0,0 +1,76 @@ +// SPDX-License-Identifier: GPL-2.0 + +#[cfg(not(kernel))] +use proc_macro2 as proc_macro; + +use crate::helpers::{parse_generics, Generics}; +use proc_macro::{TokenStream, TokenTree}; + +pub(crate) fn derive(input: TokenStream) -> TokenStream { + let ( + Generics { + impl_generics, + decl_generics: _, + ty_generics, + }, + mut rest, + ) = parse_generics(input); + // This should be the body of the struct `{...}`. + let last = rest.pop(); + // Now we insert `Zeroable` as a bound for every generic parameter in `impl_generics`. + let mut new_impl_generics = Vec::with_capacity(impl_generics.len()); + // Are we inside of a generic where we want to add `Zeroable`? + let mut in_generic = !impl_generics.is_empty(); + // Have we already inserted `Zeroable`? + let mut inserted = false; + // Level of `<>` nestings. + let mut nested = 0; + for tt in impl_generics { + match &tt { + // If we find a `,`, then we have finished a generic/constant/lifetime parameter. + TokenTree::Punct(p) if nested == 0 && p.as_char() == ',' => { + if in_generic && !inserted { + new_impl_generics.extend(quote! { : ::pin_init::Zeroable }); + } + in_generic = true; + inserted = false; + new_impl_generics.push(tt); + } + // If we find `'`, then we are entering a lifetime. + TokenTree::Punct(p) if nested == 0 && p.as_char() == '\'' => { + in_generic = false; + new_impl_generics.push(tt); + } + TokenTree::Punct(p) if nested == 0 && p.as_char() == ':' => { + new_impl_generics.push(tt); + if in_generic { + new_impl_generics.extend(quote! { ::pin_init::Zeroable + }); + inserted = true; + } + } + TokenTree::Punct(p) if p.as_char() == '<' => { + nested += 1; + new_impl_generics.push(tt); + } + TokenTree::Punct(p) if p.as_char() == '>' => { + assert!(nested > 0); + nested -= 1; + new_impl_generics.push(tt); + } + _ => new_impl_generics.push(tt), + } + } + assert_eq!(nested, 0); + if in_generic && !inserted { + new_impl_generics.extend(quote! { : ::pin_init::Zeroable }); + } + quote! { + ::pin_init::__derive_zeroable!( + parse_input: + @sig(#(#rest)*), + @impl_generics(#(#new_impl_generics)*), + @ty_generics(#(#ty_generics)*), + @body(#last), + ); + } +} |