Skip to content

Commit 33a7782

Browse files
authored
[eval] update eval scripts with webvid for pab (#225)
1 parent 49f77c8 commit 33a7782

File tree

8 files changed

+439
-4
lines changed

8 files changed

+439
-4
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,3 +165,4 @@ pretrained
165165
samples
166166
cache_dir
167167
test_outputs
168+
datasets
Lines changed: 205 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,205 @@
1+
import argparse
2+
import os
3+
4+
import imageio
5+
import torch
6+
import torchvision.transforms.functional as F
7+
import tqdm
8+
from calculate_lpips import calculate_lpips
9+
from calculate_psnr import calculate_psnr
10+
from calculate_ssim import calculate_ssim
11+
12+
13+
def load_video(video_path):
14+
"""
15+
Load a video from the given path and convert it to a PyTorch tensor.
16+
"""
17+
# Read the video using imageio
18+
reader = imageio.get_reader(video_path, "ffmpeg")
19+
20+
# Extract frames and convert to a list of tensors
21+
frames = []
22+
for frame in reader:
23+
# Convert the frame to a tensor and permute the dimensions to match (C, H, W)
24+
frame_tensor = torch.tensor(frame).cuda().permute(2, 0, 1)
25+
frames.append(frame_tensor)
26+
27+
# Stack the list of tensors into a single tensor with shape (T, C, H, W)
28+
video_tensor = torch.stack(frames)
29+
30+
return video_tensor
31+
32+
33+
def resize_video(video, target_height, target_width):
34+
resized_frames = []
35+
for frame in video:
36+
resized_frame = F.resize(frame, [target_height, target_width])
37+
resized_frames.append(resized_frame)
38+
return torch.stack(resized_frames)
39+
40+
41+
def resize_gt_video(gt_video, gen_video):
42+
gen_video_shape = gen_video.shape
43+
T_gen, _, H_gen, W_gen = gen_video_shape
44+
T_eval, _, H_eval, W_eval = gt_video.shape
45+
46+
if T_eval < T_gen:
47+
raise ValueError(f"Eval video time steps ({T_eval}) are less than generated video time steps ({T_gen}).")
48+
49+
if H_eval < H_gen or W_eval < W_gen:
50+
# Resize the video maintaining the aspect ratio
51+
resize_height = max(H_gen, int(H_gen * (H_eval / W_eval)))
52+
resize_width = max(W_gen, int(W_gen * (W_eval / H_eval)))
53+
gt_video = resize_video(gt_video, resize_height, resize_width)
54+
# Recalculate the dimensions
55+
T_eval, _, H_eval, W_eval = gt_video.shape
56+
57+
# Center crop
58+
start_h = (H_eval - H_gen) // 2
59+
start_w = (W_eval - W_gen) // 2
60+
cropped_video = gt_video[:T_gen, :, start_h : start_h + H_gen, start_w : start_w + W_gen]
61+
62+
return cropped_video
63+
64+
65+
def get_video_ids(gt_video_dirs, gen_video_dirs):
66+
video_ids = []
67+
for f in os.listdir(gt_video_dirs[0]):
68+
if f.endswith(f".mp4"):
69+
video_ids.append(f.replace(f".mp4", ""))
70+
video_ids.sort()
71+
72+
for video_dir in gt_video_dirs + gen_video_dirs:
73+
tmp_video_ids = []
74+
for f in os.listdir(video_dir):
75+
if f.endswith(f".mp4"):
76+
tmp_video_ids.append(f.replace(f".mp4", ""))
77+
tmp_video_ids.sort()
78+
if tmp_video_ids != video_ids:
79+
raise ValueError(f"Video IDs in {video_dir} are different.")
80+
return video_ids
81+
82+
83+
def get_videos(video_ids, gt_video_dirs, gen_video_dirs):
84+
gt_videos = {}
85+
generated_videos = {}
86+
87+
for gt_video_dir in gt_video_dirs:
88+
tmp_gt_videos_tensor = []
89+
for video_id in video_ids:
90+
gt_video = load_video(os.path.join(gt_video_dir, f"{video_id}.mp4"))
91+
tmp_gt_videos_tensor.append(gt_video)
92+
gt_videos[gt_video_dir] = tmp_gt_videos_tensor
93+
94+
for generated_video_dir in gen_video_dirs:
95+
tmp_generated_videos_tensor = []
96+
for video_id in video_ids:
97+
generated_video = load_video(os.path.join(generated_video_dir, f"{video_id}.mp4"))
98+
tmp_generated_videos_tensor.append(generated_video)
99+
generated_videos[generated_video_dir] = tmp_generated_videos_tensor
100+
101+
return gt_videos, generated_videos
102+
103+
104+
def print_results(lpips_results, psnr_results, ssim_results, gt_video_dirs, gen_video_dirs):
105+
out_str = ""
106+
107+
for gt_video_dir in gt_video_dirs:
108+
for generated_video_dir in gen_video_dirs:
109+
if gt_video_dir == generated_video_dir:
110+
continue
111+
lpips = sum(lpips_results[gt_video_dir][generated_video_dir]) / len(
112+
lpips_results[gt_video_dir][generated_video_dir]
113+
)
114+
psnr = sum(psnr_results[gt_video_dir][generated_video_dir]) / len(
115+
psnr_results[gt_video_dir][generated_video_dir]
116+
)
117+
ssim = sum(ssim_results[gt_video_dir][generated_video_dir]) / len(
118+
ssim_results[gt_video_dir][generated_video_dir]
119+
)
120+
out_str += f"\ngt: {gt_video_dir} -> gen: {generated_video_dir}, lpips: {lpips:.4f}, psnr: {psnr:.4f}, ssim: {ssim:.4f}"
121+
122+
return out_str
123+
124+
125+
def main(args):
126+
device = "cuda"
127+
gt_video_dirs = args.gt_video_dirs
128+
gen_video_dirs = args.gen_video_dirs
129+
130+
video_ids = get_video_ids(gt_video_dirs, gen_video_dirs)
131+
print(f"Find {len(video_ids)} videos")
132+
133+
prompt_interval = 1
134+
batch_size = 8
135+
calculate_lpips_flag, calculate_psnr_flag, calculate_ssim_flag = True, True, True
136+
137+
lpips_results = {}
138+
psnr_results = {}
139+
ssim_results = {}
140+
for gt_video_dir in gt_video_dirs:
141+
lpips_results[gt_video_dir] = {}
142+
psnr_results[gt_video_dir] = {}
143+
ssim_results[gt_video_dir] = {}
144+
for generated_video_dir in gen_video_dirs:
145+
lpips_results[gt_video_dir][generated_video_dir] = []
146+
psnr_results[gt_video_dir][generated_video_dir] = []
147+
ssim_results[gt_video_dir][generated_video_dir] = []
148+
149+
total_len = len(video_ids) // batch_size + (1 if len(video_ids) % batch_size != 0 else 0)
150+
151+
for idx in tqdm.tqdm(range(total_len)):
152+
video_ids_batch = video_ids[idx * batch_size : (idx + 1) * batch_size]
153+
gt_videos, generated_videos = get_videos(video_ids_batch, gt_video_dirs, gen_video_dirs)
154+
155+
for gt_video_dir, gt_videos_tensor in gt_videos.items():
156+
for generated_video_dir, generated_videos_tensor in generated_videos.items():
157+
if gt_video_dir == generated_video_dir:
158+
continue
159+
160+
if not isinstance(gt_videos_tensor, torch.Tensor):
161+
for i in range(len(gt_videos_tensor)):
162+
gt_videos_tensor[i] = resize_gt_video(gt_videos_tensor[i], generated_videos_tensor[0])
163+
gt_videos_tensor = (torch.stack(gt_videos_tensor) / 255.0).cpu()
164+
165+
generated_videos_tensor = (torch.stack(generated_videos_tensor) / 255.0).cpu()
166+
167+
if calculate_lpips_flag:
168+
result = calculate_lpips(gt_videos_tensor, generated_videos_tensor, device=device)
169+
result = result["value"].values()
170+
result = float(sum(result) / len(result))
171+
lpips_results[gt_video_dir][generated_video_dir].append(result)
172+
173+
if calculate_psnr_flag:
174+
result = calculate_psnr(gt_videos_tensor, generated_videos_tensor)
175+
result = result["value"].values()
176+
result = float(sum(result) / len(result))
177+
psnr_results[gt_video_dir][generated_video_dir].append(result)
178+
179+
if calculate_ssim_flag:
180+
result = calculate_ssim(gt_videos_tensor, generated_videos_tensor)
181+
result = result["value"].values()
182+
result = float(sum(result) / len(result))
183+
ssim_results[gt_video_dir][generated_video_dir].append(result)
184+
185+
if (idx + 1) % prompt_interval == 0:
186+
out_str = print_results(lpips_results, psnr_results, ssim_results, gt_video_dirs, gen_video_dirs)
187+
print(f"Processed {idx + 1} / {total_len} videos. {out_str}")
188+
189+
out_str = print_results(lpips_results, psnr_results, ssim_results, gt_video_dirs, gen_video_dirs)
190+
191+
# save
192+
with open(f"./batch_eval.txt", "w+") as f:
193+
f.write(out_str)
194+
195+
print(f"Processed all videos. {out_str}")
196+
197+
198+
if __name__ == "__main__":
199+
parser = argparse.ArgumentParser()
200+
parser.add_argument("--gt_video_dirs", type=str, nargs="+")
201+
parser.add_argument("--gen_video_dirs", type=str, nargs="+")
202+
203+
args = parser.parse_args()
204+
205+
main(args)

eval/pab/experiments/opensora_plan.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44

55

66
def eval_base(prompt_list):
7-
config = OpenSoraPlanConfig()
7+
config = OpenSoraPlanConfig(version="v110", transformer_type="65x512x512")
88
engine = VideoSysEngine(config)
99
generate_func(engine, prompt_list, "./samples/opensoraplan_base", loop=5)
1010

@@ -15,7 +15,7 @@ def eval_pab1(prompt_list):
1515
temporal_gap=4,
1616
cross_gap=6,
1717
)
18-
config = OpenSoraPlanConfig(enable_pab=True, pab_config=pab_config)
18+
config = OpenSoraPlanConfig(version="v110", transformer_type="65x512x512", enable_pab=True, pab_config=pab_config)
1919
engine = VideoSysEngine(config)
2020
generate_func(engine, prompt_list, "./samples/opensoraplan_pab1", loop=5)
2121

@@ -26,7 +26,7 @@ def eval_pab2(prompt_list):
2626
temporal_gap=5,
2727
cross_gap=7,
2828
)
29-
config = OpenSoraPlanConfig(enable_pab=True, pab_config=pab_config)
29+
config = OpenSoraPlanConfig(version="v110", transformer_type="65x512x512", enable_pab=True, pab_config=pab_config)
3030
engine = VideoSysEngine(config)
3131
generate_func(engine, prompt_list, "./samples/opensoraplan_pab2", loop=5)
3232

@@ -37,7 +37,7 @@ def eval_pab3(prompt_list):
3737
temporal_gap=7,
3838
cross_gap=9,
3939
)
40-
config = OpenSoraPlanConfig(enable_pab=True, pab_config=pab_config)
40+
config = OpenSoraPlanConfig(version="v110", transformer_type="65x512x512", enable_pab=True, pab_config=pab_config)
4141
engine = VideoSysEngine(config)
4242
generate_func(engine, prompt_list, "./samples/opensoraplan_pab3", loop=5)
4343

eval/pab/webvid/download.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
import csv
2+
import os
3+
import time
4+
5+
import requests
6+
import tqdm
7+
8+
9+
def read_csv(csv_file):
10+
with open(csv_file, "r") as f:
11+
reader = csv.reader(f)
12+
data = list(reader)
13+
data = data[1:]
14+
print(f"Read {len(data)} rows from {csv_file}")
15+
return data
16+
17+
18+
def select_csv(data, min_text_len, min_vid_len, select_num):
19+
results = []
20+
assert 0 <= min_vid_len <= 60
21+
min_vid_len_str = f"PT00H00M{min_vid_len:02d}S"
22+
for d in data:
23+
# [id, link, duration, page, text]
24+
if d[2] < min_vid_len_str:
25+
continue
26+
token_num = len(d[4].split(" "))
27+
if token_num < min_text_len:
28+
continue
29+
results.append(d)
30+
if len(results) == select_num:
31+
break
32+
return results
33+
34+
35+
def save_data_list(data, save_path):
36+
with open(save_path, "w") as f:
37+
writer = csv.writer(f)
38+
writer.writerow(["id", "link", "duration", "page", "text"])
39+
for d in data:
40+
writer.writerow(d)
41+
42+
43+
def download_video(data, save_path):
44+
os.makedirs(save_path, exist_ok=True)
45+
for d in tqdm.tqdm(data):
46+
url = d[1]
47+
video_path = os.path.join(save_path, f"{d[0]}.mp4")
48+
while True:
49+
try:
50+
r = requests.get(url, stream=True)
51+
with open(video_path, "wb") as f:
52+
for chunk in r.iter_content(chunk_size=1024):
53+
if chunk:
54+
f.write(chunk)
55+
break
56+
except ConnectionError:
57+
time.sleep(1)
58+
print(f"Failed to download {url}, retrying...")
59+
continue
60+
time.sleep(0.1)
61+
62+
63+
if __name__ == "__main__":
64+
data = read_csv("./datasets/webvid.csv")
65+
selected_data = select_csv(data, 20, 5, 500)
66+
save_data_list(selected_data, "./datasets/webvid_selected.csv")
67+
download_video(selected_data, "./datasets/webvid")

eval/pab/webvid/latte.py

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
from utils import generate_func, load_eval_prompts
2+
3+
from videosys import LatteConfig, LattePABConfig, VideoSysEngine
4+
5+
6+
def eval_base(prompt_list):
7+
config = LatteConfig()
8+
engine = VideoSysEngine(config)
9+
generate_func(engine, prompt_list, "./samples/latte_base")
10+
11+
12+
def eval_pab1(prompt_list):
13+
pab_config = LattePABConfig(
14+
spatial_range=2,
15+
temporal_range=3,
16+
cross_range=6,
17+
)
18+
config = LatteConfig(enable_pab=True, pab_config=pab_config)
19+
engine = VideoSysEngine(config)
20+
generate_func(engine, prompt_list, "./samples/latte_pab1")
21+
22+
23+
def eval_pab2(prompt_list):
24+
pab_config = LattePABConfig(
25+
spatial_range=3,
26+
temporal_range=4,
27+
cross_range=7,
28+
)
29+
config = LatteConfig(enable_pab=True, pab_config=pab_config)
30+
engine = VideoSysEngine(config)
31+
generate_func(engine, prompt_list, "./samples/latte_pab2")
32+
33+
34+
def eval_pab3(prompt_list):
35+
pab_config = LattePABConfig(
36+
spatial_range=4,
37+
temporal_range=6,
38+
cross_range=9,
39+
)
40+
config = LatteConfig(enable_pab=True, pab_config=pab_config)
41+
engine = VideoSysEngine(config)
42+
generate_func(engine, prompt_list, "./samples/latte_pab3")
43+
44+
45+
if __name__ == "__main__":
46+
prompt_list = load_eval_prompts("./datasets/webvid_selected.csv")
47+
eval_base(prompt_list)
48+
eval_pab1(prompt_list)
49+
eval_pab2(prompt_list)
50+
eval_pab3(prompt_list)

0 commit comments

Comments
 (0)