Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion src/ffi_avx512.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,13 +74,30 @@ pub unsafe fn hash_many<const N: usize>(

// Unsafe because this may only be called on platforms supporting AVX-512.
#[cfg(unix)]
#[allow(unused)]
#[inline]
pub unsafe fn xof_many(
cv: &CVWords,
block: &[u8; BLOCK_LEN],
block_len: u8,
counter: u64,
flags: u8,
out: &mut [u8],
) {
// SAFETY: We only write fully initialized bytes
let out: &mut [core::mem::MaybeUninit<u8>] = unsafe { core::mem::transmute(out) };
xof_many_uninit(cv, block, block_len, counter, flags, out)
}

// Unsafe because this may only be called on platforms supporting AVX-512.
#[cfg(unix)]
pub unsafe fn xof_many_uninit(
cv: &CVWords,
block: &[u8; BLOCK_LEN],
block_len: u8,
counter: u64,
flags: u8,
out: &mut [core::mem::MaybeUninit<u8>],
) {
unsafe {
debug_assert_eq!(0, out.len() % BLOCK_LEN, "whole blocks only");
Expand All @@ -90,7 +107,8 @@ pub unsafe fn xof_many(
block_len,
counter,
flags,
out.as_mut_ptr(),
// todo: use MaybeUninit::slice_as_mut_ptr when feature "maybe_uninit_slice" (issue = "63569") stabilizes
out.as_mut_ptr().cast::<u8>(),
out.len() / BLOCK_LEN,
);
}
Expand Down
24 changes: 20 additions & 4 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ use arrayref::{array_mut_ref, array_ref};
use arrayvec::{ArrayString, ArrayVec};
use core::cmp;
use core::fmt;
use core::mem::{transmute, MaybeUninit};
use platform::{Platform, MAX_SIMD_DEGREE, MAX_SIMD_DEGREE_OR_2};
#[cfg(feature = "zeroize")]
use zeroize::Zeroize;
Expand Down Expand Up @@ -1688,11 +1689,14 @@ impl OutputReader {
// This helper function handles both the case where the output buffer is
// shorter than one block, and the case where our position_within_block is
// non-zero.
fn fill_one_block(&mut self, buf: &mut &mut [u8]) {
fn fill_one_block(&mut self, buf: &mut &mut [MaybeUninit<u8>]) {
let output_block: [u8; BLOCK_LEN] = self.inner.root_output_block();
let output_bytes = &output_block[self.position_within_block as usize..];
let take = cmp::min(buf.len(), output_bytes.len());
buf[..take].copy_from_slice(&output_bytes[..take]);
let output_bytes: &[u8] = &output_bytes[..take];
// SAFETY: &[u8] and &[MaybeUninit<u8>] have the same layout
let output_bytes: &[MaybeUninit<u8>] = unsafe { transmute(output_bytes) };
buf[..take].copy_from_slice(output_bytes);
self.position_within_block += take as u8;
if self.position_within_block == BLOCK_LEN as u8 {
self.inner.counter += 1;
Expand All @@ -1717,7 +1721,19 @@ impl OutputReader {
/// reading further, the behavior is unspecified.
///
/// [`Read::read`]: #method.read
pub fn fill(&mut self, mut buf: &mut [u8]) {
#[inline]
pub fn fill(&mut self, buf: &mut [u8]) {
// SAFETY: We only write fully initialized bytes
let buf: &mut [MaybeUninit<u8>] = unsafe { transmute(buf) };
self.fill_uninit(buf)
}

/// Fills an uninitialized buffer with output bytes.
///
/// This method is similar to [`fill`](#method.fill), but accepts an uninitialized
/// buffer to avoid unnecessary zero-initialization overhead. This can provide
/// performance benefits when generating large amounts of output.
pub fn fill_uninit(&mut self, mut buf: &mut [MaybeUninit<u8>]) {
if buf.is_empty() {
return;
}
Expand All @@ -1731,7 +1747,7 @@ impl OutputReader {
let full_blocks_len = full_blocks * BLOCK_LEN;
if full_blocks > 0 {
debug_assert_eq!(0, self.position_within_block);
self.inner.platform.xof_many(
self.inner.platform.xof_many_uninit(
&self.inner.input_chaining_value,
&self.inner.block,
self.inner.block_len,
Expand Down
32 changes: 26 additions & 6 deletions src/platform.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use crate::{portable, CVWords, IncrementCounter, BLOCK_LEN};
use arrayref::{array_mut_ref, array_ref};
use core::mem::{transmute, MaybeUninit};

cfg_if::cfg_if! {
if #[cfg(any(target_arch = "x86", target_arch = "x86_64"))] {
Expand Down Expand Up @@ -163,7 +164,7 @@ impl Platform {
block_len: u8,
counter: u64,
flags: u8,
) -> [u8; 64] {
) -> [u8; BLOCK_LEN] {
match self {
Platform::Portable => portable::compress_xof(cv, block, block_len, counter, flags),
// Safe because detect() checked for platform support.
Expand Down Expand Up @@ -312,14 +313,29 @@ impl Platform {
}
}

#[inline]
pub fn xof_many(
&self,
cv: &CVWords,
block: &[u8; BLOCK_LEN],
block_len: u8,
mut counter: u64,
counter: u64,
flags: u8,
out: &mut [u8],
) {
// SAFETY: We only write fully initialized bytes
let out: &mut [MaybeUninit<u8>] = unsafe { transmute(out) };
self.xof_many_uninit(cv, block, block_len, counter, flags, out);
}

pub fn xof_many_uninit(
&self,
cv: &CVWords,
block: &[u8; BLOCK_LEN],
block_len: u8,
mut counter: u64,
flags: u8,
out: &mut [MaybeUninit<u8>],
) {
debug_assert_eq!(0, out.len() % BLOCK_LEN, "whole blocks only");
if out.is_empty() {
Expand All @@ -332,15 +348,19 @@ impl Platform {
#[cfg(unix)]
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
Platform::AVX512 => unsafe {
crate::avx512::xof_many(cv, block, block_len, counter, flags, out)
crate::avx512::xof_many_uninit(cv, block, block_len, counter, flags, out)
},
_ => {
// For platforms without an optimized xof_many, fall back to a loop over
// compress_xof. This is still faster than portable code.
// TODO: Use `as_chunks_mut` here once MSRV is >= 1.88.0.
for out_block in out.chunks_exact_mut(BLOCK_LEN) {
// TODO: Use array_chunks_mut here once that's stable.
let out_array: &mut [u8; BLOCK_LEN] = out_block.try_into().unwrap();
*out_array = self.compress_xof(cv, block, block_len, counter, flags);
let out_block = out_block.first_chunk_mut::<BLOCK_LEN>().unwrap();
// TODO: use `transpose` when "maybe_uninit_uninit_array_transpose"(issue = "96097") is stable
// SAFETY: T and MaybeUninit<T> have the same layout
let out_block: &mut MaybeUninit<[u8; BLOCK_LEN]> =
unsafe { transmute(out_block) };
out_block.write(self.compress_xof(cv, block, block_len, counter, flags));
counter += 1;
}
}
Expand Down