diff --git a/pyproject.toml b/pyproject.toml index 353811c9..e6a96e96 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -27,7 +27,6 @@ dependencies = [ "click", "dask-image", "dask>=2026.3.0", - "distributed>=2026.3.0", "datashader", "fsspec[s3,http]", "geopandas>=0.14", @@ -37,6 +36,7 @@ dependencies = [ "numpy", "ome_zarr>=0.16.0", "pandas", + "platformdirs", "pooch", "pyarrow", "rich", @@ -60,6 +60,9 @@ extra = [ "spatialdata-plot", "spatialdata-io", ] +zarrs = [ + "zarrs" +] [dependency-groups] dev = [ @@ -71,6 +74,7 @@ test = [ "pytest-mock", "pytest-xdist", "torch", + "zarrs", ] docs = [ "sphinx>=4.5", diff --git a/src/spatialdata/_core/_utils.py b/src/spatialdata/_core/_utils.py index 9dfd613b..b4c90fc5 100644 --- a/src/spatialdata/_core/_utils.py +++ b/src/spatialdata/_core/_utils.py @@ -1,8 +1,10 @@ from __future__ import annotations from collections.abc import Iterable +from typing import Any from anndata import AnnData +from ome_zarr.types import JSONDict from spatialdata._core.spatialdata import SpatialData @@ -164,3 +166,41 @@ def get_unique_name(name: str, attr: str, is_dataframe_column: bool = False) -> setattr(sanitized, attr, new_dict) return None if inplace else sanitized + + +def create_raster_element_kwargs( + raster_write_kwargs: dict[str, JSONDict | list[JSONDict]] | list[JSONDict] | None, + element_name: str, + element_names: set[str], +) -> dict[str, Any] | list[dict[str, Any]] | None: + """Normalize raster keyword arguments to the kwargs required by `zarr.create_array` for a single raster.""" + if raster_write_kwargs is None: + return {} + + kwargs_copy = raster_write_kwargs.copy() + if isinstance(kwargs_copy, dict): + element_write_kwargs: JSONDict | list[JSONDict] | None = kwargs_copy.get(element_name) + if element_write_kwargs: + return element_write_kwargs + + # If we get here it means that we do not have kwargs with the specific element. We need to clear out kwargs + # that could be there of other elements. + for name in element_names: + kwargs_copy.pop(name, None) + + # We return here if there are no kwargs after stripping all kwargs directly corresponding to a given element. + if not kwargs_copy: + return {} + + if isinstance(kwargs_copy, dict) and not all(isinstance(x, (dict, list)) for x in kwargs_copy.values()): + return kwargs_copy + + if isinstance(kwargs_copy, list): + if not all(isinstance(x, dict) for x in kwargs_copy): + raise ValueError( + "If passing raster_write_kwargs as list, it is assumed to be the storage " + "options for each scale of a multiscale raster as a dictionary." + ) + return kwargs_copy + + raise ValueError(f"Type of raster_write_kwargs should be either dict or list, got {type(kwargs_copy)}.") diff --git a/src/spatialdata/_core/spatialdata.py b/src/spatialdata/_core/spatialdata.py index fb55ab08..460c706e 100644 --- a/src/spatialdata/_core/spatialdata.py +++ b/src/spatialdata/_core/spatialdata.py @@ -16,6 +16,7 @@ from dask.dataframe import DataFrame as DaskDataFrame from dask.dataframe import Scalar from geopandas import GeoDataFrame +from ome_zarr.types import JSONDict from shapely import MultiPolygon, Polygon from upath import UPath from xarray import DataArray, DataTree @@ -29,9 +30,10 @@ raise_validation_errors, validate_table_attr_keys, ) +from spatialdata._docs import docstring_parameter from spatialdata._logging import logger from spatialdata._types import ArrayLike, Raster_T -from spatialdata._utils import _deprecation_alias +from spatialdata._utils import _deprecation_alias, zarrs_context from spatialdata.models import ( Image2DModel, Image3DModel, @@ -57,6 +59,27 @@ SpatialDataFormatType, ) +RASTER_WRITE_KWARGS_DOCS = """\ + Storage options for raster elements. These options are passed to the zarr storage backend for writing and + can be provided in several formats: + + 1. Single dictionary + A dictionary containing all storage options applied globally. + 2. Dictionary per raster element + A dictionary where: + - Keys = names of raster elements + - Values = storage options for each element + - For single-scale data: a dictionary + - For multiscale data: a list of dictionaries (one per scale) + 3. List of dictionaries (multiscale only) + A list where each dictionary defines the storage options for one scale of a multiscale raster element. + + Important Notes + - The available key–value pairs in these dictionaries depend on the Zarr format used for writing. + - For a full list of supported storage options, refer to: + https://zarr.readthedocs.io/en/stable/api/zarr/create/#zarr.create_array + """ + class SpatialData: """ @@ -1105,6 +1128,7 @@ def _validate_all_elements(self) -> None: validate_table_attr_keys(element, location=element_path) @_deprecation_alias(format="sdata_formats", version="0.7.0") + @docstring_parameter(raster_write_kwargs=RASTER_WRITE_KWARGS_DOCS) def write( self, file_path: str | Path, @@ -1113,6 +1137,7 @@ def write( update_sdata_path: bool = True, sdata_formats: SpatialDataFormatType | list[SpatialDataFormatType] | None = None, shapes_geometry_encoding: Literal["WKB", "geoarrow"] | None = None, + raster_write_kwargs: dict[str, JSONDict | list[JSONDict]] | list[JSONDict] | None = None, raster_compressor: dict[Literal["lz4", "zstd"], int] | None = None, ) -> None: """ @@ -1161,6 +1186,8 @@ def write( shapes_geometry_encoding Whether to use the WKB or geoarrow encoding for GeoParquet. See :meth:`geopandas.GeoDataFrame.to_parquet` for details. If None, uses the value from :attr:`spatialdata.settings.shapes_geometry_encoding`. + raster_write_kwargs + {RASTER_WRITE_KWARGS_DOCS} raster_compressor A lenght-1 dictionary with as key the type of compression to use for images and labels and as value the compression level which should be inclusive between 0 and 9. For compression, `lz4` and `zstd` are @@ -1193,6 +1220,7 @@ def write( overwrite=False, parsed_formats=parsed, shapes_geometry_encoding=shapes_geometry_encoding, + raster_write_kwargs=raster_write_kwargs, raster_compressor=raster_compressor, ) @@ -1211,6 +1239,7 @@ def _write_element( overwrite: bool, parsed_formats: dict[str, SpatialDataFormatType] | None = None, shapes_geometry_encoding: Literal["WKB", "geoarrow"] | None = None, + raster_write_kwargs: dict[str, JSONDict | list[JSONDict] | Any] | list[JSONDict] | None = None, raster_compressor: dict[Literal["lz4", "zstd"], int] | None = None, ) -> None: from spatialdata._io.io_zarr import _get_groups_for_element @@ -1244,51 +1273,63 @@ def _write_element( validate_element(element) - if element_type == "images": - write_image( - image=element, - group=element_group, - name=element_name, - element_format=parsed_formats["raster"], - raster_compressor=raster_compressor, - ) - elif element_type == "labels": - write_labels( - labels=element, - group=root_group, - name=element_name, - element_format=parsed_formats["raster"], - raster_compressor=raster_compressor, - ) - elif element_type == "points": - write_points( - points=element, - group=element_group, - element_format=parsed_formats["points"], - ) - elif element_type == "shapes": - write_shapes( - shapes=element, - group=element_group, - element_format=parsed_formats["shapes"], - geometry_encoding=shapes_geometry_encoding, - ) - elif element_type == "tables": - write_table( - table=element, - group=element_type_group, - name=element_name, - element_format=parsed_formats["tables"], - ) - else: - raise ValueError(f"Unknown element type: {element_type}") + element_raster_write_kwargs = None + if element_type in ("images", "labels") and raster_write_kwargs: + from spatialdata._core._utils import create_raster_element_kwargs + + element_names = set(self.images.keys()).union(self.labels.keys()) + element_raster_write_kwargs = create_raster_element_kwargs(raster_write_kwargs, element_name, element_names) + + with zarrs_context(): + if element_type == "images": + write_image( + image=element, + group=element_group, + name=element_name, + element_format=parsed_formats["raster"], + storage_options=element_raster_write_kwargs, + raster_compressor=raster_compressor, + ) + elif element_type == "labels": + write_labels( + labels=element, + group=root_group, + name=element_name, + element_format=parsed_formats["raster"], + storage_options=element_raster_write_kwargs, + raster_compressor=raster_compressor, + ) + elif element_type == "points": + write_points( + points=element, + group=element_group, + element_format=parsed_formats["points"], + ) + elif element_type == "shapes": + write_shapes( + shapes=element, + group=element_group, + element_format=parsed_formats["shapes"], + geometry_encoding=shapes_geometry_encoding, + ) + elif element_type == "tables": + write_table( + table=element, + group=element_type_group, + name=element_name, + element_format=parsed_formats["tables"], + ) + else: + raise ValueError(f"Unknown element type: {element_type}") + @docstring_parameter(raster_write_kwargs=RASTER_WRITE_KWARGS_DOCS) def write_element( self, element_name: str | list[str], overwrite: bool = False, sdata_formats: SpatialDataFormatType | list[SpatialDataFormatType] | None = None, shapes_geometry_encoding: Literal["WKB", "geoarrow"] | None = None, + raster_write_kwargs: dict[str, JSONDict | list[JSONDict] | Any] | list[JSONDict] | None = None, raster_compressor: dict[Literal["lz4", "zstd"], int] | None = None, ) -> None: """ @@ -1308,6 +1349,8 @@ def write_element( shapes_geometry_encoding Whether to use the WKB or geoarrow encoding for GeoParquet. See :meth:`geopandas.GeoDataFrame.to_parquet` for details. If None, uses the value from :attr:`spatialdata.settings.shapes_geometry_encoding`. + raster_write_kwargs + {RASTER_WRITE_KWARGS_DOCS} raster_compressor A lenght-1 dictionary with as key the type of compression to use for images and labels and as value the compression level which should be inclusive between 0 and 9. For compression, `lz4` and `zstd` are @@ -1331,6 +1374,7 @@ def write_element( overwrite=overwrite, sdata_formats=sdata_formats, shapes_geometry_encoding=shapes_geometry_encoding, + raster_write_kwargs=raster_write_kwargs, raster_compressor=raster_compressor, ) return @@ -1367,6 +1411,7 @@ def write_element( overwrite=overwrite, parsed_formats=parsed_formats, shapes_geometry_encoding=shapes_geometry_encoding, + raster_write_kwargs=raster_write_kwargs, raster_compressor=raster_compressor, ) # After every write, metadata should be consolidated, otherwise this can lead to IO problems like when deleting. diff --git a/src/spatialdata/_io/io_raster.py b/src/spatialdata/_io/io_raster.py index 2feb7a77..8b24a795 100644 --- a/src/spatialdata/_io/io_raster.py +++ b/src/spatialdata/_io/io_raster.py @@ -148,13 +148,13 @@ def _prepare_storage_options( return None if isinstance(storage_options, dict): prepared = dict(storage_options) - if "chunks" in prepared: + if "chunks" in prepared and prepared["chunks"] is not None: prepared["chunks"] = _normalize_explicit_chunks(prepared["chunks"]) return prepared prepared_options = [dict(options) for options in storage_options] for options in prepared_options: - if "chunks" in options: + if "chunks" in options and options["chunks"] is not None: options["chunks"] = _normalize_explicit_chunks(options["chunks"]) return prepared_options @@ -284,6 +284,19 @@ def _write_raster( raster_format The format used to write the raster data. storage_options + Storage options for raster elements, which have been extracted from potentially mixed kwargs dict by + `create_raster_element_kwargs`. These options are passed to the zarr storage backend for writing and can be + provided in several formats: + + 1. Single dictionary + A dictionary containing all storage options applied to the raster, either single or multiscale. + 2. List of dictionaries (multiscale only) + A list where each dictionary defines the storage options for one scale of the multiscale raster element. + + Important Notes + - The available key–value pairs in these dictionaries depend on the Zarr format used for writing. + - For a full list of supported storage options, refer to: + https://zarr.readthedocs.io/en/stable/api/zarr/create/#zarr.create_array Additional options for writing the raster data, like chunks and compression. raster_compressor Compression settings as a len-1 dictionary with a single key-value {compression: compression level} pair @@ -292,6 +305,10 @@ def _write_raster( metadata Additional metadata for the raster element """ + from dataclasses import asdict + + from spatialdata import settings + if raster_type not in ["image", "labels"]: raise ValueError(f"{raster_type} is not a valid raster type. Must be 'image' or 'labels'.") # "name" and "label_metadata" are only used for labels. "name" is written in write_multiscale_ngff() but ignored in @@ -308,6 +325,14 @@ def _write_raster( for c in channels: metadata["metadata"]["omero"]["channels"].append({"label": c}) # type: ignore[union-attr, index, call-overload] + # Prefixing with raster_ to account for anndata chunks / shards that will be supported in the future. + base_options = {k.split("_")[1]: v for k, v in asdict(settings).items() if k in ("raster_chunks", "raster_shards")} + + if isinstance(storage_options, list): + storage_options = [{**base_options, **x} for x in storage_options] + else: + storage_options = {**base_options, **(storage_options or {})} + if isinstance(raster_data, DataArray): _write_raster_dataarray( raster_type, diff --git a/src/spatialdata/_utils.py b/src/spatialdata/_utils.py index 609cd040..afd1a551 100644 --- a/src/spatialdata/_utils.py +++ b/src/spatialdata/_utils.py @@ -4,12 +4,14 @@ import re import warnings from collections.abc import Callable, Generator -from contextlib import contextmanager +from contextlib import contextmanager, nullcontext +from importlib.util import find_spec from itertools import islice from typing import Any, TypeVar import numpy as np import pandas as pd +import zarr from anndata import AnnData from dask import array as da from dask import config @@ -354,3 +356,20 @@ def _check_match_length_channels_c_dim( f" with length {c_length}." ) return c_coords + + +# TODO: get this in scverse-misc and import from there +@contextmanager +def zarrs_context() -> Generator[None, None, None]: + with ( + zarr.config.set({"codec_pipeline.path": "zarrs.ZarrsCodecPipeline"}) if find_spec("zarrs") else nullcontext(), + warnings.catch_warnings() if find_spec("zarrs") else nullcontext(), + ): + # The warning is there in case zarrs doesn't support the store type you passed in to read_zarr. + if find_spec("zarrs"): + warnings.filterwarnings( + "ignore", + message=r".*unsupported by ZarrsCodecPipeline.*", + category=UserWarning, + ) + yield diff --git a/tests/conftest.py b/tests/conftest.py index 617acb90..0941b9f4 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,6 +12,8 @@ import copy as _copy from collections.abc import Callable, Sequence +from contextlib import contextmanager +from dataclasses import replace from pathlib import Path from typing import Any @@ -30,6 +32,7 @@ from skimage import data from xarray import DataArray, DataTree +from spatialdata import settings from spatialdata._core._deepcopy import deepcopy from spatialdata._core.spatialdata import SpatialData from spatialdata._types import ArrayLike @@ -715,3 +718,27 @@ def complex_sdata() -> SpatialData: sdata.tables["labels_table"].layers["log"] = np.log1p(np.abs(sdata.tables["labels_table"].X)) return sdata + + +@pytest.fixture() +def settings_cls(tmp_path, monkeypatch): + """ + Provide setting class with default path redirected. + """ + from spatialdata.config import Settings + + monkeypatch.setattr("spatialdata.config._config_path", lambda: tmp_path / "default_settings.json") + return Settings + + +@contextmanager +def temporary_settings(**kwargs): + old = replace(settings) + try: + for k, v in kwargs.items(): + setattr(settings, k, v) + settings.save() + yield + finally: + settings.__dict__.update(old.__dict__) + settings.save() diff --git a/tests/io/test_readwrite.py b/tests/io/test_readwrite.py index 034c01d3..7e969c1c 100644 --- a/tests/io/test_readwrite.py +++ b/tests/io/test_readwrite.py @@ -34,7 +34,7 @@ ) from spatialdata._io.io_raster import write_image from spatialdata.datasets import blobs -from spatialdata.models import Image2DModel +from spatialdata.models import Image2DModel, Labels2DModel from spatialdata.models._utils import get_channel_names from spatialdata.testing import assert_spatial_data_objects_are_identical from spatialdata.transformations.operations import ( @@ -53,6 +53,27 @@ RNG = default_rng(0) SDATA_FORMATS = list(SpatialDataContainerFormats.values()) +RASTER_CASES = [ + pytest.param( + {"model": Image2DModel, "dims": ("c", "y", "x"), "data_shape": (3, 800, 1000), "zarr_subpath": "images"}, + id="image", + ), + pytest.param( + {"model": Labels2DModel, "dims": ("y", "x"), "data_shape": (800, 1000), "zarr_subpath": "labels"}, + id="label", + ), +] + +RASTER_CASES_MULTISCALE = [ + pytest.param( + {"model": Image2DModel, "dims": ("c", "y", "x"), "data_shape": (3, 1600, 2000), "zarr_subpath": "images"}, + id="image", + ), + pytest.param( + {"model": Labels2DModel, "dims": ("y", "x"), "data_shape": (1600, 2000), "zarr_subpath": "labels"}, + id="label", + ), +] @pytest.mark.filterwarnings("ignore:SpatialData is not stored in the most current format:UserWarning") @@ -820,6 +841,168 @@ def test_single_scale_image_roundtrip_stays_dataarray(tmp_path: Path) -> None: assert list(image_group.keys()) == ["s0"] +@pytest.mark.parametrize("raster_case", RASTER_CASES) +@pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) +def test_write_raster_sharding( + tmp_path: Path, + raster_case: dict, + sdata_container_format: SpatialDataContainerFormatType, +) -> None: + model, dims, data_shape, zarr_subpath = ( + raster_case["model"], + raster_case["dims"], + raster_case["data_shape"], + raster_case["zarr_subpath"], + ) + chunks = (1, 100, 200) if len(dims) == 3 else (100, 200) + write_chunks = (1, 50, 100) if len(dims) == 3 else (50, 100) + write_shards = (1, 100, 200) if len(dims) == 3 else (100, 200) + + data = da.from_array(RNG.random(data_shape), chunks=chunks) + element = model.parse(data, dims=dims) + name = "element" + sdata = SpatialData(**{zarr_subpath: {name: element}}) + path = tmp_path / "data.zarr" + + if sdata_container_format.zarr_format == 2: + with pytest.raises(ValueError, match="Zarr format 2 arrays can only"): + sdata.write( + path, + sdata_formats=sdata_container_format, + raster_write_kwargs={"chunks": write_chunks, "shards": write_shards}, + ) + else: + sdata.write( + path, + sdata_formats=sdata_container_format, + raster_write_kwargs={"chunks": write_chunks, "shards": write_shards}, + ) + arr = zarr.open_group(path / zarr_subpath / name, mode="r")["s0"] + assert arr.chunks == write_chunks + assert arr.shards == write_shards + + +@pytest.mark.parametrize("raster_case", RASTER_CASES_MULTISCALE) +def test_write_multiscale_raster_sharding(tmp_path: Path, raster_case: dict) -> None: + model, dims, data_shape, zarr_subpath = ( + raster_case["model"], + raster_case["dims"], + raster_case["data_shape"], + raster_case["zarr_subpath"], + ) + chunks = (1, 100, 200) if len(dims) == 3 else (100, 200) + write_chunks = (1, 50, 100) if len(dims) == 3 else (50, 100) + write_shards = (1, 100, 200) if len(dims) == 3 else (100, 200) + + data = da.from_array(RNG.random(data_shape), chunks=chunks) + element = model.parse(data, dims=dims, scale_factors=[2]) + name = "element" + sdata = SpatialData(**{zarr_subpath: {name: element}}) + path = tmp_path / "data.zarr" + + sdata.write(path, raster_write_kwargs={"chunks": write_chunks, "shards": write_shards}) + + group = zarr.open_group(path / zarr_subpath / name, mode="r") + for scale in ("s0", "s1"): + arr = group[scale] + assert arr.chunks == write_chunks + assert arr.shards == write_shards + + +@pytest.mark.parametrize("raster_case", RASTER_CASES_MULTISCALE) +def test_write_multiscale_raster_scale_sharding(tmp_path: Path, raster_case: dict) -> None: + model, dims, data_shape, zarr_subpath = ( + raster_case["model"], + raster_case["dims"], + raster_case["data_shape"], + raster_case["zarr_subpath"], + ) + chunks_s0 = (1, 50, 100) if len(dims) == 3 else (50, 100) + shards_s0 = (1, 100, 200) if len(dims) == 3 else (100, 200) + chunks_s1 = (1, 25, 50) if len(dims) == 3 else (25, 50) + shards_s1 = (1, 50, 100) if len(dims) == 3 else (50, 100) + base_chunks = (1, 100, 200) if len(dims) == 3 else (100, 200) + + data = da.from_array(RNG.random(data_shape), chunks=base_chunks) + element = model.parse(data, dims=dims, scale_factors=[2]) + name = "element" + sdata = SpatialData(**{zarr_subpath: {name: element}}) + path = tmp_path / "data.zarr" + + sdata.write( + path, + raster_write_kwargs=[ + {"chunks": chunks_s0, "shards": shards_s0}, + {"chunks": chunks_s1, "shards": shards_s1}, + ], + ) + + group = zarr.open_group(path / zarr_subpath / name, mode="r") + assert group["s0"].chunks == chunks_s0 + assert group["s0"].shards == shards_s0 + assert group["s1"].chunks == chunks_s1 + assert group["s1"].shards == shards_s1 + + +@pytest.mark.parametrize("raster_case", RASTER_CASES) +def test_write_raster_sharding_keyword(tmp_path: Path, raster_case: dict) -> None: + model, dims, data_shape, zarr_subpath = ( + raster_case["model"], + raster_case["dims"], + raster_case["data_shape"], + raster_case["zarr_subpath"], + ) + base_chunks = (1, 100, 200) if len(dims) == 3 else (100, 200) + write_chunks = (1, 50, 100) if len(dims) == 3 else (50, 100) + write_shards = (1, 100, 200) if len(dims) == 3 else (100, 200) + + data = da.from_array(RNG.random(data_shape), chunks=base_chunks) + element = model.parse(data, dims=dims) + other = model.parse(data.copy(), dims=dims) + name, other_name = "element", "other_element" + sdata = SpatialData(**{zarr_subpath: {name: element, other_name: other}}) + path = tmp_path / "data.zarr" + + sdata.write( + path, + raster_write_kwargs={name: {"chunks": write_chunks, "shards": write_shards}}, + ) + + arr = zarr.open_group(path / zarr_subpath / name, mode="r")["s0"] + assert arr.chunks == write_chunks + assert arr.shards == write_shards + + other_arr = zarr.open_group(path / zarr_subpath / other_name, mode="r")["s0"] + assert other_arr.chunks == base_chunks + assert not other_arr.shards + + +def test_write_raster_elements_sharding_chunking(tmp_path: Path) -> None: + write_chunks = (1, 50, 100) + write_shards = (1, 100, 200) + + data = da.from_array(RNG.random((1, 500, 600))) + element = Image2DModel.parse(data, dims=("c", "y", "x")) + + sdata = SpatialData() + path = tmp_path / "data.zarr" + + sdata.write(path) + sdata["image"] = element + sdata["other_image"] = element + + sdata.write_element( + element_name=["image", "other_image"], raster_write_kwargs={"chunks": write_chunks, "shards": write_shards} + ) + + arr = zarr.open_group(path / "images" / "image", mode="r")["s0"] + assert arr.chunks == write_chunks + assert arr.shards == write_shards + arr = zarr.open_group(path / "images" / "other_image", mode="r")["s0"] + assert arr.chunks == write_chunks + assert arr.shards == write_shards + + @pytest.mark.filterwarnings("ignore:SpatialData is not stored in the most current format:UserWarning") @pytest.mark.parametrize("sdata_container_format", SDATA_FORMATS) def test_self_contained(full_sdata: SpatialData, sdata_container_format: SpatialDataContainerFormatType) -> None: