Skip to content

Commit df2d52a

Browse files
committed
follow up on the fix of multiply with overflow
1 parent 371dba9 commit df2d52a

File tree

1 file changed

+15
-14
lines changed

1 file changed

+15
-14
lines changed

bitpacker/src/bitpacker.rs

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ impl BitPacker {
6565

6666
#[derive(Clone, Debug, Default, Copy)]
6767
pub struct BitUnpacker {
68-
num_bits: u32,
68+
num_bits: usize,
6969
mask: u64,
7070
}
7171

@@ -83,7 +83,7 @@ impl BitUnpacker {
8383
(1u64 << num_bits) - 1u64
8484
};
8585
BitUnpacker {
86-
num_bits: u32::from(num_bits),
86+
num_bits: usize::from(num_bits),
8787
mask,
8888
}
8989
}
@@ -94,7 +94,7 @@ impl BitUnpacker {
9494

9595
#[inline]
9696
pub fn get(&self, idx: u32, data: &[u8]) -> u64 {
97-
let addr_in_bits = idx as usize * self.num_bits as usize;
97+
let addr_in_bits = idx as usize * self.num_bits;
9898
let addr = addr_in_bits >> 3;
9999
if addr + 8 > data.len() {
100100
if self.num_bits == 0 {
@@ -129,24 +129,25 @@ impl BitUnpacker {
129129
//
130130
// This methods panics if `num_bits` is > 32.
131131
fn get_batch_u32s(&self, start_idx: u32, data: &[u8], output: &mut [u32]) {
132+
let start_idx = start_idx as usize;
132133
assert!(
133134
self.bit_width() <= 32,
134135
"Bitwidth must be <= 32 to use this method."
135136
);
136137

137-
let end_idx = start_idx + output.len() as u32;
138+
let end_idx = start_idx + output.len();
138139

139140
let end_bit_read = end_idx * self.num_bits;
140141
let end_byte_read = (end_bit_read + 7) / 8;
141142
assert!(
142-
end_byte_read as usize <= data.len(),
143+
end_byte_read <= data.len(),
143144
"Requested index is out of bounds."
144145
);
145146

146147
// Simple slow implementation of get_batch_u32s, to deal with our ramps.
147-
let get_batch_ramp = |start_idx: u32, output: &mut [u32]| {
148+
let get_batch_ramp = |start_idx: usize, output: &mut [u32]| {
148149
for (out, idx) in output.iter_mut().zip(start_idx..) {
149-
*out = self.get(idx, data) as u32;
150+
*out = self.get(idx as u32, data) as u32;
150151
}
151152
};
152153

@@ -161,23 +162,23 @@ impl BitUnpacker {
161162
// so highway start is the closest multiple of 8 that is >= start_idx.
162163
let entrance_ramp_len = 8 - (start_idx % 8) % 8;
163164

164-
let highway_start: u32 = start_idx + entrance_ramp_len;
165+
let highway_start: usize = start_idx + entrance_ramp_len;
165166

166-
if highway_start + BitPacker1x::BLOCK_LEN as u32 > end_idx {
167+
if highway_start + BitPacker1x::BLOCK_LEN > end_idx {
167168
// We don't have enough values to have even a single block of highway.
168169
// Let's just supply the values the simple way.
169170
get_batch_ramp(start_idx, output);
170171
return;
171172
}
172173

173-
let num_blocks: u32 = (end_idx - highway_start) / BitPacker1x::BLOCK_LEN as u32;
174+
let num_blocks: usize = (end_idx - highway_start) / BitPacker1x::BLOCK_LEN;
174175

175176
// Entrance ramp
176-
get_batch_ramp(start_idx, &mut output[..entrance_ramp_len as usize]);
177+
get_batch_ramp(start_idx, &mut output[..entrance_ramp_len]);
177178

178179
// Highway
179-
let mut offset = (highway_start * self.num_bits) as usize / 8;
180-
let mut output_cursor = (highway_start - start_idx) as usize;
180+
let mut offset = (highway_start * self.num_bits) / 8;
181+
let mut output_cursor = highway_start - start_idx;
181182
for _ in 0..num_blocks {
182183
offset += BitPacker1x.decompress(
183184
&data[offset..],
@@ -188,7 +189,7 @@ impl BitUnpacker {
188189
}
189190

190191
// Exit ramp
191-
let highway_end = highway_start + num_blocks * BitPacker1x::BLOCK_LEN as u32;
192+
let highway_end = highway_start + num_blocks * BitPacker1x::BLOCK_LEN;
192193
get_batch_ramp(highway_end, &mut output[output_cursor..]);
193194
}
194195

0 commit comments

Comments
 (0)