Skip to content

Commit 366de3f

Browse files
authored
Improve the warm-up models device logic
Update the `default_devices` logic to load models onto a single device (`cuda` if available, otherwise `cpu`), instead of both devices, to streamline resource usage.
1 parent 650ac4d commit 366de3f

File tree

3 files changed

+32
-3
lines changed

3 files changed

+32
-3
lines changed

src/marqo/inference/native_inference/remote/server/on_start_script.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def __init__(self, config: Config):
104104
# TBD to include cross-encoder/ms-marco-TinyBERT-L-2-v2
105105

106106
# TODO [Refactoring device logic] use device info gathered from device manager
107-
self.default_devices = ['cpu'] if not torch.cuda.is_available() else ['cuda', 'cpu']
107+
self.default_devices = ['cpu'] if not torch.cuda.is_available() else ['cuda']
108108

109109
self.logger.info(f"pre-loading {self.models} onto devices={self.default_devices}")
110110

src/marqo/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
__version__ = "2.19.0"
1+
__version__ = "2.19.1"
22

33
def get_version() -> str:
44
return f"{__version__}"

tests/unit_tests/marqo/inference/native_inference/remote/server/test_on_start_script.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,4 +216,33 @@ def test_missing_punkt_downloaded(self):
216216
checker = on_start_script.CheckNLTKTokenizers()
217217
with self.assertRaises(StartupSanityCheckError):
218218
checker.run()
219-
mock_nltk_download.assert_any_call("punkt_tab")
219+
mock_nltk_download.assert_any_call("punkt_tab")
220+
221+
def test_models_only_load_to_one_device(self):
222+
"""
223+
Ensure models are only loaded to one device (cuda if available, else cpu) when warming up,
224+
not to all devices.
225+
"""
226+
with mock.patch("marqo.inference.native_inference.remote.server.on_start_script.torch.cuda.is_available") as mock_cuda_available, \
227+
mock.patch("os.environ", {
228+
enums.EnvVars.MARQO_MODELS_TO_PRELOAD: json.dumps(["LanguageBind/Video_V1.5_FT_Audio_FT_Image"])
229+
}):
230+
231+
for cuda_available in [True, False]:
232+
expected_device = "cuda" if cuda_available else "cpu"
233+
mock_cuda_available.return_value = cuda_available
234+
235+
cache_model_module = on_start_script.CacheModels(self.mock_config)
236+
self.assertEqual(cache_model_module.default_devices, [expected_device])
237+
238+
with mock.patch.object(cache_model_module, "_preload_model") as mock_preload_model:
239+
cache_model_module.run()
240+
mock_preload_model.assert_called_with(
241+
model="LanguageBind/Video_V1.5_FT_Audio_FT_Image",
242+
content="this is a test string",
243+
device=expected_device
244+
)
245+
246+
# Ensure the other device is not used
247+
other_device = "cpu" if expected_device == "cuda" else "cuda"
248+
self.assertNotIn(other_device, mock_preload_model.call_args[1]["device"])

0 commit comments

Comments
 (0)