ID Not Writing To Label File #2067
Replies: 4 comments 7 replies
-
Which file? Could you give more context on what you are running? |
Beta Was this translation helpful? Give feedback.
-
The benchmark pipeline passes without issues. Where MOT results are generated and fed to TrackEval for evaluation. I have never encountered this issue. Please provide more context of what you are running. |
Beta Was this translation helpful? Give feedback.
-
Aha, I see the issue now. Track does not support this generation at the moment. Try
but this will only work if you have a mot17/20/dancetrack-like folder structure |
Beta Was this translation helpful? Give feedback.
-
An easy way of enabling this is by mokey-patching ultralytics' # Mikel Broström 🔥 Yolo Tracking 🧾 AGPL-3.0 license
import argparse
from functools import partial
from pathlib import Path
import cv2
import torch
from boxmot import TRACKERS
from boxmot.tracker_zoo import create_tracker
from boxmot.utils import ROOT, TRACKER_CONFIGS, WEIGHTS
from boxmot.utils.checks import RequirementsChecker
from boxmot.engine.detectors import default_imgsz, get_yolo_inferer, is_ultralytics_model
checker = RequirementsChecker()
checker.check_packages(("ultralytics", )) # install
from ultralytics import YOLO
from ultralytics.utils.plotting import Annotator # ultralytics.yolo.utils.plotting is deprecated
from ultralytics.utils.plotting import colors
from ultralytics.utils import plotting
# Make every drawing call a no-op
plotting.Annotator.box = lambda *args, **kwargs: None
plotting.Annotator.box_label = lambda *args, **kwargs: None
plotting.Annotator.line = lambda *args, **kwargs: None
from pathlib import Path
from ultralytics.engine.results import Results
def save_txt(self, txt_file: Union[str, Path], save_conf: bool = False) -> str:
"""
Save detection results to a text file.
Args:
txt_file (str | Path): Path to the output text file.
save_conf (bool): Whether to include confidence scores in the output.
Returns:
(str): Path to the saved text file.
Examples:
>>> from ultralytics import YOLO
>>> model = YOLO("yolo11n.pt")
>>> results = model("path/to/image.jpg")
>>> for result in results:
>>> result.save_txt("output.txt")
Notes:
- The file will contain one line per detection or classification with the following structure:
- For detections: `class confidence x_center y_center width height`
- For classifications: `confidence class_name`
- For masks and keypoints, the specific formats will vary accordingly.
- The function will create the output directory if it does not exist.
- If save_conf is False, the confidence scores will be excluded from the output.
- Existing contents of the file will not be overwritten; new results will be appended.
"""
is_obb = self.obb is not None
boxes = self.obb if is_obb else self.boxes
masks = self.masks
probs = self.probs
kpts = self.keypoints
texts = []
if probs is not None:
# Classify
[texts.append(f"{probs.data[j]:.2f} {self.names[j]}") for j in probs.top5]
elif boxes:
# Detect/segment/pose
for j, d in enumerate(boxes):
if not d.is_track:
continue
c, conf, id = int(d.cls), float(d.conf), int(d.id.item())
line = (c, *(d.xyxyxyxyn.view(-1) if is_obb else d.xywhn.view(-1)))
if masks:
seg = masks[j].xyn[0].copy().reshape(-1) # reversed mask.xyn, (n,2) to (n*2)
line = (c, *seg)
if kpts is not None:
kpt = torch.cat((kpts[j].xyn, kpts[j].conf[..., None]), 2) if kpts[j].has_visible else kpts[j].xyn
line += (*kpt.reshape(-1).tolist(),)
line += (conf,) * save_conf + (() if id is None else (id,))
texts.append(("%g " * len(line)).rstrip() % line)
if texts:
Path(txt_file).parent.mkdir(parents=True, exist_ok=True) # make directory
with open(txt_file, "a", encoding="utf-8") as f:
f.writelines(text + "\n" for text in texts)
return str(txt_file)
# overwrite the method on the class
Results.save_txt = my_save_txt
def on_predict_start(predictor, persist=False):
"""
Initialize trackers for object tracking during prediction.
Args:
predictor (object): The predictor object to initialize trackers for.
persist (bool, optional): Whether to persist the trackers if they already exist. Defaults to False.
"""
assert predictor.custom_args.tracking_method in TRACKERS, \
f"'{predictor.custom_args.tracking_method}' is not supported. Supported ones are {TRACKERS}"
tracking_config = TRACKER_CONFIGS / (predictor.custom_args.tracking_method + '.yaml')
trackers = []
for i in range(predictor.dataset.bs):
tracker = create_tracker(
predictor.custom_args.tracking_method,
tracking_config,
predictor.custom_args.reid_model,
predictor.device,
predictor.custom_args.half,
predictor.custom_args.per_class,
)
# motion only models do not have
if hasattr(tracker, "model"):
tracker.model.warmup()
trackers.append(tracker)
predictor.trackers = trackers
# callback to plot trajectories on each frame
def plot_trajectories(predictor):
# predictor.results is a list of Results, one per frame in the batch
for i, result in enumerate(predictor.results):
tracker = predictor.trackers[i]
result.orig_img = tracker.plot_results(result.orig_img, predictor.custom_args.show_trajectories)
cv2.waitKey(1)
@torch.no_grad()
def main(args):
if args.imgsz is None:
args.imgsz = default_imgsz(args.yolo_model)
yolo = YOLO(
args.yolo_model if is_ultralytics_model(args.yolo_model) else "yolov8n.pt",
)
results = yolo.track(
source=args.source,
conf=args.conf,
iou=args.iou,
agnostic_nms=args.agnostic_nms,
show=True,
stream=True,
device=args.device,
show_conf=args.show_conf,
save_txt=args.save_txt,
show_labels=args.show_labels,
save=args.save,
verbose=args.verbose,
exist_ok=args.exist_ok,
project=args.project,
name=args.name,
classes=args.classes,
imgsz=args.imgsz,
vid_stride=args.vid_stride,
line_width=args.line_width,
save_crop=args.save_crop,
)
yolo.add_callback("on_predict_start", partial(on_predict_start, persist=True))
yolo.add_callback("on_predict_postprocess_end", plot_trajectories)
if not is_ultralytics_model(args.yolo_model):
# replace yolov8 model
m = get_yolo_inferer(args.yolo_model)
yolo_model = m(
model=args.yolo_model,
device=yolo.predictor.device,
args=yolo.predictor.args,
)
yolo.predictor.model = yolo_model
# If current model is YOLOX, change the preprocess and postprocess
if not is_ultralytics_model(args.yolo_model):
# add callback to save image paths for further processing
yolo.add_callback(
"on_predict_batch_start", lambda p: yolo_model.update_im_paths(p)
)
yolo.predictor.preprocess = lambda imgs: yolo_model.preprocess(im=imgs)
yolo.predictor.postprocess = lambda preds, im, im0s: yolo_model.postprocess(
preds=preds, im=im, im0s=im0s
)
# store custom args in predictor
yolo.predictor.custom_args = args
for _ in results:
pass
if __name__ == "__main__":
main() |
Beta Was this translation helpful? Give feedback.
Uh oh!
There was an error while loading. Please reload this page.
-
Is there a reason why the ID sometimes will not write to the .txt file? For example frame n might have (0, 0.561567, 0.247865, 0.396077, 0.261982, 3) but frame n + 1 will have (0, 0.943255, 0.225126, 0.110261, 0.0804192)
Beta Was this translation helpful? Give feedback.
All reactions