Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,4 @@ pretrained
samples
cache_dir
test_outputs
datasets
205 changes: 205 additions & 0 deletions eval/pab/common_metrics/batch_eval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
import argparse
import os

import imageio
import torch
import torchvision.transforms.functional as F
import tqdm
from calculate_lpips import calculate_lpips
from calculate_psnr import calculate_psnr
from calculate_ssim import calculate_ssim


def load_video(video_path):
"""
Load a video from the given path and convert it to a PyTorch tensor.
"""
# Read the video using imageio
reader = imageio.get_reader(video_path, "ffmpeg")

# Extract frames and convert to a list of tensors
frames = []
for frame in reader:
# Convert the frame to a tensor and permute the dimensions to match (C, H, W)
frame_tensor = torch.tensor(frame).cuda().permute(2, 0, 1)
frames.append(frame_tensor)

# Stack the list of tensors into a single tensor with shape (T, C, H, W)
video_tensor = torch.stack(frames)

return video_tensor


def resize_video(video, target_height, target_width):
resized_frames = []
for frame in video:
resized_frame = F.resize(frame, [target_height, target_width])
resized_frames.append(resized_frame)
return torch.stack(resized_frames)


def resize_gt_video(gt_video, gen_video):
gen_video_shape = gen_video.shape
T_gen, _, H_gen, W_gen = gen_video_shape
T_eval, _, H_eval, W_eval = gt_video.shape

if T_eval < T_gen:
raise ValueError(f"Eval video time steps ({T_eval}) are less than generated video time steps ({T_gen}).")

if H_eval < H_gen or W_eval < W_gen:
# Resize the video maintaining the aspect ratio
resize_height = max(H_gen, int(H_gen * (H_eval / W_eval)))
resize_width = max(W_gen, int(W_gen * (W_eval / H_eval)))
gt_video = resize_video(gt_video, resize_height, resize_width)
# Recalculate the dimensions
T_eval, _, H_eval, W_eval = gt_video.shape

# Center crop
start_h = (H_eval - H_gen) // 2
start_w = (W_eval - W_gen) // 2
cropped_video = gt_video[:T_gen, :, start_h : start_h + H_gen, start_w : start_w + W_gen]

return cropped_video


def get_video_ids(gt_video_dirs, gen_video_dirs):
video_ids = []
for f in os.listdir(gt_video_dirs[0]):
if f.endswith(f".mp4"):
video_ids.append(f.replace(f".mp4", ""))
video_ids.sort()

for video_dir in gt_video_dirs + gen_video_dirs:
tmp_video_ids = []
for f in os.listdir(video_dir):
if f.endswith(f".mp4"):
tmp_video_ids.append(f.replace(f".mp4", ""))
tmp_video_ids.sort()
if tmp_video_ids != video_ids:
raise ValueError(f"Video IDs in {video_dir} are different.")
return video_ids


def get_videos(video_ids, gt_video_dirs, gen_video_dirs):
gt_videos = {}
generated_videos = {}

for gt_video_dir in gt_video_dirs:
tmp_gt_videos_tensor = []
for video_id in video_ids:
gt_video = load_video(os.path.join(gt_video_dir, f"{video_id}.mp4"))
tmp_gt_videos_tensor.append(gt_video)
gt_videos[gt_video_dir] = tmp_gt_videos_tensor

for generated_video_dir in gen_video_dirs:
tmp_generated_videos_tensor = []
for video_id in video_ids:
generated_video = load_video(os.path.join(generated_video_dir, f"{video_id}.mp4"))
tmp_generated_videos_tensor.append(generated_video)
generated_videos[generated_video_dir] = tmp_generated_videos_tensor

return gt_videos, generated_videos


def print_results(lpips_results, psnr_results, ssim_results, gt_video_dirs, gen_video_dirs):
out_str = ""

for gt_video_dir in gt_video_dirs:
for generated_video_dir in gen_video_dirs:
if gt_video_dir == generated_video_dir:
continue
lpips = sum(lpips_results[gt_video_dir][generated_video_dir]) / len(
lpips_results[gt_video_dir][generated_video_dir]
)
psnr = sum(psnr_results[gt_video_dir][generated_video_dir]) / len(
psnr_results[gt_video_dir][generated_video_dir]
)
ssim = sum(ssim_results[gt_video_dir][generated_video_dir]) / len(
ssim_results[gt_video_dir][generated_video_dir]
)
out_str += f"\ngt: {gt_video_dir} -> gen: {generated_video_dir}, lpips: {lpips:.4f}, psnr: {psnr:.4f}, ssim: {ssim:.4f}"

return out_str


def main(args):
device = "cuda"
gt_video_dirs = args.gt_video_dirs
gen_video_dirs = args.gen_video_dirs

video_ids = get_video_ids(gt_video_dirs, gen_video_dirs)
print(f"Find {len(video_ids)} videos")

prompt_interval = 1
batch_size = 8
calculate_lpips_flag, calculate_psnr_flag, calculate_ssim_flag = True, True, True

lpips_results = {}
psnr_results = {}
ssim_results = {}
for gt_video_dir in gt_video_dirs:
lpips_results[gt_video_dir] = {}
psnr_results[gt_video_dir] = {}
ssim_results[gt_video_dir] = {}
for generated_video_dir in gen_video_dirs:
lpips_results[gt_video_dir][generated_video_dir] = []
psnr_results[gt_video_dir][generated_video_dir] = []
ssim_results[gt_video_dir][generated_video_dir] = []

total_len = len(video_ids) // batch_size + (1 if len(video_ids) % batch_size != 0 else 0)

for idx in tqdm.tqdm(range(total_len)):
video_ids_batch = video_ids[idx * batch_size : (idx + 1) * batch_size]
gt_videos, generated_videos = get_videos(video_ids_batch, gt_video_dirs, gen_video_dirs)

for gt_video_dir, gt_videos_tensor in gt_videos.items():
for generated_video_dir, generated_videos_tensor in generated_videos.items():
if gt_video_dir == generated_video_dir:
continue

if not isinstance(gt_videos_tensor, torch.Tensor):
for i in range(len(gt_videos_tensor)):
gt_videos_tensor[i] = resize_gt_video(gt_videos_tensor[i], generated_videos_tensor[0])
gt_videos_tensor = (torch.stack(gt_videos_tensor) / 255.0).cpu()

generated_videos_tensor = (torch.stack(generated_videos_tensor) / 255.0).cpu()

if calculate_lpips_flag:
result = calculate_lpips(gt_videos_tensor, generated_videos_tensor, device=device)
result = result["value"].values()
result = float(sum(result) / len(result))
lpips_results[gt_video_dir][generated_video_dir].append(result)

if calculate_psnr_flag:
result = calculate_psnr(gt_videos_tensor, generated_videos_tensor)
result = result["value"].values()
result = float(sum(result) / len(result))
psnr_results[gt_video_dir][generated_video_dir].append(result)

if calculate_ssim_flag:
result = calculate_ssim(gt_videos_tensor, generated_videos_tensor)
result = result["value"].values()
result = float(sum(result) / len(result))
ssim_results[gt_video_dir][generated_video_dir].append(result)

if (idx + 1) % prompt_interval == 0:
out_str = print_results(lpips_results, psnr_results, ssim_results, gt_video_dirs, gen_video_dirs)
print(f"Processed {idx + 1} / {total_len} videos. {out_str}")

out_str = print_results(lpips_results, psnr_results, ssim_results, gt_video_dirs, gen_video_dirs)

# save
with open(f"./batch_eval.txt", "w+") as f:
f.write(out_str)

print(f"Processed all videos. {out_str}")


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--gt_video_dirs", type=str, nargs="+")
parser.add_argument("--gen_video_dirs", type=str, nargs="+")

args = parser.parse_args()

main(args)
8 changes: 4 additions & 4 deletions eval/pab/experiments/opensora_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@


def eval_base(prompt_list):
config = OpenSoraPlanConfig()
config = OpenSoraPlanConfig(version="v110", transformer_type="65x512x512")
engine = VideoSysEngine(config)
generate_func(engine, prompt_list, "./samples/opensoraplan_base", loop=5)

Expand All @@ -15,7 +15,7 @@ def eval_pab1(prompt_list):
temporal_gap=4,
cross_gap=6,
)
config = OpenSoraPlanConfig(enable_pab=True, pab_config=pab_config)
config = OpenSoraPlanConfig(version="v110", transformer_type="65x512x512", enable_pab=True, pab_config=pab_config)
engine = VideoSysEngine(config)
generate_func(engine, prompt_list, "./samples/opensoraplan_pab1", loop=5)

Expand All @@ -26,7 +26,7 @@ def eval_pab2(prompt_list):
temporal_gap=5,
cross_gap=7,
)
config = OpenSoraPlanConfig(enable_pab=True, pab_config=pab_config)
config = OpenSoraPlanConfig(version="v110", transformer_type="65x512x512", enable_pab=True, pab_config=pab_config)
engine = VideoSysEngine(config)
generate_func(engine, prompt_list, "./samples/opensoraplan_pab2", loop=5)

Expand All @@ -37,7 +37,7 @@ def eval_pab3(prompt_list):
temporal_gap=7,
cross_gap=9,
)
config = OpenSoraPlanConfig(enable_pab=True, pab_config=pab_config)
config = OpenSoraPlanConfig(version="v110", transformer_type="65x512x512", enable_pab=True, pab_config=pab_config)
engine = VideoSysEngine(config)
generate_func(engine, prompt_list, "./samples/opensoraplan_pab3", loop=5)

Expand Down
67 changes: 67 additions & 0 deletions eval/pab/webvid/download.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import csv
import os
import time

import requests
import tqdm


def read_csv(csv_file):
with open(csv_file, "r") as f:
reader = csv.reader(f)
data = list(reader)
data = data[1:]
print(f"Read {len(data)} rows from {csv_file}")
return data


def select_csv(data, min_text_len, min_vid_len, select_num):
results = []
assert 0 <= min_vid_len <= 60
min_vid_len_str = f"PT00H00M{min_vid_len:02d}S"
for d in data:
# [id, link, duration, page, text]
if d[2] < min_vid_len_str:
continue
token_num = len(d[4].split(" "))
if token_num < min_text_len:
continue
results.append(d)
if len(results) == select_num:
break
return results


def save_data_list(data, save_path):
with open(save_path, "w") as f:
writer = csv.writer(f)
writer.writerow(["id", "link", "duration", "page", "text"])
for d in data:
writer.writerow(d)


def download_video(data, save_path):
os.makedirs(save_path, exist_ok=True)
for d in tqdm.tqdm(data):
url = d[1]
video_path = os.path.join(save_path, f"{d[0]}.mp4")
while True:
try:
r = requests.get(url, stream=True)
with open(video_path, "wb") as f:
for chunk in r.iter_content(chunk_size=1024):
if chunk:
f.write(chunk)
break
except ConnectionError:
time.sleep(1)
print(f"Failed to download {url}, retrying...")
continue
time.sleep(0.1)


if __name__ == "__main__":
data = read_csv("./datasets/webvid.csv")
selected_data = select_csv(data, 20, 5, 500)
save_data_list(selected_data, "./datasets/webvid_selected.csv")
download_video(selected_data, "./datasets/webvid")
50 changes: 50 additions & 0 deletions eval/pab/webvid/latte.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
from utils import generate_func, load_eval_prompts

from videosys import LatteConfig, LattePABConfig, VideoSysEngine


def eval_base(prompt_list):
config = LatteConfig()
engine = VideoSysEngine(config)
generate_func(engine, prompt_list, "./samples/latte_base")


def eval_pab1(prompt_list):
pab_config = LattePABConfig(
spatial_range=2,
temporal_range=3,
cross_range=6,
)
config = LatteConfig(enable_pab=True, pab_config=pab_config)
engine = VideoSysEngine(config)
generate_func(engine, prompt_list, "./samples/latte_pab1")


def eval_pab2(prompt_list):
pab_config = LattePABConfig(
spatial_range=3,
temporal_range=4,
cross_range=7,
)
config = LatteConfig(enable_pab=True, pab_config=pab_config)
engine = VideoSysEngine(config)
generate_func(engine, prompt_list, "./samples/latte_pab2")


def eval_pab3(prompt_list):
pab_config = LattePABConfig(
spatial_range=4,
temporal_range=6,
cross_range=9,
)
config = LatteConfig(enable_pab=True, pab_config=pab_config)
engine = VideoSysEngine(config)
generate_func(engine, prompt_list, "./samples/latte_pab3")


if __name__ == "__main__":
prompt_list = load_eval_prompts("./datasets/webvid_selected.csv")
eval_base(prompt_list)
eval_pab1(prompt_list)
eval_pab2(prompt_list)
eval_pab3(prompt_list)
Loading