summaryrefslogtreecommitdiffstats
path: root/rust/pin-init/examples/pthread_mutex.rs
diff options
context:
space:
mode:
authorBenno Lossin <benno.lossin@proton.me>2025-03-08 11:04:18 +0000
committerMiguel Ojeda <ojeda@kernel.org>2025-03-16 21:59:18 +0100
commit84837cf6fa541150a3012ea233225a7ecfa8771a (patch)
tree3b5f5d06d58b42cd86ad123937c1fbc59447ac60 /rust/pin-init/examples/pthread_mutex.rs
parent4b11798e82d6f340c2afc94c57823b6fbc109fad (diff)
downloadlinux-84837cf6fa541150a3012ea233225a7ecfa8771a.tar.gz
linux-84837cf6fa541150a3012ea233225a7ecfa8771a.tar.bz2
linux-84837cf6fa541150a3012ea233225a7ecfa8771a.zip
rust: pin-init: change examples to the user-space version
Replace the examples in the documentation by the ones from the user-space version and introduce the standalone examples from the user-space version such as the `CMutex<T>` type. The `CMutex<T>` example from the pinned-init repository [1] is used in several documentation examples in the user-space version instead of the kernel `Mutex<T>` type (as it's not available). In order to split off the pin-init crate, all examples need to be free of kernel-specific types. Link: https://github.com/rust-for-Linux/pinned-init [1] Signed-off-by: Benno Lossin <benno.lossin@proton.me> Reviewed-by: Fiona Behrens <me@kloenk.dev> Tested-by: Andreas Hindborg <a.hindborg@kernel.org> Link: https://lore.kernel.org/r/20250308110339.2997091-6-benno.lossin@proton.me Signed-off-by: Miguel Ojeda <ojeda@kernel.org>
Diffstat (limited to 'rust/pin-init/examples/pthread_mutex.rs')
-rw-r--r--rust/pin-init/examples/pthread_mutex.rs178
1 files changed, 178 insertions, 0 deletions
diff --git a/rust/pin-init/examples/pthread_mutex.rs b/rust/pin-init/examples/pthread_mutex.rs
new file mode 100644
index 000000000000..9164298c44c0
--- /dev/null
+++ b/rust/pin-init/examples/pthread_mutex.rs
@@ -0,0 +1,178 @@
+// SPDX-License-Identifier: Apache-2.0 OR MIT
+
+// inspired by https://github.com/nbdd0121/pin-init/blob/trunk/examples/pthread_mutex.rs
+#![allow(clippy::undocumented_unsafe_blocks)]
+#![cfg_attr(feature = "alloc", feature(allocator_api))]
+#[cfg(not(windows))]
+mod pthread_mtx {
+ #[cfg(feature = "alloc")]
+ use core::alloc::AllocError;
+ use core::{
+ cell::UnsafeCell,
+ marker::PhantomPinned,
+ mem::MaybeUninit,
+ ops::{Deref, DerefMut},
+ pin::Pin,
+ };
+ use pin_init::*;
+ use std::convert::Infallible;
+
+ #[pin_data(PinnedDrop)]
+ pub struct PThreadMutex<T> {
+ #[pin]
+ raw: UnsafeCell<libc::pthread_mutex_t>,
+ data: UnsafeCell<T>,
+ #[pin]
+ pin: PhantomPinned,
+ }
+
+ unsafe impl<T: Send> Send for PThreadMutex<T> {}
+ unsafe impl<T: Send> Sync for PThreadMutex<T> {}
+
+ #[pinned_drop]
+ impl<T> PinnedDrop for PThreadMutex<T> {
+ fn drop(self: Pin<&mut Self>) {
+ unsafe {
+ libc::pthread_mutex_destroy(self.raw.get());
+ }
+ }
+ }
+
+ #[derive(Debug)]
+ pub enum Error {
+ #[expect(dead_code)]
+ IO(std::io::Error),
+ Alloc,
+ }
+
+ impl From<Infallible> for Error {
+ fn from(e: Infallible) -> Self {
+ match e {}
+ }
+ }
+
+ #[cfg(feature = "alloc")]
+ impl From<AllocError> for Error {
+ fn from(_: AllocError) -> Self {
+ Self::Alloc
+ }
+ }
+
+ impl<T> PThreadMutex<T> {
+ pub fn new(data: T) -> impl PinInit<Self, Error> {
+ fn init_raw() -> impl PinInit<UnsafeCell<libc::pthread_mutex_t>, Error> {
+ let init = |slot: *mut UnsafeCell<libc::pthread_mutex_t>| {
+ // we can cast, because `UnsafeCell` has the same layout as T.
+ let slot: *mut libc::pthread_mutex_t = slot.cast();
+ let mut attr = MaybeUninit::uninit();
+ let attr = attr.as_mut_ptr();
+ // SAFETY: ptr is valid
+ let ret = unsafe { libc::pthread_mutexattr_init(attr) };
+ if ret != 0 {
+ return Err(Error::IO(std::io::Error::from_raw_os_error(ret)));
+ }
+ // SAFETY: attr is initialized
+ let ret = unsafe {
+ libc::pthread_mutexattr_settype(attr, libc::PTHREAD_MUTEX_NORMAL)
+ };
+ if ret != 0 {
+ // SAFETY: attr is initialized
+ unsafe { libc::pthread_mutexattr_destroy(attr) };
+ return Err(Error::IO(std::io::Error::from_raw_os_error(ret)));
+ }
+ // SAFETY: slot is valid
+ unsafe { slot.write(libc::PTHREAD_MUTEX_INITIALIZER) };
+ // SAFETY: attr and slot are valid ptrs and attr is initialized
+ let ret = unsafe { libc::pthread_mutex_init(slot, attr) };
+ // SAFETY: attr was initialized
+ unsafe { libc::pthread_mutexattr_destroy(attr) };
+ if ret != 0 {
+ return Err(Error::IO(std::io::Error::from_raw_os_error(ret)));
+ }
+ Ok(())
+ };
+ // SAFETY: mutex has been initialized
+ unsafe { pin_init_from_closure(init) }
+ }
+ try_pin_init!(Self {
+ data: UnsafeCell::new(data),
+ raw <- init_raw(),
+ pin: PhantomPinned,
+ }? Error)
+ }
+
+ pub fn lock(&self) -> PThreadMutexGuard<'_, T> {
+ // SAFETY: raw is always initialized
+ unsafe { libc::pthread_mutex_lock(self.raw.get()) };
+ PThreadMutexGuard { mtx: self }
+ }
+ }
+
+ pub struct PThreadMutexGuard<'a, T> {
+ mtx: &'a PThreadMutex<T>,
+ }
+
+ impl<T> Drop for PThreadMutexGuard<'_, T> {
+ fn drop(&mut self) {
+ // SAFETY: raw is always initialized
+ unsafe { libc::pthread_mutex_unlock(self.mtx.raw.get()) };
+ }
+ }
+
+ impl<T> Deref for PThreadMutexGuard<'_, T> {
+ type Target = T;
+
+ fn deref(&self) -> &Self::Target {
+ unsafe { &*self.mtx.data.get() }
+ }
+ }
+
+ impl<T> DerefMut for PThreadMutexGuard<'_, T> {
+ fn deref_mut(&mut self) -> &mut Self::Target {
+ unsafe { &mut *self.mtx.data.get() }
+ }
+ }
+}
+
+#[cfg_attr(test, test)]
+fn main() {
+ #[cfg(all(any(feature = "std", feature = "alloc"), not(windows)))]
+ {
+ use core::pin::Pin;
+ use pin_init::*;
+ use pthread_mtx::*;
+ use std::{
+ sync::Arc,
+ thread::{sleep, Builder},
+ time::Duration,
+ };
+ let mtx: Pin<Arc<PThreadMutex<usize>>> = Arc::try_pin_init(PThreadMutex::new(0)).unwrap();
+ let mut handles = vec![];
+ let thread_count = 20;
+ let workload = 1_000_000;
+ for i in 0..thread_count {
+ let mtx = mtx.clone();
+ handles.push(
+ Builder::new()
+ .name(format!("worker #{i}"))
+ .spawn(move || {
+ for _ in 0..workload {
+ *mtx.lock() += 1;
+ }
+ println!("{i} halfway");
+ sleep(Duration::from_millis((i as u64) * 10));
+ for _ in 0..workload {
+ *mtx.lock() += 1;
+ }
+ println!("{i} finished");
+ })
+ .expect("should not fail"),
+ );
+ }
+ for h in handles {
+ h.join().expect("thread panicked");
+ }
+ println!("{:?}", &*mtx.lock());
+ assert_eq!(*mtx.lock(), workload * thread_count * 2);
+ }
+}