Skip to content

Commit 538a401

Browse files
CopilotFlova
andcommitted
Pass yoeo substruct directly to YOEO components instead of full config
Co-authored-by: Flova <[email protected]>
1 parent 0a56fc6 commit 538a401

File tree

4 files changed

+48
-49
lines changed

4 files changed

+48
-49
lines changed

bitbots_vision/bitbots_vision/vision.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def _dynamic_reconfigure_callback(self, params) -> SetParametersResult:
8080
return SetParametersResult(successful=True)
8181

8282
def _configure_vision(self, config) -> None:
83-
yoeo.YOEOObjectManager.configure(config)
83+
yoeo.YOEOObjectManager.configure(config.yoeo)
8484

8585
debug_image = debug.DebugImage(config.component_debug_image_active)
8686
self._debug_image = debug_image

bitbots_vision/bitbots_vision/vision_modules/yoeo/object_manager.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -73,19 +73,19 @@ def is_team_color_detection_supported(cls) -> bool:
7373
return cls._model_config.team_colors_are_provided()
7474

7575
@classmethod
76-
def configure(cls, config: parameters.Params) -> None:
76+
def configure(cls, yoeo_config) -> None:
7777
if not cls._package_directory_set:
7878
logger.error("Package directory not set!")
7979

80-
framework = config.yoeo.framework
80+
framework = yoeo_config.framework
8181
cls._verify_framework_parameter(framework)
8282

83-
model_path = cls._get_full_model_path(config.yoeo.model_path)
83+
model_path = cls._get_full_model_path(yoeo_config.model_path)
8484
cls._verify_required_neural_network_files_exist(framework, model_path)
8585

86-
cls._configure_yoeo_instance(config, framework, model_path)
86+
cls._configure_yoeo_instance(yoeo_config, framework, model_path)
8787

88-
cls._config = config
88+
cls._config = yoeo_config
8989
cls._framework = framework
9090
cls._model_path = model_path
9191

@@ -108,13 +108,13 @@ def _model_files_exist(cls, framework: str, model_path: str) -> bool:
108108
return cls._HANDLERS_BY_NAME[framework].model_files_exist(model_path)
109109

110110
@classmethod
111-
def _configure_yoeo_instance(cls, config: parameters.Params, framework: str, model_path: str) -> None:
111+
def _configure_yoeo_instance(cls, yoeo_config, framework: str, model_path: str) -> None:
112112
if cls._new_yoeo_handler_is_needed(framework, model_path):
113113
cls._load_model_config(model_path)
114114
cls._instantiate_new_yoeo_handler(config, framework, model_path)
115-
elif cls._yoeo_parameters_have_changed(config):
115+
elif cls._yoeo_parameters_have_changed(yoeo_config):
116116
assert cls._yoeo_instance is not None, "YOEO handler instance not set!"
117-
cls._yoeo_instance.configure(config)
117+
cls._yoeo_instance.configure(yoeo_config)
118118

119119
@classmethod
120120
def _new_yoeo_handler_is_needed(cls, framework: str, model_path: str) -> bool:
@@ -125,9 +125,9 @@ def _load_model_config(cls, model_path: str) -> None:
125125
cls._model_config = ModelConfigLoader.load_from(model_path)
126126

127127
@classmethod
128-
def _instantiate_new_yoeo_handler(cls, config: parameters.Params, framework: str, model_path: str) -> None:
128+
def _instantiate_new_yoeo_handler(cls, yoeo_config, framework: str, model_path: str) -> None:
129129
cls._yoeo_instance = cls._HANDLERS_BY_NAME[framework](
130-
config,
130+
yoeo_config,
131131
model_path,
132132
cls._model_config.get_detection_classes(),
133133
cls._model_config.get_robot_class_ids(),
@@ -136,9 +136,9 @@ def _instantiate_new_yoeo_handler(cls, config: parameters.Params, framework: str
136136
logger.info(f"Using {cls._yoeo_instance.__class__.__name__}")
137137

138138
@classmethod
139-
def _yoeo_parameters_have_changed(cls, new_config: parameters.Params) -> bool:
139+
def _yoeo_parameters_have_changed(cls, new_yoeo_config) -> bool:
140140
if cls._config is None:
141141
return True
142142

143143
# Compare YOEO parameters using the hierarchical structure
144-
return cls._config.yoeo != new_config.yoeo
144+
return cls._config != new_yoeo_config

bitbots_vision/bitbots_vision/vision_modules/yoeo/yoeo_handlers.py

Lines changed: 29 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@ class IYOEOHandler(ABC):
2424
"""
2525

2626
@abstractmethod
27-
def configure(self, config: parameters.Params) -> None:
27+
def configure(self, yoeo_config) -> None:
2828
"""
2929
Allows to (re-) configure the YOEO handler.
3030
"""
@@ -100,7 +100,7 @@ class YOEOHandlerTemplate(IYOEOHandler):
100100

101101
def __init__(
102102
self,
103-
config: parameters.Params,
103+
yoeo_config,
104104
model_directory: str,
105105
det_class_names: list[str],
106106
det_robot_class_ids: list[int],
@@ -119,12 +119,12 @@ def __init__(
119119
self._seg_class_names: list[str] = seg_class_names
120120
self._seg_masks: dict = dict()
121121

122-
self._use_caching: bool = config.caching
122+
self._use_caching: bool = yoeo_config.caching
123123

124124
logger.debug("Leaving YOEOHandlerTemplate constructor")
125125

126-
def configure(self, config: parameters.Params) -> None:
127-
self._use_caching = config.caching
126+
def configure(self, yoeo_config) -> None:
127+
self._use_caching = yoeo_config.caching
128128

129129
def get_available_detection_class_names(self) -> list[str]:
130130
return self._det_class_names
@@ -213,7 +213,7 @@ class YOEOHandlerONNX(YOEOHandlerTemplate):
213213

214214
def __init__(
215215
self,
216-
config: parameters.Params,
216+
yoeo_config,
217217
model_directory: str,
218218
det_class_names: list[str],
219219
det_robot_class_ids: list[int],
@@ -240,8 +240,8 @@ def __init__(
240240
self._det_postprocessor: utils.IDetectionPostProcessor = utils.DefaultDetectionPostProcessor(
241241
image_preprocessor=self._img_preprocessor,
242242
output_img_size=self._input_layer.shape[2],
243-
conf_thresh=config.yoeo.conf_threshold,
244-
nms_thresh=config.yoeo.nms_threshold,
243+
conf_thresh=yoeo_config.conf_threshold,
244+
nms_thresh=yoeo_config.nms_threshold,
245245
robot_class_ids=self.get_robot_class_ids(),
246246
)
247247
self._seg_postprocessor: utils.ISegmentationPostProcessor = utils.DefaultSegmentationPostProcessor(
@@ -250,13 +250,13 @@ def __init__(
250250

251251
logger.debug(f"Leaving {self.__class__.__name__} constructor")
252252

253-
def configure(self, config: parameters.Params) -> None:
253+
def configure(self, yoeo_config) -> None:
254254
super().configure(config)
255255
self._det_postprocessor.configure(
256256
image_preprocessor=self._img_preprocessor,
257257
output_img_size=self._input_layer.shape[2],
258-
conf_thresh=config.yoeo.conf_threshold,
259-
nms_thresh=config.yoeo.nms_threshold,
258+
conf_thresh=yoeo_config.conf_threshold,
259+
nms_thresh=yoeo_config.nms_threshold,
260260
robot_class_ids=self.get_robot_class_ids(),
261261
)
262262

@@ -286,7 +286,7 @@ class YOEOHandlerOpenVino(YOEOHandlerTemplate):
286286

287287
def __init__(
288288
self,
289-
config: parameters.Params,
289+
yoeo_config,
290290
model_directory: str,
291291
det_class_names: list[str],
292292
det_robot_class_ids: list[int],
@@ -322,8 +322,8 @@ def __init__(
322322
self._det_postprocessor: utils.IDetectionPostProcessor = utils.DefaultDetectionPostProcessor(
323323
image_preprocessor=self._img_preprocessor,
324324
output_img_size=self._input_layer.shape[2],
325-
conf_thresh=config.yoeo.conf_threshold,
326-
nms_thresh=config.yoeo.nms_threshold,
325+
conf_thresh=yoeo_config.conf_threshold,
326+
nms_thresh=yoeo_config.nms_threshold,
327327
robot_class_ids=self.get_robot_class_ids(),
328328
)
329329
self._seg_postprocessor: utils.ISegmentationPostProcessor = utils.DefaultSegmentationPostProcessor(
@@ -339,13 +339,13 @@ def _select_device(self) -> str:
339339
device = "CPU"
340340
return device
341341

342-
def configure(self, config: parameters.Params) -> None:
342+
def configure(self, yoeo_config) -> None:
343343
super().configure(config)
344344
self._det_postprocessor.configure(
345345
image_preprocessor=self._img_preprocessor,
346346
output_img_size=self._input_layer.shape[2],
347-
conf_thresh=config.yoeo.conf_threshold,
348-
nms_thresh=config.yoeo.nms_threshold,
347+
conf_thresh=yoeo_config.conf_threshold,
348+
nms_thresh=yoeo_config.nms_threshold,
349349
robot_class_ids=self.get_robot_class_ids(),
350350
)
351351

@@ -374,7 +374,7 @@ class YOEOHandlerPytorch(YOEOHandlerTemplate):
374374

375375
def __init__(
376376
self,
377-
config: parameters.Params,
377+
yoeo_config,
378378
model_directory: str,
379379
det_class_names: list[str],
380380
det_robot_class_ids: list[int],
@@ -400,8 +400,8 @@ def __init__(
400400
logger.debug(f"Loading files...\n\t{config_path}\n\t{weights_path}")
401401
self._model = torch_models.load_model(config_path, weights_path)
402402

403-
self._conf_thresh: float = config.yoeo.conf_threshold
404-
self._nms_thresh: float = config.yoeo.nms_threshold
403+
self._conf_thresh: float = yoeo_config.conf_threshold
404+
self._nms_thresh: float = yoeo_config.nms_threshold
405405
self._group_config: torch_GroupConfig = self._update_group_config()
406406

407407
logger.debug(f"Leaving {self.__class__.__name__} constructor")
@@ -411,10 +411,10 @@ def _update_group_config(self):
411411

412412
return self.torch_group_config(group_ids=robot_class_ids, surrogate_id=robot_class_ids[0])
413413

414-
def configure(self, config: parameters.Params) -> None:
414+
def configure(self, yoeo_config) -> None:
415415
super().configure(config)
416-
self._conf_thresh = config.yoeo.conf_threshold
417-
self._nms_thresh = config.yoeo.nms_threshold
416+
self._conf_thresh = yoeo_config.conf_threshold
417+
self._nms_thresh = yoeo_config.nms_threshold
418418
self._group_config = self._update_group_config()
419419

420420
@staticmethod
@@ -448,7 +448,7 @@ class YOEOHandlerTVM(YOEOHandlerTemplate):
448448

449449
def __init__(
450450
self,
451-
config: parameters.Params,
451+
yoeo_config,
452452
model_directory: str,
453453
det_class_names: list[str],
454454
det_robot_class_ids: list[int],
@@ -487,8 +487,8 @@ def __init__(
487487
self._det_postprocessor: utils.IDetectionPostProcessor = utils.DefaultDetectionPostProcessor(
488488
image_preprocessor=self._img_preprocessor,
489489
output_img_size=self._input_layer_shape[2],
490-
conf_thresh=config.yoeo.conf_threshold,
491-
nms_thresh=config.yoeo.nms_threshold,
490+
conf_thresh=yoeo_config.conf_threshold,
491+
nms_thresh=yoeo_config.nms_threshold,
492492
robot_class_ids=self.get_robot_class_ids(),
493493
)
494494
self._seg_postprocessor: utils.ISegmentationPostProcessor = utils.DefaultSegmentationPostProcessor(
@@ -497,13 +497,13 @@ def __init__(
497497

498498
logger.debug(f"Leaving {self.__class__.__name__} constructor")
499499

500-
def configure(self, config: parameters.Params) -> None:
500+
def configure(self, yoeo_config) -> None:
501501
super().configure(config)
502502
self._det_postprocessor.configure(
503503
image_preprocessor=self._img_preprocessor,
504504
output_img_size=self._input_layer_shape[2],
505-
conf_thresh=config.yoeo.conf_threshold,
506-
nms_thresh=config.yoeo.nms_threshold,
505+
conf_thresh=yoeo_config.conf_threshold,
506+
nms_thresh=yoeo_config.nms_threshold,
507507
robot_class_ids=self.get_robot_class_ids(),
508508
)
509509

bitbots_vision/config/vision_parameters.yaml

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,11 @@ bitbots_vision:
106106
description: "The neural network framework that should be used"
107107
validation:
108108
one_of<>: [["pytorch", "openvino", "onnx", "tvm"]]
109+
110+
caching:
111+
type: bool
112+
default_value: true
113+
description: "Used to deactivate caching for profiling reasons"
109114

110115
# Ball detection parameters
111116
ball_candidate_rating_threshold:
@@ -120,10 +125,4 @@ bitbots_vision:
120125
default_value: 1
121126
description: "The maximum number of balls that should be published"
122127
validation:
123-
bounds<>: [0, 50]
124-
125-
# Caching parameter
126-
caching:
127-
type: bool
128-
default_value: true
129-
description: "Used to deactivate caching for profiling reasons"
128+
bounds<>: [0, 50]

0 commit comments

Comments
 (0)