Contributors: 5
Author Tokens Token Proportion Commits Commit Proportion
José Expósito 717 75.47% 1 11.11%
Kaibo Ma 143 15.05% 1 11.11%
Miguel Ojeda Sandonis 63 6.63% 4 44.44%
Gary Guo 17 1.79% 2 22.22%
Unknown 10 1.05% 1 11.11%
Total 950 9


// SPDX-License-Identifier: GPL-2.0

//! Procedural macro to run KUnit tests using a user-space like syntax.
//!
//! Copyright (c) 2023 José Expósito <jose.exposito89@gmail.com>

use proc_macro::{Delimiter, Group, TokenStream, TokenTree};
use std::collections::HashMap;
use std::fmt::Write;

pub(crate) fn kunit_tests(attr: TokenStream, ts: TokenStream) -> TokenStream {
    let attr = attr.to_string();

    if attr.is_empty() {
        panic!("Missing test name in `#[kunit_tests(test_name)]` macro")
    }

    if attr.len() > 255 {
        panic!("The test suite name `{attr}` exceeds the maximum length of 255 bytes")
    }

    let mut tokens: Vec<_> = ts.into_iter().collect();

    // Scan for the `mod` keyword.
    tokens
        .iter()
        .find_map(|token| match token {
            TokenTree::Ident(ident) => match ident.to_string().as_str() {
                "mod" => Some(true),
                _ => None,
            },
            _ => None,
        })
        .expect("`#[kunit_tests(test_name)]` attribute should only be applied to modules");

    // Retrieve the main body. The main body should be the last token tree.
    let body = match tokens.pop() {
        Some(TokenTree::Group(group)) if group.delimiter() == Delimiter::Brace => group,
        _ => panic!("Cannot locate main body of module"),
    };

    // Get the functions set as tests. Search for `[test]` -> `fn`.
    let mut body_it = body.stream().into_iter();
    let mut tests = Vec::new();
    let mut attributes: HashMap<String, TokenStream> = HashMap::new();
    while let Some(token) = body_it.next() {
        match token {
            TokenTree::Punct(ref p) if p.as_char() == '#' => match body_it.next() {
                Some(TokenTree::Group(g)) if g.delimiter() == Delimiter::Bracket => {
                    if let Some(TokenTree::Ident(name)) = g.stream().into_iter().next() {
                        // Collect attributes because we need to find which are tests. We also
                        // need to copy `cfg` attributes so tests can be conditionally enabled.
                        attributes
                            .entry(name.to_string())
                            .or_default()
                            .extend([token, TokenTree::Group(g)]);
                    }
                    continue;
                }
                _ => (),
            },
            TokenTree::Ident(i) if i.to_string() == "fn" && attributes.contains_key("test") => {
                if let Some(TokenTree::Ident(test_name)) = body_it.next() {
                    tests.push((test_name, attributes.remove("cfg").unwrap_or_default()))
                }
            }

            _ => (),
        }
        attributes.clear();
    }

    // Add `#[cfg(CONFIG_KUNIT="y")]` before the module declaration.
    let config_kunit = "#[cfg(CONFIG_KUNIT=\"y\")]".to_owned().parse().unwrap();
    tokens.insert(
        0,
        TokenTree::Group(Group::new(Delimiter::None, config_kunit)),
    );

    // Generate the test KUnit test suite and a test case for each `#[test]`.
    // The code generated for the following test module:
    //
    // ```
    // #[kunit_tests(kunit_test_suit_name)]
    // mod tests {
    //     #[test]
    //     fn foo() {
    //         assert_eq!(1, 1);
    //     }
    //
    //     #[test]
    //     fn bar() {
    //         assert_eq!(2, 2);
    //     }
    // }
    // ```
    //
    // Looks like:
    //
    // ```
    // unsafe extern "C" fn kunit_rust_wrapper_foo(_test: *mut ::kernel::bindings::kunit) { foo(); }
    // unsafe extern "C" fn kunit_rust_wrapper_bar(_test: *mut ::kernel::bindings::kunit) { bar(); }
    //
    // static mut TEST_CASES: [::kernel::bindings::kunit_case; 3] = [
    //     ::kernel::kunit::kunit_case(::kernel::c_str!("foo"), kunit_rust_wrapper_foo),
    //     ::kernel::kunit::kunit_case(::kernel::c_str!("bar"), kunit_rust_wrapper_bar),
    //     ::kernel::kunit::kunit_case_null(),
    // ];
    //
    // ::kernel::kunit_unsafe_test_suite!(kunit_test_suit_name, TEST_CASES);
    // ```
    let mut kunit_macros = "".to_owned();
    let mut test_cases = "".to_owned();
    let mut assert_macros = "".to_owned();
    let path = crate::helpers::file();
    let num_tests = tests.len();
    for (test, cfg_attr) in tests {
        let kunit_wrapper_fn_name = format!("kunit_rust_wrapper_{test}");
        // Append any `cfg` attributes the user might have written on their tests so we don't
        // attempt to call them when they are `cfg`'d out. An extra `use` is used here to reduce
        // the length of the assert message.
        let kunit_wrapper = format!(
            r#"unsafe extern "C" fn {kunit_wrapper_fn_name}(_test: *mut ::kernel::bindings::kunit)
            {{
                (*_test).status = ::kernel::bindings::kunit_status_KUNIT_SKIPPED;
                {cfg_attr} {{
                    (*_test).status = ::kernel::bindings::kunit_status_KUNIT_SUCCESS;
                    use ::kernel::kunit::is_test_result_ok;
                    assert!(is_test_result_ok({test}()));
                }}
            }}"#,
        );
        writeln!(kunit_macros, "{kunit_wrapper}").unwrap();
        writeln!(
            test_cases,
            "    ::kernel::kunit::kunit_case(::kernel::c_str!(\"{test}\"), {kunit_wrapper_fn_name}),"
        )
        .unwrap();
        writeln!(
            assert_macros,
            r#"
/// Overrides the usual [`assert!`] macro with one that calls KUnit instead.
#[allow(unused)]
macro_rules! assert {{
    ($cond:expr $(,)?) => {{{{
        kernel::kunit_assert!("{test}", "{path}", 0, $cond);
    }}}}
}}

/// Overrides the usual [`assert_eq!`] macro with one that calls KUnit instead.
#[allow(unused)]
macro_rules! assert_eq {{
    ($left:expr, $right:expr $(,)?) => {{{{
        kernel::kunit_assert_eq!("{test}", "{path}", 0, $left, $right);
    }}}}
}}
        "#
        )
        .unwrap();
    }

    writeln!(kunit_macros).unwrap();
    writeln!(
        kunit_macros,
        "static mut TEST_CASES: [::kernel::bindings::kunit_case; {}] = [\n{test_cases}    ::kernel::kunit::kunit_case_null(),\n];",
        num_tests + 1
    )
    .unwrap();

    writeln!(
        kunit_macros,
        "::kernel::kunit_unsafe_test_suite!({attr}, TEST_CASES);"
    )
    .unwrap();

    // Remove the `#[test]` macros.
    // We do this at a token level, in order to preserve span information.
    let mut new_body = vec![];
    let mut body_it = body.stream().into_iter();

    while let Some(token) = body_it.next() {
        match token {
            TokenTree::Punct(ref c) if c.as_char() == '#' => match body_it.next() {
                Some(TokenTree::Group(group)) if group.to_string() == "[test]" => (),
                Some(next) => {
                    new_body.extend([token, next]);
                }
                _ => {
                    new_body.push(token);
                }
            },
            _ => {
                new_body.push(token);
            }
        }
    }

    let mut final_body = TokenStream::new();
    final_body.extend::<TokenStream>(assert_macros.parse().unwrap());
    final_body.extend(new_body);
    final_body.extend::<TokenStream>(kunit_macros.parse().unwrap());

    tokens.push(TokenTree::Group(Group::new(Delimiter::Brace, final_body)));

    tokens.into_iter().collect()
}