Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions CHANGES.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,11 @@
- Introduced `SliceSource.close()` so
[contextlib.closing()](https://docs.python.org/3/library/contextlib.html#contextlib.closing)
is applicable. Deprecated `SliceSource.dispose()`.

- Introduced new optional configuration setting `slice_source_kwargs` that
contains keyword-arguments, which are passed to a configured `slice_source` together with
each slice item.


## Version 0.6.0 (from 2024-03-12)

Expand Down
4 changes: 4 additions & 0 deletions docs/guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -760,6 +760,10 @@ argument to your slice source.
- `dict`: keyword arguments only;
- Any other type is interpreted as single positional argument.

You can also pass extra keyword arguments to your slice source using the
`slice_source_kwargs` setting. Keyword arguments passed as slice items take
precedence, that is, they overwrite arguments passed by `slice_source_kawrgs`.

In addition, your slice source function or class constructor specified by
`slice_source` may define a 1st positional argument or keyword argument
named `ctx`, which will receive the current processing context of type
Expand Down
18 changes: 18 additions & 0 deletions tests/config/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,24 @@ def test_slice_source_as_type(self):
}
)

def test_slice_source_kwargs(self):
config = Config(
{
"target_dir": "memory://target.zarr",
}
)
self.assertEqual(None, config.slice_source_kwargs)

config = Config(
{
"target_dir": "memory://target.zarr",
"slice_source_kwargs": {"a": 1, "b": True, "c": "nearest"},
}
)
self.assertEqual(
{"a": 1, "b": True, "c": "nearest"}, config.slice_source_kwargs
)


def new_custom_slice_source(ctx: Context, index: int):
return CustomSliceSource(ctx, index)
Expand Down
1 change: 1 addition & 0 deletions tests/config/test_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ def test_get_config_schema(self):
"slice_engine",
"slice_polling",
"slice_source",
"slice_source_kwargs",
"slice_storage_options",
"target_storage_options",
"target_dir",
Expand Down
33 changes: 30 additions & 3 deletions tests/slice/test_cm.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
# noinspection PyUnusedLocal


# noinspection PyShadowingBuiltins,PyRedeclaration
# noinspection PyShadowingBuiltins,PyRedeclaration,PyMethodMayBeStatic
class OpenSliceDatasetTest(unittest.TestCase):
def setUp(self):
clear_memory_fs()
Expand Down Expand Up @@ -164,7 +164,7 @@ def get_dataset(name):
with slice_cm as slice_ds:
self.assertIsInstance(slice_ds, xr.Dataset)

def test_slice_item_is_slice_source(self):
def test_slice_item_is_slice_source_arg(self):
class MySliceSource(SliceSource):
def __init__(self, name):
self.uri = f"memory://{name}.zarr"
Expand All @@ -191,7 +191,7 @@ def close(self):
with slice_cm as slice_ds:
self.assertIsInstance(slice_ds, xr.Dataset)

def test_slice_item_is_deprecated_slice_source(self):
def test_slice_item_is_deprecated_slice_source_arg(self):
class MySliceSource(SliceSource):
def __init__(self, name):
self.uri = f"memory://{name}.zarr"
Expand Down Expand Up @@ -219,6 +219,33 @@ def dispose(self):
with slice_cm as slice_ds:
self.assertIsInstance(slice_ds, xr.Dataset)

def test_slice_item_is_slice_source_arg_with_extra_kwargs(self):
class MySliceSource(SliceSource):
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs

def get_dataset(self):
return xr.Dataset()

ctx = Context(
dict(
target_dir="memory://target.zarr",
slice_source=MySliceSource,
slice_source_kwargs={"a": 1, "b": True, "c": "nearest"},
)
)
slice_cm = open_slice_dataset(ctx, (["bibo"], {"a": 2, "d": 3.14}))
self.assertIsInstance(slice_cm, SliceSourceContextManager)
slice_source = slice_cm.slice_source
self.assertIsInstance(slice_source, MySliceSource)
with slice_cm as slice_ds:
self.assertIsInstance(slice_ds, xr.Dataset)
self.assertEqual(slice_source.args, ("bibo",))
self.assertEqual(
slice_source.kwargs, {"a": 2, "b": True, "c": "nearest", "d": 3.14}
)


class IsContextManagerTest(unittest.TestCase):
"""Assert that context managers are identified by isinstance()"""
Expand Down
7 changes: 7 additions & 0 deletions zappend/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,13 @@ def slice_source(self) -> Callable[[...], Any] | None:
"""
return self._slice_source

@property
def slice_source_kwargs(self) -> dict[str, Any] | None:
"""Extra keyword-arguments passed to a specified `slice_source`
together with each slice item.
"""
return self._config.get("slice_source_kwargs")

@property
def slice_storage_options(self) -> dict[str, Any] | None:
"""The configured slice storage options to be used
Expand Down
8 changes: 8 additions & 0 deletions zappend/config/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -643,6 +643,14 @@
"type": "string",
"minLength": 1,
},
slice_source_kwargs={
"description": (
"Extra keyword-arguments passed to a configured `slice_source`"
" together with each slice item."
),
"type": "object",
"additionalProperties": True,
},
slice_engine={
"description": (
"The name of the engine to be used for opening"
Expand Down
4 changes: 4 additions & 0 deletions zappend/slice/callable.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ def invoke_slice_callable(
A slice item of type `SliceItem`.
"""
slice_args, slice_kwargs = to_slice_args(slice_item)
if ctx.config.slice_source_kwargs:
extra_kwargs = dict(ctx.config.slice_source_kwargs)
extra_kwargs.update(slice_kwargs)
slice_kwargs = extra_kwargs

signature = inspect.signature(slice_callable)
ctx_parameter = signature.parameters.get("ctx")
Expand Down