Contributors: 1
Author Tokens Token Proportion Commits Commit Proportion
Benno Lossin 1301 100.00% 3 100.00%
Total 1301 3


// SPDX-License-Identifier: Apache-2.0 OR MIT

#![allow(clippy::undocumented_unsafe_blocks)]
#![cfg_attr(feature = "alloc", feature(allocator_api))]
#![cfg_attr(not(RUSTC_LINT_REASONS_IS_STABLE), feature(lint_reasons))]
#![allow(clippy::missing_safety_doc)]

use core::{
    cell::{Cell, UnsafeCell},
    marker::PhantomPinned,
    ops::{Deref, DerefMut},
    pin::Pin,
    sync::atomic::{AtomicBool, Ordering},
};
use std::{
    sync::Arc,
    thread::{self, park, sleep, Builder, Thread},
    time::Duration,
};

use pin_init::*;
#[expect(unused_attributes)]
#[path = "./linked_list.rs"]
pub mod linked_list;
use linked_list::*;

pub struct SpinLock {
    inner: AtomicBool,
}

impl SpinLock {
    #[inline]
    pub fn acquire(&self) -> SpinLockGuard<'_> {
        while self
            .inner
            .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
            .is_err()
        {
            while self.inner.load(Ordering::Relaxed) {
                thread::yield_now();
            }
        }
        SpinLockGuard(self)
    }

    #[inline]
    #[allow(clippy::new_without_default)]
    pub const fn new() -> Self {
        Self {
            inner: AtomicBool::new(false),
        }
    }
}

pub struct SpinLockGuard<'a>(&'a SpinLock);

impl Drop for SpinLockGuard<'_> {
    #[inline]
    fn drop(&mut self) {
        self.0.inner.store(false, Ordering::Release);
    }
}

#[pin_data]
pub struct CMutex<T> {
    #[pin]
    wait_list: ListHead,
    spin_lock: SpinLock,
    locked: Cell<bool>,
    #[pin]
    data: UnsafeCell<T>,
}

impl<T> CMutex<T> {
    #[inline]
    pub fn new(val: impl PinInit<T>) -> impl PinInit<Self> {
        pin_init!(CMutex {
            wait_list <- ListHead::new(),
            spin_lock: SpinLock::new(),
            locked: Cell::new(false),
            data <- unsafe {
                pin_init_from_closure(|slot: *mut UnsafeCell<T>| {
                    val.__pinned_init(slot.cast::<T>())
                })
            },
        })
    }

    #[inline]
    pub fn lock(&self) -> Pin<CMutexGuard<'_, T>> {
        let mut sguard = self.spin_lock.acquire();
        if self.locked.get() {
            stack_pin_init!(let wait_entry = WaitEntry::insert_new(&self.wait_list));
            // println!("wait list length: {}", self.wait_list.size());
            while self.locked.get() {
                drop(sguard);
                park();
                sguard = self.spin_lock.acquire();
            }
            // This does have an effect, as the ListHead inside wait_entry implements Drop!
            #[expect(clippy::drop_non_drop)]
            drop(wait_entry);
        }
        self.locked.set(true);
        unsafe {
            Pin::new_unchecked(CMutexGuard {
                mtx: self,
                _pin: PhantomPinned,
            })
        }
    }

    #[allow(dead_code)]
    pub fn get_data_mut(self: Pin<&mut Self>) -> &mut T {
        // SAFETY: we have an exclusive reference and thus nobody has access to data.
        unsafe { &mut *self.data.get() }
    }
}

unsafe impl<T: Send> Send for CMutex<T> {}
unsafe impl<T: Send> Sync for CMutex<T> {}

pub struct CMutexGuard<'a, T> {
    mtx: &'a CMutex<T>,
    _pin: PhantomPinned,
}

impl<T> Drop for CMutexGuard<'_, T> {
    #[inline]
    fn drop(&mut self) {
        let sguard = self.mtx.spin_lock.acquire();
        self.mtx.locked.set(false);
        if let Some(list_field) = self.mtx.wait_list.next() {
            let wait_entry = list_field.as_ptr().cast::<WaitEntry>();
            unsafe { (*wait_entry).thread.unpark() };
        }
        drop(sguard);
    }
}

impl<T> Deref for CMutexGuard<'_, T> {
    type Target = T;

    #[inline]
    fn deref(&self) -> &Self::Target {
        unsafe { &*self.mtx.data.get() }
    }
}

impl<T> DerefMut for CMutexGuard<'_, T> {
    #[inline]
    fn deref_mut(&mut self) -> &mut Self::Target {
        unsafe { &mut *self.mtx.data.get() }
    }
}

#[pin_data]
#[repr(C)]
struct WaitEntry {
    #[pin]
    wait_list: ListHead,
    thread: Thread,
}

impl WaitEntry {
    #[inline]
    fn insert_new(list: &ListHead) -> impl PinInit<Self> + '_ {
        pin_init!(Self {
            thread: thread::current(),
            wait_list <- ListHead::insert_prev(list),
        })
    }
}

#[cfg(not(any(feature = "std", feature = "alloc")))]
fn main() {}

#[allow(dead_code)]
#[cfg_attr(test, test)]
#[cfg(any(feature = "std", feature = "alloc"))]
fn main() {
    let mtx: Pin<Arc<CMutex<usize>>> = Arc::pin_init(CMutex::new(0)).unwrap();
    let mut handles = vec![];
    let thread_count = 20;
    let workload = if cfg!(miri) { 100 } else { 1_000 };
    for i in 0..thread_count {
        let mtx = mtx.clone();
        handles.push(
            Builder::new()
                .name(format!("worker #{i}"))
                .spawn(move || {
                    for _ in 0..workload {
                        *mtx.lock() += 1;
                    }
                    println!("{i} halfway");
                    sleep(Duration::from_millis((i as u64) * 10));
                    for _ in 0..workload {
                        *mtx.lock() += 1;
                    }
                    println!("{i} finished");
                })
                .expect("should not fail"),
        );
    }
    for h in handles {
        h.join().expect("thread panicked");
    }
    println!("{:?}", &*mtx.lock());
    assert_eq!(*mtx.lock(), workload * thread_count * 2);
}