Skip to content

Commit e7abeb2

Browse files
authored
enable data parallelism (#61)
1 parent a4d3407 commit e7abeb2

File tree

2 files changed

+140
-74
lines changed

2 files changed

+140
-74
lines changed

Cargo.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[package]
22
name = "labelme2yolo"
3-
version = "0.2.0"
3+
version = "0.2.1"
44
edition = "2021"
55

66
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
@@ -12,3 +12,5 @@ serde = { version = "1.0", features = ["derive"] }
1212
serde_json = "1.0"
1313
rand = "0.8"
1414
sanitize-filename = "0.5"
15+
indicatif = "0.16"
16+
rayon = "1.5"

src/main.rs

Lines changed: 137 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,41 @@
11
use clap::{Parser, ValueEnum};
22
use glob::glob;
3-
use std::path::{Path, PathBuf};
4-
use std::str::FromStr;
5-
use serde::{Serialize, Deserialize};
3+
use indicatif::{ProgressBar, ProgressStyle};
4+
use rand::rngs::StdRng;
5+
use rand::seq::SliceRandom;
6+
use rand::SeedableRng;
7+
use rayon::prelude::*;
8+
use serde::{Deserialize, Serialize};
69
use serde_json;
710
use std::collections::HashMap;
811
use std::fs;
12+
use std::fs::copy;
913
use std::fs::File;
1014
use std::io::Write;
11-
use std::fs::copy;
12-
use rand::seq::SliceRandom;
13-
use rand::SeedableRng;
14-
use rand::rngs::StdRng;
15+
use std::path::{Path, PathBuf};
16+
use std::str::FromStr;
17+
use std::sync::{Arc, Mutex};
1518

16-
#[derive(Debug, Serialize, Deserialize)]
17-
struct Shape {
18-
label: String,
19-
points: Vec<(f64, f64)>,
19+
#[derive(Debug, Serialize, Deserialize)]
20+
struct Shape {
21+
label: String,
22+
points: Vec<(f64, f64)>,
2023
group_id: Option<String>,
21-
shape_type: String,
24+
shape_type: String,
2225
description: Option<String>,
2326
mask: Option<String>,
24-
}
25-
26-
#[derive(Debug, Serialize, Deserialize)]
27+
}
28+
29+
#[derive(Debug, Serialize, Deserialize)]
2730
#[serde(rename_all = "camelCase")]
28-
struct ImageAnnotation {
29-
version: String,
31+
struct ImageAnnotation {
32+
version: String,
3033
flags: Option<HashMap<String, bool>>,
31-
shapes: Vec<Shape>,
32-
image_path: String,
33-
image_data: String,
34-
image_height: u32,
35-
image_width: u32,
34+
shapes: Vec<Shape>,
35+
image_path: String,
36+
image_data: String,
37+
image_height: u32,
38+
image_width: u32,
3639
}
3740

3841
/// A powerful tool for converting LabelMe's JSON format to YOLO dataset format.
@@ -105,26 +108,20 @@ fn read_and_parse_json(path: &Path) -> Option<ImageAnnotation> {
105108

106109
fn main() {
107110
let args = Args::parse();
108-
109111
let dirname = PathBuf::from(&args.json_dir);
110112
let pattern = dirname.join("**/*.json");
111-
112113
let labels_dir = dirname.join("YOLODataset/labels");
113114
let images_dir = dirname.join("YOLODataset/images");
114-
115115
create_dir(&labels_dir);
116116
create_dir(&images_dir);
117-
118117
let train_labels_dir = labels_dir.join("train");
119118
let val_labels_dir = labels_dir.join("val");
120119
let train_images_dir = images_dir.join("train");
121120
let val_images_dir = images_dir.join("val");
122-
123121
create_dir(&train_labels_dir);
124122
create_dir(&val_labels_dir);
125123
create_dir(&train_images_dir);
126124
create_dir(&val_images_dir);
127-
128125
let (test_labels_dir, test_images_dir) = if args.test_size > 0.0 {
129126
let test_labels_dir = labels_dir.join("test");
130127
let test_images_dir = images_dir.join("test");
@@ -134,12 +131,9 @@ fn main() {
134131
} else {
135132
(None, None)
136133
};
137-
138-
let mut label_map = HashMap::new();
139-
let mut next_class_id = 0;
140-
134+
let label_map = Arc::new(Mutex::new(HashMap::new()));
135+
let next_class_id = Arc::new(Mutex::new(0));
141136
let mut annotations = Vec::new();
142-
143137
for entry in glob(pattern.to_str().expect("Failed to convert path to string"))
144138
.expect("Failed to read glob pattern")
145139
{
@@ -149,7 +143,6 @@ fn main() {
149143
}
150144
}
151145
}
152-
153146
// Shuffle and split the annotations into train, val, and test sets
154147
let seed: u64 = 42; // Fixed random seed
155148
let mut rng = StdRng::seed_from_u64(seed);
@@ -158,51 +151,122 @@ fn main() {
158151
let val_size = (annotations.len() as f32 * args.val_size).ceil() as usize;
159152
let (test_annotations, rest_annotations) = annotations.split_at(test_size);
160153
let (val_annotations, train_annotations) = rest_annotations.split_at(val_size);
161-
162154
// Update label_map from label_list if not empty
163155
if !args.label_list.is_empty() {
156+
let mut label_map_guard = label_map.lock().unwrap();
164157
for (id, label) in args.label_list.iter().enumerate() {
165-
label_map.insert(label.clone(), id);
158+
label_map_guard.insert(label.clone(), id);
166159
}
167-
next_class_id = args.label_list.len();
168-
}
169-
170-
for (path, annotation) in train_annotations {
171-
process_annotation(path, annotation, &train_labels_dir, &train_images_dir, &mut label_map, &mut next_class_id, &args, &dirname);
172-
}
173-
174-
for (path, annotation) in val_annotations {
175-
process_annotation(path, annotation, &val_labels_dir, &val_images_dir, &mut label_map, &mut next_class_id, &args, &dirname);
160+
*next_class_id.lock().unwrap() = args.label_list.len();
176161
}
177162

163+
// Create progress bars
164+
let train_pb = Arc::new(Mutex::new(ProgressBar::new(train_annotations.len() as u64)));
165+
train_pb.lock().unwrap().set_style(ProgressStyle::default_bar()
166+
.template("{spinner:.green} [Train] [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({eta})")
167+
.progress_chars("#>-"));
168+
169+
let val_pb = Arc::new(Mutex::new(ProgressBar::new(val_annotations.len() as u64)));
170+
val_pb.lock().unwrap().set_style(ProgressStyle::default_bar()
171+
.template("{spinner:.green} [Val] [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({eta})")
172+
.progress_chars("#>-"));
173+
174+
let test_pb = Arc::new(Mutex::new(ProgressBar::new(test_annotations.len() as u64)));
175+
test_pb.lock().unwrap().set_style(ProgressStyle::default_bar()
176+
.template("{spinner:.green} [Test] [{elapsed_precise}] [{bar:40.cyan/blue}] {pos}/{len} ({eta})")
177+
.progress_chars("#>-"));
178+
179+
// Process train_annotations in parallel
180+
train_annotations.par_iter().for_each(|(path, annotation)| {
181+
let mut label_map_guard = label_map.lock().unwrap();
182+
let mut next_class_id_guard = next_class_id.lock().unwrap();
183+
process_annotation(
184+
path,
185+
annotation,
186+
&train_labels_dir,
187+
&train_images_dir,
188+
&mut label_map_guard,
189+
&mut next_class_id_guard,
190+
&args,
191+
&dirname,
192+
);
193+
train_pb.lock().unwrap().inc(1);
194+
});
195+
train_pb
196+
.lock()
197+
.unwrap()
198+
.finish_with_message("Train processing complete");
199+
200+
// Process val_annotations in parallel
201+
val_annotations.par_iter().for_each(|(path, annotation)| {
202+
let mut label_map_guard = label_map.lock().unwrap();
203+
let mut next_class_id_guard = next_class_id.lock().unwrap();
204+
process_annotation(
205+
path,
206+
annotation,
207+
&val_labels_dir,
208+
&val_images_dir,
209+
&mut label_map_guard,
210+
&mut next_class_id_guard,
211+
&args,
212+
&dirname,
213+
);
214+
val_pb.lock().unwrap().inc(1);
215+
});
216+
val_pb
217+
.lock()
218+
.unwrap()
219+
.finish_with_message("Val processing complete");
220+
221+
// Process test_annotations in parallel
178222
if let (Some(test_labels_dir), Some(test_images_dir)) = (test_labels_dir, test_images_dir) {
179-
for (path, annotation) in test_annotations {
180-
process_annotation(path, annotation, &test_labels_dir, &test_images_dir, &mut label_map, &mut next_class_id, &args, &dirname);
181-
}
223+
test_annotations.par_iter().for_each(|(path, annotation)| {
224+
let mut label_map_guard = label_map.lock().unwrap();
225+
let mut next_class_id_guard = next_class_id.lock().unwrap();
226+
process_annotation(
227+
path,
228+
annotation,
229+
&test_labels_dir,
230+
&test_images_dir,
231+
&mut label_map_guard,
232+
&mut next_class_id_guard,
233+
&args,
234+
&dirname,
235+
);
236+
test_pb.lock().unwrap().inc(1);
237+
});
238+
test_pb
239+
.lock()
240+
.unwrap()
241+
.finish_with_message("Test processing complete");
182242
}
183243

184244
// Create dataset.yaml file after processing annotations
185245
let dataset_yaml_path = dirname.join("YOLODataset/dataset.yaml");
186-
let mut dataset_yaml = File::create(dataset_yaml_path).expect("Failed to create dataset.yaml file");
187-
188-
let absolute_path = fs::canonicalize(&dirname.join("YOLODataset")).expect("Failed to get absolute path");
189-
190-
let mut yaml_content = format!("path: {}\ntrain: images/train\nval: images/val\n", absolute_path.to_str().unwrap());
246+
let mut dataset_yaml =
247+
File::create(dataset_yaml_path).expect("Failed to create dataset.yaml file");
248+
let absolute_path =
249+
fs::canonicalize(&dirname.join("YOLODataset")).expect("Failed to get absolute path");
250+
let mut yaml_content = format!(
251+
"path: {}\ntrain: images/train\nval: images/val\n",
252+
absolute_path.to_str().unwrap()
253+
);
191254
if args.test_size > 0.0 {
192255
yaml_content.push_str("test: images/test\n");
193256
} else {
194257
yaml_content.push_str("test:\n");
195258
}
196259
yaml_content.push_str("\nnames:\n");
197-
198260
// Read names from label_map
199-
let mut sorted_labels: Vec<_> = label_map.iter().collect();
261+
let label_map_guard = label_map.lock().unwrap();
262+
let mut sorted_labels: Vec<_> = label_map_guard.iter().collect();
200263
sorted_labels.sort_by_key(|&(_, id)| id);
201264
for (label, id) in sorted_labels {
202265
yaml_content.push_str(&format!(" {}: {}\n", id, label));
203266
}
204-
205-
dataset_yaml.write_all(yaml_content.as_bytes()).expect("Failed to write to dataset.yaml file");
267+
dataset_yaml
268+
.write_all(yaml_content.as_bytes())
269+
.expect("Failed to write to dataset.yaml file");
206270
}
207271

208272
fn process_annotation(
@@ -237,12 +301,7 @@ fn process_annotation(
237301
if shape.shape_type == "rectangle" {
238302
let (x1, y1) = shape.points[0];
239303
let (x2, y2) = shape.points[1];
240-
let rect_points = vec![
241-
(x1, y1),
242-
(x2, y1),
243-
(x2, y2),
244-
(x1, y2),
245-
];
304+
let rect_points = vec![(x1, y1), (x2, y1), (x2, y2), (x1, y2)];
246305
for &(x, y) in &rect_points {
247306
let x_norm = x / annotation.image_width as f64;
248307
let y_norm = y / annotation.image_height as f64;
@@ -261,12 +320,7 @@ fn process_annotation(
261320
let (x_min, y_min, x_max, y_max) = shape.points.iter().fold(
262321
(f64::MAX, f64::MAX, f64::MIN, f64::MIN),
263322
|(x_min, y_min, x_max, y_max), &(x, y)| {
264-
(
265-
x_min.min(x),
266-
y_min.min(y),
267-
x_max.max(x),
268-
y_max.max(y),
269-
)
323+
(x_min.min(x), y_min.min(y), x_max.max(x), y_max.max(y))
270324
},
271325
);
272326

@@ -275,18 +329,28 @@ fn process_annotation(
275329
let width = (x_max - x_min) / annotation.image_width as f64;
276330
let height = (y_max - y_min) / annotation.image_height as f64;
277331

278-
yolo_data.push_str(&format!("{} {:.6} {:.6} {:.6} {:.6}\n", class_id, x_center, y_center, width, height));
332+
yolo_data.push_str(&format!(
333+
"{} {:.6} {:.6} {:.6} {:.6}\n",
334+
class_id, x_center, y_center, width, height
335+
));
279336
}
280337
}
281338

282-
let output_path = labels_dir.join(sanitize_filename::sanitize(path.file_stem().unwrap().to_str().unwrap())).with_extension("txt");
339+
let output_path = labels_dir
340+
.join(sanitize_filename::sanitize(
341+
path.file_stem().unwrap().to_str().unwrap(),
342+
))
343+
.with_extension("txt");
283344
let mut file = File::create(output_path).expect("Failed to create YOLO data file");
284-
file.write_all(yolo_data.as_bytes()).expect("Failed to write YOLO data");
345+
file.write_all(yolo_data.as_bytes())
346+
.expect("Failed to write YOLO data");
285347

286348
// Copy the image to the images directory
287349
let image_path = base_dir.join(&annotation.image_path);
288350
if image_path.exists() {
289-
let image_output_path = images_dir.join(sanitize_filename::sanitize(image_path.file_name().unwrap().to_str().unwrap()));
351+
let image_output_path = images_dir.join(sanitize_filename::sanitize(
352+
image_path.file_name().unwrap().to_str().unwrap(),
353+
));
290354
copy(&image_path, &image_output_path).expect("Failed to copy image");
291355
} else {
292356
eprintln!("Image file not found: {:?}", image_path);

0 commit comments

Comments
 (0)