Skip to content

Commit 54cb021

Browse files
committed
update
1 parent 7507d31 commit 54cb021

File tree

2 files changed

+39
-13
lines changed

2 files changed

+39
-13
lines changed

ns_vfs/api/run_with_nsvqa.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,17 @@
11
from __future__ import annotations
22

3-
import numpy as np
43
import json
4+
import time
5+
6+
import numpy as np
57
from cvias.image.detection.vllm_detection import VLLMDetection
68

79
from ns_vfs.automaton.video_automaton import VideoAutomaton
810
from ns_vfs.data.frame import FramesofInterest, VideoFrame
11+
from ns_vfs.dataloader.longvideobench import LongVideoBench
912
from ns_vfs.model_checking.stormpy import StormModelChecker
1013
from ns_vfs.percepter.single_vision_percepter import SingleVisionPercepter
1114
from ns_vfs.validator import FrameValidator
12-
from ns_vfs.dataloader.longvideobench import LongVideoBench
13-
from ns_vfs.dataloader.nextqa import NextQA
1415

1516

1617
def run_nsvs_nsvqa(
@@ -28,6 +29,7 @@ def run_nsvs_nsvqa(
2829
desired_fps: int | None = None,
2930
custom_prompt: str | None = None,
3031
) -> None:
32+
print(f"THE LTL FORMULA IS: {ltl_formula}")
3133
# Yolo model initialization
3234
vllm_model = VLLMDetection(
3335
api_key=api_key,
@@ -52,11 +54,15 @@ def run_nsvs_nsvqa(
5254
cv_models=vllm_model,
5355
)
5456

55-
frame_validator = FrameValidator(ltl_formula=ltl_formula)
57+
frame_validator = FrameValidator(
58+
ltl_formula=ltl_formula,
59+
threshold_of_probability=0.3,
60+
)
5661
frame_idx = 0
57-
model_checker_is_filter: bool = (False,)
58-
model_checker_type: str = ("sparse_ma",)
62+
model_checker_is_filter: bool = False
63+
model_checker_type: str = "dtmc" # "sparse_ma"
5964
for nsvqa_input in nsvqa_input_data:
65+
start_time = time.time()
6066
sequence_of_frames = nsvqa_input["frames"]
6167
detected_objects: dict = vision_percepter.perceive(
6268
image=sequence_of_frames,
@@ -73,6 +79,8 @@ def run_nsvs_nsvqa(
7379
activity_of_interest=activity_of_interest,
7480
)
7581
frame_idx += 1
82+
end_time = time.time()
83+
print(f"Time taken to detect objects: {end_time - start_time} seconds")
7684

7785
# 1. frame validation
7886
if frame_validator.validate_frame(frame=frame):
@@ -96,6 +104,15 @@ def run_nsvs_nsvqa(
96104
print(frame_of_interest.foi_list)
97105
# save result
98106
if output_path:
107+
if output_path.endswith("/"):
108+
foi_output_path = f"{output_path}frames_of_interest.txt"
109+
else:
110+
foi_output_path = f"{output_path}/frames_of_interest.txt"
111+
with open(foi_output_path, "w") as f:
112+
f.write("Detected frames of interest:\n")
113+
f.write("--------------------------------\n")
114+
for frame in frame_of_interest.foi_list:
115+
f.write(f"{frame}\n")
99116
frame_of_interest.save(path=output_path)
100117
print(f"\nResults saved in {output_path}")
101118

@@ -104,24 +121,28 @@ def run_nsvs_nsvqa(
104121

105122
if __name__ == "__main__":
106123
# input_data_path = "/nas/mars/experiment_result/nsvqa/1_puls/longvideobench/longvideobench-outputs-fixed-specs-v2.json"
107-
input_data_path = "/nas/mars/experiment_result/nsvqa/1_puls/next-dataset/nextqa-outputs.json"
108-
with open(input_data_path, 'r', encoding='utf-8') as f:
124+
# input_data_path = "/nas/mars/experiment_result/nsvqa/1_puls/next-dataset/nextqa-outputs.json"
125+
# input_data_path = "/nas/mars/experiment_result/nsvqa/1_puls/longvideobench/longvideobench-run1.json"
126+
input_data_path = "/nas/mars/experiment_result/nsvqa/1_puls/longvideobench/longvideobench-run2.json"
127+
128+
with open(input_data_path, "r", encoding="utf-8") as f:
109129
data = json.load(f)
110130

111131
for sample in data:
112132
# loader = LongVideoBench(sample["video_path"], sample["subtitle_path"])
113-
loader = NextQA(sample["video_path"], sample["subtitle_path"])
133+
loader = LongVideoBench(sample["video_path"], sample["subtitle_path"])
114134
nsvqa_input = loader.load_all()
115-
extracted = sample["video_path"].split('/')[-1].split('.')[0]
135+
extracted = sample["video_path"].split("/")[-1].split(".")[0]
116136

117137
run_nsvs_nsvqa(
118138
nsvqa_input_data=nsvqa_input,
119139
desired_interval_in_sec=None,
140+
api_base="http://localhost:8002/v1",
120141
desired_fps=30,
121142
proposition_set=sample["proposition"],
122143
ltl_formula=sample["specification"],
123144
output_path=f"/nas/mars/experiment_result/nsvqa/2_nsvs/longvideobench/{extracted}/",
124145
threshold_satisfaction_probability=0.80,
125146
frame_scale=None,
126-
calibration_method="temperature_scaling",
147+
calibration_method=None, # "temperature_scaling",
127148
)

ns_vfs/data/frame.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -165,9 +165,14 @@ def save(self, path: str | Path):
165165

166166
frame_path.mkdir(parents=True, exist_ok=True)
167167
annotation_path.mkdir(parents=True, exist_ok=True)
168-
169168
for idx, img in enumerate(self.frame_images):
170-
Image.fromarray(img).save(f"{frame_path}/{idx}.png")
169+
if isinstance(img, np.ndarray):
170+
Image.fromarray(img).save(f"{frame_path}/{idx}.png")
171+
else:
172+
for i, img_i in enumerate(img):
173+
Image.fromarray(img_i).save(
174+
f"{frame_path}/frame_chunk_{idx}_{i}.png"
175+
)
171176
try:
172177
if (
173178
len(self.annotated_images) > 0

0 commit comments

Comments
 (0)