Skip to content

Commit eca1d13

Browse files
authored
refactor: mmap query and preprocessing PPs (#237)
* feat: mmap query and preprocessing param files * fix: inverted logic
1 parent dbebb26 commit eca1d13

File tree

7 files changed

+53
-80
lines changed

7 files changed

+53
-80
lines changed

Cargo.lock

Lines changed: 3 additions & 2 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

lgn-provers/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ verifiable-db = { workspace = true }
2525

2626
lgn-messages = { path = "../lgn-messages" }
2727
exponential-backoff = "2.0.1"
28+
memmap2 = "0.9.7"
2829

2930
[features]
3031
dummy-prover = []

lgn-provers/src/params/mod.rs

Lines changed: 25 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,15 @@
11
use std::collections::HashMap;
2-
use std::path::Path;
32
use std::path::PathBuf;
43
use std::str::FromStr;
54
use std::time::Duration;
65

76
use anyhow::bail;
87
use anyhow::ensure;
98
use anyhow::Context;
10-
use bytes::Bytes;
119
use futures::StreamExt;
1210
use reqwest::StatusCode;
1311
use tokio::fs::File;
1412
use tokio::io::AsyncWriteExt;
15-
use tracing::error;
1613
use tracing::info;
1714
use tracing::warn;
1815

@@ -39,7 +36,6 @@ const DOWNLOAD_BACKOFF_MIN_MILLIS: u64 = 100;
3936
const DOWNLOAD_BACKOFF_MAX_MILLIS: u64 = 10_000;
4037

4138
/// Ratio to convert byte to megabytes.
42-
const BYTES_TO_MEGABYTES_USIZE: usize = 1024 * 1024;
4339
const BYTES_TO_MEGABYTES_U64: u64 = 1024 * 1024;
4440

4541
/// Download and verify `file_name`.
@@ -55,7 +51,7 @@ pub async fn download_and_checksum(
5551
param_dir: &str,
5652
file_name: &str,
5753
checksums: &HashMap<String, blake3::Hash>,
58-
) -> anyhow::Result<Bytes> {
54+
) -> anyhow::Result<PathBuf> {
5955
let mut filepath = PathBuf::from(param_dir);
6056
filepath.push(file_name);
6157

@@ -84,16 +80,14 @@ pub async fn download_and_checksum(
8480
.with_context(|| format!("Failed to create file. filepath: {}", filepath.display()))?;
8581

8682
let mut hasher = blake3::Hasher::new();
87-
let mut buf = std::fs::read(&filepath)
88-
.with_context(|| format!("Reading file failed. filepath: {}", filepath.display()))?;
89-
hasher.update_rayon(&buf);
83+
hasher.update_mmap_rayon(&filepath)?;
9084

9185
if *expected_checksum == hasher.finalize() {
9286
info!(
9387
"Found file matching checksum, skipping download. filepath: {}",
9488
filepath.display()
9589
);
96-
return Ok(Bytes::from(buf));
90+
return Ok(filepath);
9791
}
9892

9993
let fileurl = format!("{base_url}/{file_name}");
@@ -111,10 +105,7 @@ pub async fn download_and_checksum(
111105

112106
for (retry, duration) in backoff.iter().enumerate() {
113107
let result = resume_download(
114-
&mut buf,
115-
base_url,
116108
&mut file,
117-
&filepath,
118109
&fileurl,
119110
retry,
120111
&client,
@@ -125,12 +116,8 @@ pub async fn download_and_checksum(
125116

126117
match result {
127118
Ok(()) => {
128-
info!(
129-
"Params loaded. filepath: {} size: {}MiB",
130-
filepath.display(),
131-
buf.len() / BYTES_TO_MEGABYTES_USIZE,
132-
);
133-
return Ok(Bytes::from(buf));
119+
info!("Downloaded file. filepath: {}", filepath.display());
120+
return Ok(filepath);
134121
},
135122
Err(err) => {
136123
if let Some(duration) = duration {
@@ -154,10 +141,7 @@ pub async fn download_and_checksum(
154141

155142
#[allow(clippy::too_many_arguments)]
156143
async fn resume_download(
157-
buf: &mut Vec<u8>,
158-
base_url: &str,
159144
file: &mut File,
160-
filepath: &Path,
161145
fileurl: &str,
162146
retry: usize,
163147
client: &reqwest::Client,
@@ -182,66 +166,44 @@ async fn resume_download(
182166
Ok(parsed)
183167
})?;
184168

169+
if response.status() == StatusCode::RANGE_NOT_SATISFIABLE {
170+
hasher.reset();
171+
file.set_len(0).await?;
172+
bail!("Local file is bigger than remote, reset file and restart download");
173+
}
174+
175+
ensure!(
176+
response.status().is_success(),
177+
"Requesting params failed. status: {}",
178+
response.status(),
179+
);
180+
185181
info!(
186-
"Downloading params. base_url: {} filepath: {} present: {}MiB download: {}MiB retry: {}",
187-
base_url,
188-
filepath.display(),
182+
"Downloading params. fileurl: {} present: {}MiB download: {}MiB retry: {}",
183+
fileurl,
189184
metadata.len() / BYTES_TO_MEGABYTES_U64,
190185
length / BYTES_TO_MEGABYTES_U64,
191186
retry,
192187
);
193188

194-
if response.status() == StatusCode::RANGE_NOT_SATISFIABLE {
195-
warn!(
196-
"Local file is bigger than remote, resetting length and checking checksum. filepath: {}",
197-
filepath.display(),
198-
);
199-
200-
hasher.reset();
201-
buf.resize(
202-
length.try_into().expect("File size should fit in a usize"),
203-
0,
204-
);
205-
file.set_len(length).await?;
206-
hasher.update_rayon(buf);
207-
} else {
208-
ensure!(
209-
response.status().is_success(),
210-
"Requesting params failed. status: {} filepath: {}",
211-
response.status(),
212-
filepath.display(),
213-
);
214-
215-
let mut stream = response.bytes_stream();
216-
while let Some(data) = stream.next().await {
217-
let data = data?;
218-
file.write_all(&data).await?;
219-
hasher.update_rayon(&data);
220-
buf.extend(data);
221-
}
189+
let mut stream = response.bytes_stream();
190+
while let Some(data) = stream.next().await {
191+
let data = data?;
192+
file.write_all(&data).await?;
193+
hasher.update_rayon(&data);
222194
}
223195

224196
let found_checksum = hasher.finalize();
225197
if found_checksum != *expected_checksum {
226-
error!(
227-
"Checksum failed, restarting download. checksum: {} expected: {} filepath: {}",
228-
found_checksum.to_hex(),
229-
expected_checksum.to_hex(),
230-
filepath.display(),
231-
);
232-
233198
hasher.reset();
234-
buf.clear();
235199
file.set_len(0).await?;
236200

237201
bail!(
238-
"Checksum failed, restarting download. checksum: {} expected: {} filepath: {}",
202+
"Checksum failed, restarting download. checksum: {} expected: {}",
239203
found_checksum.to_hex(),
240204
expected_checksum.to_hex(),
241-
filepath.display(),
242205
);
243206
} else {
244-
info!("Downloaded file. filepath: {}", filepath.display());
245207
Ok(())
246208
}
247209
}

lgn-provers/src/provers/v1/groth16/euclid_prover.rs

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
use std::collections::HashMap;
2+
use std::fs::read;
23

4+
use anyhow::Context;
35
use anyhow::Result;
46
use groth16_framework::Groth16Prover;
57
use tracing::debug;
@@ -21,17 +23,20 @@ impl Groth16EuclidProver {
2123
pk_file: &str,
2224
checksums: &HashMap<String, blake3::Hash>,
2325
) -> Result<Self> {
24-
let circuit_bytes =
26+
let circuit_bytes_path =
2527
params::download_and_checksum(url, dir, circuit_file, checksums).await?;
26-
let r1cs_bytes = params::download_and_checksum(url, dir, r1cs_file, checksums).await?;
27-
let pk_bytes = params::download_and_checksum(url, dir, pk_file, checksums).await?;
28+
let r1cs_bytes_path = params::download_and_checksum(url, dir, r1cs_file, checksums).await?;
29+
let pk_bytes_path = params::download_and_checksum(url, dir, pk_file, checksums).await?;
30+
31+
let r1cs = read(&r1cs_bytes_path)
32+
.with_context(|| format!("while reading {}", r1cs_bytes_path.display()))?;
33+
let pk = read(&pk_bytes_path)
34+
.with_context(|| format!("while reading {}", pk_bytes_path.display()))?;
35+
let circuit = read(&circuit_bytes_path)
36+
.with_context(|| format!("while reading {}", circuit_bytes_path.display()))?;
2837

2938
debug!("Creating Groth16 prover");
30-
let inner = Groth16Prover::from_bytes(
31-
r1cs_bytes.to_vec(),
32-
pk_bytes.to_vec(),
33-
circuit_bytes.to_vec(),
34-
)?;
39+
let inner = Groth16Prover::from_bytes(r1cs, pk, circuit)?;
3540

3641
debug!("Groth16 prover created");
3742
Ok(Self { inner })

lgn-provers/src/provers/v1/preprocessing/euclid_prover.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use std::collections::HashMap;
2+
use std::fs::File;
23

34
use alloy::primitives::Address;
45
use alloy::primitives::U256;
@@ -46,9 +47,10 @@ impl PreprocessingEuclidProver {
4647
checksums: &HashMap<String, blake3::Hash>,
4748
with_tracing: bool,
4849
) -> anyhow::Result<Self> {
49-
let params = params::download_and_checksum(url, dir, file, checksums).await?;
50-
let reader = std::io::BufReader::new(params.as_ref());
51-
let params = bincode::deserialize_from(reader)?;
50+
let params_path = params::download_and_checksum(url, dir, file, checksums).await?;
51+
let file = File::open(params_path)?;
52+
let reader = unsafe { memmap2::Mmap::map(&file)? };
53+
let params = bincode::deserialize(&reader)?;
5254
Ok(Self {
5355
params,
5456
with_tracing,

lgn-provers/src/provers/v1/query/euclid_prover.rs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
use std::collections::HashMap;
2+
use std::fs::File;
23

34
use anyhow::Context;
45
use lgn_messages::types::v1::query::tasks::MatchingRowInput;
@@ -51,9 +52,10 @@ impl QueryEuclidProver {
5152
checksums: &HashMap<String, blake3::Hash>,
5253
with_tracing: bool,
5354
) -> anyhow::Result<Self> {
54-
let params = params::download_and_checksum(url, dir, file, checksums).await?;
55-
let reader = std::io::BufReader::new(params.as_ref());
56-
let params = bincode::deserialize_from(reader)?;
55+
let params_path = params::download_and_checksum(url, dir, file, checksums).await?;
56+
let file = File::open(params_path)?;
57+
let reader = unsafe { memmap2::Mmap::map(&file)? };
58+
let params = bincode::deserialize(&reader)?;
5759
Ok(Self {
5860
params,
5961
with_tracing,

lgn-worker/Cargo.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ alloy = { workspace = true, features = ["signers", "signer-local", "signer-keyst
2020
anyhow = { workspace = true }
2121
backtrace = { workspace = true }
2222
bincode.workspace = true
23-
blake3 = { workspace = true }
23+
blake3 = { workspace = true, features = ["mmap", "rayon"] }
2424
clap = { workspace = true, features = ["derive", "env", "help", "std", "suggestions"] }
2525
config = { workspace = true, features = ["toml"] }
2626
elliptic-curve = { workspace = true }

0 commit comments

Comments
 (0)