Skip to content

Commit a7028d2

Browse files
committed
Refactor simd_bitmask to reduce the number of iterations (#4129)
This PR is co-authored with @tautschnig. Background: In #4127 we enabled certain target features that are platform dependent (e.g. `sse` and `sse2` for x86_64 and `neon` for aarch64) which resulted in using the `simd_bitmask` intrinsic more frequently. Kani's current model of that intrinsic (https://github.com/model-checking/kani/blob/5f14b735b74f3ae3f9f1c64ce5656e1e735d42ea/library/kani/src/models/mod.rs#L72) uses a loop that iterates `LANE` times, which requires harnesses that touch this code (e.g. that use the `HashSet` data structure) to have a large unwinding value, which results in a significant slowdown. This PR refactors this function through rewriting it using a nested loop and manually unwinding the inner loop (which operates over the bits in a byte) so that a large unwinding value is not needed. By submitting this pull request, I confirm that my contribution is made under the terms of the Apache 2.0 and MIT licenses.
1 parent 8492f4e commit a7028d2

File tree

2 files changed

+165
-10
lines changed

2 files changed

+165
-10
lines changed

library/kani/src/models/mod.rs

Lines changed: 85 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -74,17 +74,92 @@ mod intrinsics {
7474
T: MaskElement,
7575
{
7676
let mut mask_array = [0; mask_len(LANES)];
77-
for lane in (0..input.len()).rev() {
78-
let byte = lane / 8;
79-
let mask = &mut mask_array[byte];
80-
let shift_mask = *mask << 1;
81-
*mask = if input[lane] == T::TRUE {
82-
shift_mask | 0x1
83-
} else {
84-
assert_eq!(input[lane], T::FALSE, "Masks values should either be 0 or -1");
85-
shift_mask
86-
};
77+
78+
// The implementation below is the equivalent of the following:
79+
// ```rust
80+
// for lane in (0..input.len()).rev() {
81+
// let byte = lane / 8;
82+
// let mask = &mut mask_array[byte];
83+
// let shift_mask = *mask << 1;
84+
// *mask = if input[lane] == T::TRUE {
85+
// shift_mask | 0x1
86+
// } else {
87+
// assert_eq!(input[lane], T::FALSE, "Masks values should either be 0 or -1");
88+
// shift_mask
89+
// };
90+
// }
91+
// ```
92+
// but is intentionally written in a way that minimizes the number of
93+
// loop iterations. In particular, it's implemented as a nested loop
94+
// where the outer loop iterates over bytes and the inner "loop" (which
95+
// is manually unwound) iterates over bits in a byte. This is to avoid
96+
// needing a high unwind value for harnesses that invoke this code (e.g.
97+
// through the `HashSet` data structure).
98+
for (byte_idx, byte) in mask_array.iter_mut().enumerate() {
99+
// Calculate the starting lane for this byte
100+
let start_lane = byte_idx << 3;
101+
// Calculate how many bits to process (handle the last byte which might be partial)
102+
let bits_to_process = (LANES - start_lane).min(8);
103+
104+
*byte = if bits_to_process > 0 && input[start_lane] == T::TRUE { 1 << 0 } else { 0 }
105+
| if bits_to_process > 1 && input[start_lane + 1] == T::TRUE { 1 << 1 } else { 0 }
106+
| if bits_to_process > 2 && input[start_lane + 2] == T::TRUE { 1 << 2 } else { 0 }
107+
| if bits_to_process > 3 && input[start_lane + 3] == T::TRUE { 1 << 3 } else { 0 }
108+
| if bits_to_process > 4 && input[start_lane + 4] == T::TRUE { 1 << 4 } else { 0 }
109+
| if bits_to_process > 5 && input[start_lane + 5] == T::TRUE { 1 << 5 } else { 0 }
110+
| if bits_to_process > 6 && input[start_lane + 6] == T::TRUE { 1 << 6 } else { 0 }
111+
| if bits_to_process > 7 && input[start_lane + 7] == T::TRUE { 1 << 7 } else { 0 };
112+
113+
assert!(
114+
bits_to_process < 1
115+
|| input[start_lane] == T::TRUE
116+
|| input[start_lane] == T::FALSE,
117+
"Masks values should either be 0 or -1"
118+
);
119+
assert!(
120+
bits_to_process < 2
121+
|| input[start_lane + 1] == T::TRUE
122+
|| input[start_lane + 1] == T::FALSE,
123+
"Masks values should either be 0 or -1"
124+
);
125+
assert!(
126+
bits_to_process < 3
127+
|| input[start_lane + 2] == T::TRUE
128+
|| input[start_lane + 2] == T::FALSE,
129+
"Masks values should either be 0 or -1"
130+
);
131+
assert!(
132+
bits_to_process < 4
133+
|| input[start_lane + 3] == T::TRUE
134+
|| input[start_lane + 3] == T::FALSE,
135+
"Masks values should either be 0 or -1"
136+
);
137+
assert!(
138+
bits_to_process < 5
139+
|| input[start_lane + 4] == T::TRUE
140+
|| input[start_lane + 4] == T::FALSE,
141+
"Masks values should either be 0 or -1"
142+
);
143+
assert!(
144+
bits_to_process < 6
145+
|| input[start_lane + 5] == T::TRUE
146+
|| input[start_lane + 5] == T::FALSE,
147+
"Masks values should either be 0 or -1"
148+
);
149+
assert!(
150+
bits_to_process < 7
151+
|| input[start_lane + 6] == T::TRUE
152+
|| input[start_lane + 6] == T::FALSE,
153+
"Masks values should either be 0 or -1"
154+
);
155+
assert!(
156+
bits_to_process < 8
157+
|| input[start_lane + 7] == T::TRUE
158+
|| input[start_lane + 7] == T::FALSE,
159+
"Masks values should either be 0 or -1"
160+
);
87161
}
162+
88163
mask_array
89164
}
90165

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
// Copyright Kani Contributors
2+
// SPDX-License-Identifier: Apache-2.0 OR MIT
3+
#![feature(repr_simd, core_intrinsics)]
4+
#![feature(generic_const_exprs)]
5+
#![feature(portable_simd)]
6+
7+
// This test checks the equivalence of Kani's old and new implementations of the
8+
// `simd_bitmask` intrinsic
9+
10+
use std::fmt::Debug;
11+
12+
pub trait MaskElement: PartialEq + Debug {
13+
const TRUE: Self;
14+
const FALSE: Self;
15+
}
16+
17+
impl MaskElement for i32 {
18+
const TRUE: Self = -1;
19+
const FALSE: Self = 0;
20+
}
21+
22+
/// Calculate the minimum number of lanes to represent a mask
23+
/// Logic similar to `bitmask_len` from `portable_simd`.
24+
/// <https://github.com/rust-lang/portable-simd/blob/490b5cf/crates/core_simd/src/masks/to_bitmask.rs#L75-L79>
25+
const fn mask_len(len: usize) -> usize {
26+
len.div_ceil(8)
27+
}
28+
29+
fn simd_bitmask_impl_old<T, const LANES: usize>(input: &[T; LANES]) -> [u8; mask_len(LANES)]
30+
where
31+
T: MaskElement,
32+
{
33+
let mut mask_array = [0; mask_len(LANES)];
34+
for lane in (0..input.len()).rev() {
35+
let byte = lane / 8;
36+
let mask = &mut mask_array[byte];
37+
let shift_mask = *mask << 1;
38+
*mask = if input[lane] == T::TRUE {
39+
shift_mask | 0x1
40+
} else {
41+
assert_eq!(input[lane], T::FALSE, "Masks values should either be 0 or -1");
42+
shift_mask
43+
};
44+
}
45+
mask_array
46+
}
47+
48+
unsafe fn simd_bitmask<T, U, E, const LANES: usize>(input: T) -> U
49+
where
50+
[u8; mask_len(LANES)]: Sized,
51+
E: MaskElement,
52+
{
53+
let data = &*(&input as *const T as *const [E; LANES]);
54+
let mask = simd_bitmask_impl_old(data);
55+
(&mask as *const [u8; mask_len(LANES)] as *const U).read()
56+
}
57+
58+
#[repr(simd)]
59+
#[derive(Clone, Debug)]
60+
struct CustomMask<const LANES: usize>([i32; LANES]);
61+
62+
impl<const LANES: usize> kani::Arbitrary for CustomMask<LANES>
63+
where
64+
[bool; LANES]: Sized + kani::Arbitrary,
65+
{
66+
fn any() -> Self {
67+
CustomMask(kani::any::<[bool; LANES]>().map(|v| if v { i32::FALSE } else { i32::TRUE }))
68+
}
69+
}
70+
71+
#[kani::proof]
72+
#[kani::solver(kissat)]
73+
fn check_equiv() {
74+
let mask = kani::any::<CustomMask<8>>();
75+
unsafe {
76+
let result1 = simd_bitmask::<_, u8, i32, 8>(mask.clone());
77+
let result2 = std::intrinsics::simd::simd_bitmask::<_, u8>(mask);
78+
assert_eq!(result1, result2);
79+
}
80+
}

0 commit comments

Comments
 (0)