diff options
author | Benno Lossin <benno.lossin@proton.me> | 2025-03-08 11:04:18 +0000 |
---|---|---|
committer | Miguel Ojeda <ojeda@kernel.org> | 2025-03-16 21:59:18 +0100 |
commit | 84837cf6fa541150a3012ea233225a7ecfa8771a (patch) | |
tree | 3b5f5d06d58b42cd86ad123937c1fbc59447ac60 /rust/pin-init/examples/pthread_mutex.rs | |
parent | 4b11798e82d6f340c2afc94c57823b6fbc109fad (diff) | |
download | linux-84837cf6fa541150a3012ea233225a7ecfa8771a.tar.gz linux-84837cf6fa541150a3012ea233225a7ecfa8771a.tar.bz2 linux-84837cf6fa541150a3012ea233225a7ecfa8771a.zip |
rust: pin-init: change examples to the user-space version
Replace the examples in the documentation by the ones from the
user-space version and introduce the standalone examples from the
user-space version such as the `CMutex<T>` type.
The `CMutex<T>` example from the pinned-init repository [1] is used in
several documentation examples in the user-space version instead of the
kernel `Mutex<T>` type (as it's not available). In order to split off
the pin-init crate, all examples need to be free of kernel-specific
types.
Link: https://github.com/rust-for-Linux/pinned-init [1]
Signed-off-by: Benno Lossin <benno.lossin@proton.me>
Reviewed-by: Fiona Behrens <me@kloenk.dev>
Tested-by: Andreas Hindborg <a.hindborg@kernel.org>
Link: https://lore.kernel.org/r/20250308110339.2997091-6-benno.lossin@proton.me
Signed-off-by: Miguel Ojeda <ojeda@kernel.org>
Diffstat (limited to 'rust/pin-init/examples/pthread_mutex.rs')
-rw-r--r-- | rust/pin-init/examples/pthread_mutex.rs | 178 |
1 files changed, 178 insertions, 0 deletions
diff --git a/rust/pin-init/examples/pthread_mutex.rs b/rust/pin-init/examples/pthread_mutex.rs new file mode 100644 index 000000000000..9164298c44c0 --- /dev/null +++ b/rust/pin-init/examples/pthread_mutex.rs @@ -0,0 +1,178 @@ +// SPDX-License-Identifier: Apache-2.0 OR MIT + +// inspired by https://github.com/nbdd0121/pin-init/blob/trunk/examples/pthread_mutex.rs +#![allow(clippy::undocumented_unsafe_blocks)] +#![cfg_attr(feature = "alloc", feature(allocator_api))] +#[cfg(not(windows))] +mod pthread_mtx { + #[cfg(feature = "alloc")] + use core::alloc::AllocError; + use core::{ + cell::UnsafeCell, + marker::PhantomPinned, + mem::MaybeUninit, + ops::{Deref, DerefMut}, + pin::Pin, + }; + use pin_init::*; + use std::convert::Infallible; + + #[pin_data(PinnedDrop)] + pub struct PThreadMutex<T> { + #[pin] + raw: UnsafeCell<libc::pthread_mutex_t>, + data: UnsafeCell<T>, + #[pin] + pin: PhantomPinned, + } + + unsafe impl<T: Send> Send for PThreadMutex<T> {} + unsafe impl<T: Send> Sync for PThreadMutex<T> {} + + #[pinned_drop] + impl<T> PinnedDrop for PThreadMutex<T> { + fn drop(self: Pin<&mut Self>) { + unsafe { + libc::pthread_mutex_destroy(self.raw.get()); + } + } + } + + #[derive(Debug)] + pub enum Error { + #[expect(dead_code)] + IO(std::io::Error), + Alloc, + } + + impl From<Infallible> for Error { + fn from(e: Infallible) -> Self { + match e {} + } + } + + #[cfg(feature = "alloc")] + impl From<AllocError> for Error { + fn from(_: AllocError) -> Self { + Self::Alloc + } + } + + impl<T> PThreadMutex<T> { + pub fn new(data: T) -> impl PinInit<Self, Error> { + fn init_raw() -> impl PinInit<UnsafeCell<libc::pthread_mutex_t>, Error> { + let init = |slot: *mut UnsafeCell<libc::pthread_mutex_t>| { + // we can cast, because `UnsafeCell` has the same layout as T. + let slot: *mut libc::pthread_mutex_t = slot.cast(); + let mut attr = MaybeUninit::uninit(); + let attr = attr.as_mut_ptr(); + // SAFETY: ptr is valid + let ret = unsafe { libc::pthread_mutexattr_init(attr) }; + if ret != 0 { + return Err(Error::IO(std::io::Error::from_raw_os_error(ret))); + } + // SAFETY: attr is initialized + let ret = unsafe { + libc::pthread_mutexattr_settype(attr, libc::PTHREAD_MUTEX_NORMAL) + }; + if ret != 0 { + // SAFETY: attr is initialized + unsafe { libc::pthread_mutexattr_destroy(attr) }; + return Err(Error::IO(std::io::Error::from_raw_os_error(ret))); + } + // SAFETY: slot is valid + unsafe { slot.write(libc::PTHREAD_MUTEX_INITIALIZER) }; + // SAFETY: attr and slot are valid ptrs and attr is initialized + let ret = unsafe { libc::pthread_mutex_init(slot, attr) }; + // SAFETY: attr was initialized + unsafe { libc::pthread_mutexattr_destroy(attr) }; + if ret != 0 { + return Err(Error::IO(std::io::Error::from_raw_os_error(ret))); + } + Ok(()) + }; + // SAFETY: mutex has been initialized + unsafe { pin_init_from_closure(init) } + } + try_pin_init!(Self { + data: UnsafeCell::new(data), + raw <- init_raw(), + pin: PhantomPinned, + }? Error) + } + + pub fn lock(&self) -> PThreadMutexGuard<'_, T> { + // SAFETY: raw is always initialized + unsafe { libc::pthread_mutex_lock(self.raw.get()) }; + PThreadMutexGuard { mtx: self } + } + } + + pub struct PThreadMutexGuard<'a, T> { + mtx: &'a PThreadMutex<T>, + } + + impl<T> Drop for PThreadMutexGuard<'_, T> { + fn drop(&mut self) { + // SAFETY: raw is always initialized + unsafe { libc::pthread_mutex_unlock(self.mtx.raw.get()) }; + } + } + + impl<T> Deref for PThreadMutexGuard<'_, T> { + type Target = T; + + fn deref(&self) -> &Self::Target { + unsafe { &*self.mtx.data.get() } + } + } + + impl<T> DerefMut for PThreadMutexGuard<'_, T> { + fn deref_mut(&mut self) -> &mut Self::Target { + unsafe { &mut *self.mtx.data.get() } + } + } +} + +#[cfg_attr(test, test)] +fn main() { + #[cfg(all(any(feature = "std", feature = "alloc"), not(windows)))] + { + use core::pin::Pin; + use pin_init::*; + use pthread_mtx::*; + use std::{ + sync::Arc, + thread::{sleep, Builder}, + time::Duration, + }; + let mtx: Pin<Arc<PThreadMutex<usize>>> = Arc::try_pin_init(PThreadMutex::new(0)).unwrap(); + let mut handles = vec![]; + let thread_count = 20; + let workload = 1_000_000; + for i in 0..thread_count { + let mtx = mtx.clone(); + handles.push( + Builder::new() + .name(format!("worker #{i}")) + .spawn(move || { + for _ in 0..workload { + *mtx.lock() += 1; + } + println!("{i} halfway"); + sleep(Duration::from_millis((i as u64) * 10)); + for _ in 0..workload { + *mtx.lock() += 1; + } + println!("{i} finished"); + }) + .expect("should not fail"), + ); + } + for h in handles { + h.join().expect("thread panicked"); + } + println!("{:?}", &*mtx.lock()); + assert_eq!(*mtx.lock(), workload * thread_count * 2); + } +} |