1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
|
// SPDX-License-Identifier: Apache-2.0 OR MIT
#![allow(clippy::undocumented_unsafe_blocks)]
#![cfg_attr(feature = "alloc", feature(allocator_api))]
use core::{
cell::{Cell, UnsafeCell},
mem::MaybeUninit,
ops,
pin::Pin,
time::Duration,
};
use pin_init::*;
use std::{
sync::Arc,
thread::{sleep, Builder},
};
#[expect(unused_attributes)]
mod mutex;
use mutex::*;
pub struct StaticInit<T, I> {
cell: UnsafeCell<MaybeUninit<T>>,
init: Cell<Option<I>>,
lock: SpinLock,
present: Cell<bool>,
}
unsafe impl<T: Sync, I> Sync for StaticInit<T, I> {}
unsafe impl<T: Send, I> Send for StaticInit<T, I> {}
impl<T, I: PinInit<T>> StaticInit<T, I> {
pub const fn new(init: I) -> Self {
Self {
cell: UnsafeCell::new(MaybeUninit::uninit()),
init: Cell::new(Some(init)),
lock: SpinLock::new(),
present: Cell::new(false),
}
}
}
impl<T, I: PinInit<T>> ops::Deref for StaticInit<T, I> {
type Target = T;
fn deref(&self) -> &Self::Target {
if self.present.get() {
unsafe { (*self.cell.get()).assume_init_ref() }
} else {
println!("acquire spinlock on static init");
let _guard = self.lock.acquire();
println!("rechecking present...");
std::thread::sleep(std::time::Duration::from_millis(200));
if self.present.get() {
return unsafe { (*self.cell.get()).assume_init_ref() };
}
println!("doing init");
let ptr = self.cell.get().cast::<T>();
match self.init.take() {
Some(f) => unsafe { f.__pinned_init(ptr).unwrap() },
None => unsafe { core::hint::unreachable_unchecked() },
}
self.present.set(true);
unsafe { (*self.cell.get()).assume_init_ref() }
}
}
}
pub struct CountInit;
unsafe impl PinInit<CMutex<usize>> for CountInit {
unsafe fn __pinned_init(
self,
slot: *mut CMutex<usize>,
) -> Result<(), core::convert::Infallible> {
let init = CMutex::new(0);
std::thread::sleep(std::time::Duration::from_millis(1000));
unsafe { init.__pinned_init(slot) }
}
}
pub static COUNT: StaticInit<CMutex<usize>, CountInit> = StaticInit::new(CountInit);
#[cfg(not(any(feature = "std", feature = "alloc")))]
fn main() {}
#[cfg(any(feature = "std", feature = "alloc"))]
fn main() {
let mtx: Pin<Arc<CMutex<usize>>> = Arc::pin_init(CMutex::new(0)).unwrap();
let mut handles = vec![];
let thread_count = 20;
let workload = 1_000;
for i in 0..thread_count {
let mtx = mtx.clone();
handles.push(
Builder::new()
.name(format!("worker #{i}"))
.spawn(move || {
for _ in 0..workload {
*COUNT.lock() += 1;
std::thread::sleep(std::time::Duration::from_millis(10));
*mtx.lock() += 1;
std::thread::sleep(std::time::Duration::from_millis(10));
*COUNT.lock() += 1;
}
println!("{i} halfway");
sleep(Duration::from_millis((i as u64) * 10));
for _ in 0..workload {
std::thread::sleep(std::time::Duration::from_millis(10));
*mtx.lock() += 1;
}
println!("{i} finished");
})
.expect("should not fail"),
);
}
for h in handles {
h.join().expect("thread panicked");
}
println!("{:?}, {:?}", &*mtx.lock(), &*COUNT.lock());
assert_eq!(*mtx.lock(), workload * thread_count * 2);
}
|