Skip to content
This repository was archived by the owner on Feb 26, 2025. It is now read-only.

Commit 2b0b888

Browse files
Restore analysis_config.output, raise if SHMDIR isn't set (#42)
- Allow to access the cache config and the output path of individual analyses with ``analysis_config.cache`` and ``analysis_config.output``, as a shortcut to ``analysis_config.cache.path``. - Raise an error if the env variable ``SHMDIR`` isn't set, instead of logging a warning.
1 parent 0c9106b commit 2b0b888

File tree

10 files changed

+61
-48
lines changed

10 files changed

+61
-48
lines changed

CHANGELOG.rst

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,16 @@
11
Changelog
22
=========
33

4+
Version 0.11.0
5+
--------------
6+
7+
Improvements
8+
~~~~~~~~~~~~
9+
10+
- Allow to access the cache config and the output path of individual analyses with ``analysis_config.cache`` and ``analysis_config.output``, as a shortcut to ``analysis_config.cache.path``.
11+
- Raise an error if the env variable ``SHMDIR`` isn't set, instead of logging a warning.
12+
13+
414
Version 0.10.1
515
--------------
616

src/blueetl/analysis.py

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from blueetl.cache import CacheManager
1212
from blueetl.campaign.config import SimulationCampaign
1313
from blueetl.config.analysis import init_multi_analysis_configuration
14-
from blueetl.config.analysis_model import CacheConfig, MultiAnalysisConfig, SingleAnalysisConfig
14+
from blueetl.config.analysis_model import MultiAnalysisConfig, SingleAnalysisConfig
1515
from blueetl.features import FeaturesCollection
1616
from blueetl.repository import Repository
1717
from blueetl.resolver import AttrResolver, Resolver
@@ -46,19 +46,16 @@ def from_config(
4646
cls,
4747
analysis_config: SingleAnalysisConfig,
4848
simulations_config: SimulationCampaign,
49-
cache_config: CacheConfig,
5049
resolver: Resolver,
5150
) -> "Analyzer":
5251
"""Initialize the Analyzer from the given configuration.
5352
5453
Args:
5554
analysis_config: analysis configuration.
5655
simulations_config: simulation campaign configuration.
57-
cache_config: cache configuration.
5856
resolver: resolver instance.
5957
"""
6058
cache_manager = CacheManager(
61-
cache_config=cache_config,
6259
analysis_config=analysis_config,
6360
simulations_config=simulations_config,
6461
)
@@ -214,9 +211,6 @@ def _init_analyzers(self) -> dict[str, Analyzer]:
214211
name: Analyzer.from_config(
215212
analysis_config=analysis_config,
216213
simulations_config=simulations_config,
217-
cache_config=self.global_config.cache.model_copy(
218-
update={"path": self.global_config.cache.path / name}
219-
),
220214
resolver=resolver,
221215
)
222216
for name, analysis_config in self.global_config.analysis.items()

src/blueetl/cache.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from blueetl_core.utils import is_subfilter
1616

1717
from blueetl.campaign.config import SimulationCampaign
18-
from blueetl.config.analysis_model import CacheConfig, FeaturesConfig, SingleAnalysisConfig
18+
from blueetl.config.analysis_model import FeaturesConfig, SingleAnalysisConfig
1919
from blueetl.store.base import BaseStore
2020
from blueetl.store.feather import FeatherStore
2121
from blueetl.store.parquet import ParquetStore
@@ -143,17 +143,17 @@ class CacheManager:
143143

144144
def __init__(
145145
self,
146-
cache_config: CacheConfig,
147146
analysis_config: SingleAnalysisConfig,
148147
simulations_config: SimulationCampaign,
149148
) -> None:
150149
"""Initialize the object.
151150
152151
Args:
153-
cache_config: cache configuration dict.
154-
analysis_config: analysis configuration dict.
152+
analysis_config: analysis configuration.
155153
simulations_config: simulations campaign configuration.
156154
"""
155+
cache_config = analysis_config.cache
156+
assert cache_config is not None
157157
self._output_dir = cache_config.path
158158
if cache_config.clear:
159159
self._clear_cache()

src/blueetl/config/analysis.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,9 @@ def expand_zip(params: dict, params_zip: dict) -> Iterator[dict]:
171171

172172

173173
def _resolve_analysis_configs(global_config: MultiAnalysisConfig) -> None:
174-
for config in global_config.analysis.values():
174+
global_cache_path = global_config.cache.path
175+
for name, config in global_config.analysis.items():
176+
config.cache = global_config.cache.model_copy(update={"path": global_cache_path / name})
175177
config.simulations_filter = global_config.simulations_filter
176178
config.simulations_filter_in_memory = global_config.simulations_filter_in_memory
177179
config.features = _resolve_features(config.features)

src/blueetl/config/analysis_model.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -182,6 +182,7 @@ class FeaturesConfig(BaseModel):
182182
class SingleAnalysisConfig(BaseModel):
183183
"""SingleAnalysisConfig Model."""
184184

185+
cache: Optional[CacheConfig] = None
185186
simulations_filter: dict[str, Any] = {}
186187
simulations_filter_in_memory: dict[str, Any] = {}
187188
extraction: ExtractionConfig
@@ -204,6 +205,11 @@ def handle_deprecated_fields(cls, data: Any) -> Any:
204205
data.pop("output", None)
205206
return data
206207

208+
@property
209+
def output(self) -> Optional[Path]:
210+
"""Shortcut to the base output path of the analysis."""
211+
return self.cache.path if self.cache else None
212+
207213

208214
class MultiAnalysisConfig(BaseModel):
209215
"""MultiAnalysisConfig Model."""

src/blueetl/utils.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,14 @@
99
import time
1010
from collections.abc import Iterable, Iterator
1111
from contextlib import contextmanager
12+
from enum import Enum
1213
from functools import cache, cached_property
1314
from pathlib import Path
1415
from typing import Any, Callable, Optional, Union
1516

1617
import pandas as pd
1718
import yaml
19+
from pydantic import BaseModel
1820

1921
from blueetl.constants import DTYPES
2022
from blueetl.types import StrOrPath
@@ -190,12 +192,9 @@ def checksum_json(obj: Any) -> str:
190192
@cache
191193
def _get_internal_yaml_dumper() -> type[yaml.SafeDumper]:
192194
"""Return the custom internal yaml dumper class."""
193-
# pylint: disable=import-outside-toplevel
194-
# imported here because optional
195-
from pydantic import BaseModel
196-
197195
_representers = {
198196
Path: str,
197+
Enum: lambda data: data.value,
199198
BaseModel: lambda data: data.dict(),
200199
}
201200

@@ -336,12 +335,15 @@ def copy_config(src: StrOrPath, dst: StrOrPath) -> None:
336335
dump_yaml(dst, config, default_flow_style=None)
337336

338337

339-
def get_shmdir() -> Optional[Path]:
340-
"""Return the shared memory directory, or None if not set."""
338+
def get_shmdir() -> Path:
339+
"""Return the shared memory directory, or raise an error if not set."""
341340
shmdir = os.getenv("SHMDIR")
342341
if not shmdir:
343-
L.warning("SHMDIR should be set to the shared memory directory")
344-
return None
342+
raise RuntimeError(
343+
"SHMDIR must be set to the shared memory directory. "
344+
"The variable should be automatically set when running on an allocated node, "
345+
"but it's not set when connecting via SSH to a pre-allocated node."
346+
)
345347
shmdir = Path(shmdir)
346348
if not shmdir.is_dir():
347349
raise RuntimeError("SHMDIR must be set to an existing directory")

tests/unit/config/test_analysis.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import pytest
22

33
from blueetl.config import analysis as test_module
4-
from blueetl.config.analysis_model import FeaturesConfig, MultiAnalysisConfig
4+
from blueetl.config.analysis_model import FeaturesConfig, MultiAnalysisConfig, SingleAnalysisConfig
55
from blueetl.utils import load_yaml
66
from tests.functional.utils import TEST_DATA_PATH as TEST_DATA_PATH_FUNCTIONAL
77
from tests.unit.utils import TEST_DATA_PATH as TEST_DATA_PATH_UNIT
@@ -191,3 +191,9 @@ def test_init_multi_analysis_configuration(config_file):
191191
config_dict, base_path=base_path, extra_params={}
192192
)
193193
assert isinstance(result, MultiAnalysisConfig)
194+
assert result.cache.path == base_path / config_dict["cache"]["path"]
195+
assert len(result.analysis) > 0
196+
for name, analysis_config in result.analysis.items():
197+
assert isinstance(analysis_config, SingleAnalysisConfig)
198+
assert analysis_config.cache is not None
199+
assert analysis_config.output == result.cache.path / name

tests/unit/test_cache.py

Lines changed: 8 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -10,9 +10,7 @@
1010
def cache_manager(global_config):
1111
simulations_config = SimulationCampaign.load(global_config.simulation_campaign)
1212
analysis_config = global_config.analysis["spikes"]
13-
cache_config = global_config.cache
1413
instance = test_module.CacheManager(
15-
cache_config=cache_config,
1614
analysis_config=analysis_config,
1715
simulations_config=simulations_config,
1816
)
@@ -62,10 +60,8 @@ def test_lock_manager_shared(tmp_path):
6260
def test_cache_manager_init_and_close(global_config):
6361
simulations_config = SimulationCampaign.load(global_config.simulation_campaign)
6462
analysis_config = global_config.analysis["spikes"]
65-
cache_config = global_config.cache
6663

6764
instance = test_module.CacheManager(
68-
cache_config=cache_config,
6965
analysis_config=analysis_config,
7066
simulations_config=simulations_config,
7167
)
@@ -79,10 +75,8 @@ def test_cache_manager_init_and_close(global_config):
7975
def test_cache_manager_to_readonly(global_config):
8076
simulations_config = SimulationCampaign.load(global_config.simulation_campaign)
8177
analysis_config = global_config.analysis["spikes"]
82-
cache_config = global_config.cache
8378

8479
instance = test_module.CacheManager(
85-
cache_config=cache_config,
8680
analysis_config=analysis_config,
8781
simulations_config=simulations_config,
8882
)
@@ -108,24 +102,20 @@ def test_cache_manager_to_readonly(global_config):
108102
def test_cache_manager_concurrency_is_not_allowed_when_locked(global_config):
109103
simulations_config = SimulationCampaign.load(global_config.simulation_campaign)
110104
analysis_config = global_config.analysis["spikes"]
111-
cache_config = global_config.cache
112105

113106
instance = test_module.CacheManager(
114-
cache_config=cache_config,
115107
analysis_config=analysis_config,
116108
simulations_config=simulations_config,
117109
)
118110
# verify that a new instance cannot be created when the old instance is keeping the lock
119111
with pytest.raises(test_module.CacheError, match="Another process is locking"):
120112
test_module.CacheManager(
121-
cache_config=cache_config,
122113
analysis_config=analysis_config,
123114
simulations_config=simulations_config,
124115
)
125116
# verify that a new instance can be created after closing the old instance
126117
instance.close()
127118
instance = test_module.CacheManager(
128-
cache_config=cache_config,
129119
analysis_config=analysis_config,
130120
simulations_config=simulations_config,
131121
)
@@ -135,12 +125,10 @@ def test_cache_manager_concurrency_is_not_allowed_when_locked(global_config):
135125
def test_cache_manager_concurrency_is_allowed_when_readonly(global_config):
136126
simulations_config = SimulationCampaign.load(global_config.simulation_campaign)
137127
analysis_config = global_config.analysis["spikes"]
138-
cache_config = global_config.cache.model_copy(update={"readonly": False})
139-
cache_config_readonly = global_config.cache.model_copy(update={"readonly": True})
128+
cache_config = analysis_config.cache
140129

141130
# init the cache that will be used later
142131
instance = test_module.CacheManager(
143-
cache_config=cache_config,
144132
analysis_config=analysis_config,
145133
simulations_config=simulations_config,
146134
)
@@ -149,8 +137,9 @@ def test_cache_manager_concurrency_is_allowed_when_readonly(global_config):
149137
# use the same cache in multiple cache managers
150138
instances = [
151139
test_module.CacheManager(
152-
cache_config=cache_config_readonly,
153-
analysis_config=analysis_config,
140+
analysis_config=analysis_config.model_copy(
141+
update={"cache": cache_config.model_copy(update={"readonly": True})}
142+
),
154143
simulations_config=simulations_config,
155144
)
156145
for _ in range(3)
@@ -162,15 +151,13 @@ def test_cache_manager_concurrency_is_allowed_when_readonly(global_config):
162151
def test_cache_manager_clear_cache(global_config, tmp_path):
163152
simulations_config = SimulationCampaign.load(global_config.simulation_campaign)
164153
analysis_config = global_config.analysis["spikes"]
165-
cache_config = global_config.cache.model_copy(update={"clear": False})
166-
cache_config_clear = global_config.cache.model_copy(update={"clear": True})
154+
cache_config = analysis_config.cache
167155

168156
output = cache_config.path
169157
sentinel = output / "sentinel"
170158

171159
assert output.exists() is False
172160
instance = test_module.CacheManager(
173-
cache_config=cache_config_clear,
174161
analysis_config=analysis_config,
175162
simulations_config=simulations_config,
176163
)
@@ -181,7 +168,6 @@ def test_cache_manager_clear_cache(global_config, tmp_path):
181168

182169
# reuse the cache
183170
instance = test_module.CacheManager(
184-
cache_config=cache_config,
185171
analysis_config=analysis_config,
186172
simulations_config=simulations_config,
187173
)
@@ -191,8 +177,9 @@ def test_cache_manager_clear_cache(global_config, tmp_path):
191177

192178
# delete the cache
193179
instance = test_module.CacheManager(
194-
cache_config=cache_config_clear,
195-
analysis_config=analysis_config,
180+
analysis_config=analysis_config.model_copy(
181+
update={"cache": cache_config.model_copy(update={"clear": True})}
182+
),
196183
simulations_config=simulations_config,
197184
)
198185
instance.close()

tests/unit/test_utils.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import json
2+
from enum import Enum
23
from pathlib import Path
34

45
import numpy as np
@@ -95,11 +96,16 @@ def test_resolve_path(tmp_path):
9596

9697

9798
def test_dump_yaml(tmp_path):
99+
class TestEnum(str, Enum):
100+
v0 = "v0"
101+
v1 = "v1"
102+
98103
data = {
99104
"dict": {"str": "mystr", "int": 123},
100105
"list_of_int": [1, 2, 3],
101106
"list_of_str": ["1", "2", "3"],
102107
"path": Path("/custom/path"),
108+
"enum": TestEnum.v0,
103109
}
104110
expected = """
105111
dict:
@@ -114,6 +120,7 @@ def test_dump_yaml(tmp_path):
114120
- '2'
115121
- '3'
116122
path: /custom/path
123+
enum: v0
117124
"""
118125
filepath = tmp_path / "test.yaml"
119126

@@ -150,7 +157,7 @@ def test_load_yaml(tmp_path):
150157
assert loaded_data == expected
151158

152159

153-
def test_dump_jaon_load_json_roundtrip(tmp_path):
160+
def test_dump_json_load_json_roundtrip(tmp_path):
154161
data = {
155162
"dict": {"str": "mystr", "int": 123},
156163
"list_of_int": [1, 2, 3],
@@ -286,8 +293,8 @@ def test_get_shmdir(monkeypatch, tmp_path):
286293
assert shmdir == tmp_path
287294

288295
monkeypatch.delenv("SHMDIR")
289-
shmdir = test_module.get_shmdir()
290-
assert shmdir is None
296+
with pytest.raises(RuntimeError, match="SHMDIR must be set to the shared memory directory"):
297+
test_module.get_shmdir()
291298

292299
monkeypatch.setenv("SHMDIR", str(tmp_path / "non-existent"))
293300
with pytest.raises(RuntimeError, match="SHMDIR must be set to an existing directory"):

tox.ini

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,11 +17,10 @@ minversion = 4
1717

1818
[testenv]
1919
setenv =
20+
TMPDIR={env:TMPDIR:/tmp}
21+
SHMDIR={env:SHMDIR:{env:TMPDIR}}
2022
# Run serially
2123
BLUEETL_JOBLIB_JOBS=1
22-
passenv =
23-
SHMDIR
24-
TMPDIR
2524
extras =
2625
all
2726
deps =

0 commit comments

Comments
 (0)