Skip to content
Merged
3 changes: 3 additions & 0 deletions docs/plotting.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,6 @@ Plotting (`.pl`)

.. automodule:: spatialdata_plot.pl.basic
:members:

.. autoclass:: spatialdata_plot.PercentileNormalize
:members:
3 changes: 2 additions & 1 deletion src/spatialdata_plot/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,8 @@
from . import pl
from ._logging import set_verbosity
from ._settings import Verbosity
from .pl._color import PercentileNormalize

__all__ = ["pl", "set_verbosity", "Verbosity"]
__all__ = ["PercentileNormalize", "Verbosity", "pl", "set_verbosity"]

__version__ = version("spatialdata-plot")
86 changes: 68 additions & 18 deletions src/spatialdata_plot/pl/_color.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,31 +110,81 @@ def _make_continuous_mappable(vmin: float, vmax: float, cmap: Any) -> ScalarMapp
return ScalarMappable(norm=Normalize(vmin=vmin, vmax=vmax), cmap=cmap)


class PercentileNormalize(Normalize):
""":class:`~matplotlib.colors.Normalize` that autoscales to data percentiles instead of min/max.

Heavy-tailed images (fluorescence, Xenium morphology) have a few very bright pixels, so the
default min/max mapping crushes the bulk of the signal into near-black. ``PercentileNormalize``
derives ``vmin``/``vmax`` from the ``pmin``/``pmax`` percentiles of the data instead, which
matches the per-channel contrast limits used by viewers like Xenium Explorer.

It plugs into the existing ``norm`` argument like any other ``Normalize``: a single instance is
autoscaled independently per channel, and a list applies channelwise limits.

Parameters
----------
pmin
Lower percentile in ``[0, 100]`` (``pmin < pmax``) used to derive ``vmin``.
pmax
Upper percentile in ``[0, 100]`` used to derive ``vmax``.
clip
Forwarded to :class:`~matplotlib.colors.Normalize`.

Notes
-----
Explicitly setting ``vmin``/``vmax`` overrides the corresponding percentile. On the datashader
backend contrast autoscales to the aggregate range rather than to these percentiles.
"""

def __init__(self, pmin: float = 0.0, pmax: float = 100.0, clip: bool = False) -> None:
if not 0.0 <= pmin < pmax <= 100.0:
raise ValueError(f"Require 0 <= pmin < pmax <= 100, got pmin={pmin}, pmax={pmax}.")
super().__init__(vmin=None, vmax=None, clip=clip)
self.pmin = pmin
self.pmax = pmax

def autoscale_None(self, A: Any) -> None:
"""Fill unset ``vmin``/``vmax`` from the ``pmin``/``pmax`` percentiles of finite values."""
finite = np.ma.masked_invalid(np.ma.asarray(A)).compressed() # drops mask + NaN/inf
if finite.size:
if self.vmin is None:
self.vmin = float(np.percentile(finite, self.pmin))
if self.vmax is None:
self.vmax = float(np.percentile(finite, self.pmax))


def _resolve_continuous_norm(values: Any, cmap_params: CmapParams) -> Normalize:
"""Resolve ``cmap_params.norm`` with concrete vmin/vmax for continuous coloring.

Honor explicit ``norm`` vmin/vmax, else the finite-value data range of ``values``, else
``[0, 1]``. Shared by the pixel and colorbar sites so both derive the same range. Preserves the
norm subclass (``LogNorm``/``PowerNorm``/...) so non-linear scaling is not silently linearized.
Honor explicit ``norm`` vmin/vmax, else delegate to the norm's own ``autoscale_None`` over the
finite values of ``values`` (so plain ``Normalize`` uses min/max, ``LogNorm`` uses its
positive-only range, and ``PercentileNormalize`` uses percentiles), else fall back to ``[0, 1]``.
Shared by the pixel and colorbar sites so both derive the same range; preserves the norm
subclass so non-linear scaling is not silently linearized.
"""
base = cmap_params.norm
vmin, vmax = base.vmin, base.vmax
if vmin is None or vmax is None:
resolved = copy(cmap_params.norm)
if resolved.vmin is None or resolved.vmax is None:
arr = np.asarray(values)
if not np.issubdtype(arr.dtype, np.number):
arr = pd.to_numeric(arr.ravel(), errors="coerce")
finite = np.isfinite(arr)
data_min = float(np.nanmin(arr[finite])) if finite.any() else 0.0
data_max = float(np.nanmax(arr[finite])) if finite.any() else 1.0
if vmin is None:
vmin = data_min
if vmax is None:
vmax = data_max
if vmin == vmax and not isinstance(base, LogNorm):
# degenerate range collapses the cmap onto its floor; fall back to [0, 1]. LogNorm exempt (0 not in domain).
vmin, vmax = 0.0, 1.0
resolved = copy(base)
resolved.vmin, resolved.vmax = vmin, vmax
finite = arr[np.isfinite(arr)]
if finite.size:
resolved.autoscale_None(finite)
if isinstance(resolved, LogNorm):
# LogNorm needs strictly-positive bounds; all-nonpositive/empty data can't provide them
# (matplotlib leaves them at 0), so fall back to a valid domain instead of raising later.
if resolved.vmin is None or resolved.vmin <= 0:
resolved.vmin = 1.0
if resolved.vmax is None or resolved.vmax <= 0:
resolved.vmax = 1.0
else:
if resolved.vmin is None:
resolved.vmin = 0.0
if resolved.vmax is None:
resolved.vmax = 1.0
if resolved.vmin == resolved.vmax and not isinstance(resolved, LogNorm):
# a single distinct value would collapse the cmap onto its floor
resolved.vmin, resolved.vmax = 0.0, 1.0
return resolved


Expand Down
4 changes: 4 additions & 0 deletions src/spatialdata_plot/pl/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -768,6 +768,10 @@ def render_images(
A single :class:`~matplotlib.colors.Normalize` applies to all channels.
A list of :class:`~matplotlib.colors.Normalize` objects applies per-channel
(length must match the number of channels).
For heavy-tailed images (e.g. fluorescence/Xenium morphology) where min/max
scaling looks dim, pass :class:`~spatialdata_plot.PercentileNormalize` to clip each
channel to a percentile range (single instance for all channels, or a list for
channelwise limits).
palette : list[str] | str | None
Palette to color images. Can be a single palette name (broadcast to all channels) or a list
matching the number of channels.
Expand Down
57 changes: 34 additions & 23 deletions src/spatialdata_plot/pl/render.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from spatialdata_plot._logging import _log_context, logger
from spatialdata_plot.pl._color import (
ColorSpec,
ColorType,
_get_colors_for_categorical_obs,
_get_linear_colormap,
_map_color_seg,
Expand Down Expand Up @@ -697,8 +698,7 @@ def _render_shapes(
nan_count = int(pd.isna(cv).sum())
if nan_count:
logger.warning(
f"Found {nan_count} NaN values in color data. "
"These observations will be colored with the 'na_color'."
f"Found {nan_count} NaN values in color data. These observations will be colored with the 'na_color'."
)
color_spec = color_spec.evolve(color_vector=cv)

Expand Down Expand Up @@ -968,9 +968,11 @@ def _render_shapes(
path.vertices = trans.transform(path.vertices)

if color_spec.is_continuous:
# Colorbar range from the same resolved norm the fill pixels use.
# Colorbar uses the same resolved norm the fill pixels use, including its subclass
# (LogNorm/PowerNorm) — set_norm, not set_clim, which would leave the collection's
# default linear Normalize in place and mis-scale the bar for non-linear norms.
used_norm = _resolve_continuous_norm(color_spec.color_vector, render_params.cmap_params)
_cax.set_clim(vmin=used_norm.vmin, vmax=used_norm.vmax)
_cax.set_norm(used_norm)

_add_legend_and_colorbar(
ax=ax,
Expand Down Expand Up @@ -1413,8 +1415,6 @@ def _render_points(

trans, trans_data = _prepare_transformation(sdata.points[element], coordinate_system, ax)

norm = render_params.cmap_params.fresh_norm()

method = render_params.method

if render_params.density:
Expand All @@ -1426,6 +1426,9 @@ def _render_points(
_default_reduction: _DsReduction = "sum"

if method == "datashader":
# datashader colors the per-pixel aggregate (count/sum/reduction), not the per-point vector,
# so pass an un-resolved norm and let _apply_ds_norm autoscale to the aggregate.
norm = render_params.cmap_params.fresh_norm()
_log_datashader_method(method, render_params.ds_reduction, _default_reduction)

# Apply transformations and materialize to pandas immediately so
Expand Down Expand Up @@ -1462,6 +1465,13 @@ def _render_points(
color_spec = color_spec.evolve(source_vector=csv, color_vector=cv)

elif method == "matplotlib":
# matplotlib colors each point by its own value, so resolve the norm to match shapes/labels
# instead of letting ax.scatter autoscale a fresh one. Non-continuous keeps the fresh norm.
norm = (
_resolve_continuous_norm(color_spec.color_vector, render_params.cmap_params)
if color_spec.is_continuous
else render_params.cmap_params.fresh_norm()
)
# update axis limits if plot was empty before (necessary if datashader comes after)
update_parameters = not _mpl_ax_contains_elements(ax)
cax = _scatter_points(
Expand Down Expand Up @@ -2297,25 +2307,29 @@ def _render_labels(
# (`_map_color_seg` Case C) instead of collapsing every dot to a single na_color.
point_color_vector = np.random.default_rng(42).random((len(point_ids), 3))
point_color_source_vector = None
point_colortype: ColorType = "none" # colour is not data-driven
allow_datashader = False
elif len(color_spec.color_vector) == len(instance_id):
# data-driven colour is per-instance
# data-driven colour is per-instance; carry the upstream classification (invariant under mask)
point_color_vector = np.asarray(color_spec.color_vector)[keep]
point_color_source_vector = None if color_spec.source_vector is None else color_spec.source_vector[keep]
point_colortype = color_spec.colortype
else:
# literal colour / user-set na_color -> one colour per centroid
point_color_vector = np.full(len(point_ids), na_color.get_hex_with_alpha())
point_color_source_vector = None
point_colortype = "none" # colour is not data-driven
# transform rendered-raster intrinsic centroids to coordinate-system coords
xy = trans.transform(np.column_stack([centroids["x"].to_numpy(), centroids["y"].to_numpy()]))
_render_centroids_as_points(
ax,
render_params,
x=xy[:, 0],
y=xy[:, 1],
# point colours are derived fresh; classify by source so the spec stays self-consistent
# point colours are derived fresh; carry the resolved colortype so the spec invariant
# (categorical => pd.Categorical source) holds and `none` is not mislabelled categorical
color_spec=ColorSpec(
"categorical" if point_color_source_vector is not None else "continuous",
point_colortype,
point_color_source_vector,
point_color_vector,
),
Expand Down Expand Up @@ -2352,20 +2366,17 @@ def _draw_labels(
outline_color_source_vector=outline_color_source_vector if seg_boundaries else None,
)

# labels is pre-baked RGB; cmap/norm only drive the colorbar, so feed the same resolved norm.
cax = ax.imshow(
labels,
rasterized=True,
cmap=None if color_spec.is_categorical else render_params.cmap_params.cmap,
norm=None
if color_spec.is_categorical
else _resolve_continuous_norm(color_spec.color_vector, render_params.cmap_params),
alpha=alpha,
origin="lower",
zorder=render_params.zorder,
)
cax.set_transform(trans_data)
return cax
# labels is pre-baked RGB, so imshow ignores cmap/norm for display. Passing the resolved
# norm to imshow would make it try to normalize the RGBA array — which raises for a
# non-linear norm (LogNorm/PowerNorm). Display the RGB without a norm and build the
# continuous colorbar mappable separately from the resolved norm (mirrors the outline path),
# so the colorbar reflects the real norm subclass.
img = ax.imshow(labels, rasterized=True, alpha=alpha, origin="lower", zorder=render_params.zorder)
img.set_transform(trans_data)
if color_spec.is_categorical:
return img
used_norm = _resolve_continuous_norm(color_spec.color_vector, render_params.cmap_params)
return ScalarMappable(norm=used_norm, cmap=render_params.cmap_params.cmap)

# When color is a literal (col_for_color is None) and no explicit outline_color,
# use the literal color for outlines so they are visible (e.g., color='white' on
Expand Down
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
10 changes: 10 additions & 0 deletions tests/pl/test_render_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from spatialdata.models import Image2DModel, Image3DModel

import spatialdata_plot # noqa: F401
from spatialdata_plot import PercentileNormalize
from spatialdata_plot._logging import logger, logger_no_warns, logger_warns
from spatialdata_plot.pl.render import _is_rgb_image
from tests.conftest import DPI, PlotTester, PlotTesterMeta, _viridis_with_under_over
Expand Down Expand Up @@ -86,6 +87,15 @@ def test_plot_can_pass_normalize_clip_False(self, sdata_blobs: SpatialData):
element="blobs_image", channel=0, norm=norm, cmap=_viridis_with_under_over()
).pl.show()

def test_plot_percentile_normalize_broadcast(self, sdata_blobs: SpatialData):
# single PercentileNormalize is broadcast and autoscaled per channel to its percentile range
sdata_blobs.pl.render_images(element="blobs_image", norm=PercentileNormalize(0, 90)).pl.show()

def test_plot_percentile_normalize_channelwise(self, sdata_blobs: SpatialData):
# a list applies channelwise percentile limits
norms = [PercentileNormalize(0, 99), PercentileNormalize(0, 90), PercentileNormalize(0, 80)]
sdata_blobs.pl.render_images(element="blobs_image", channel=[0, 1, 2], norm=norms).pl.show()

def test_plot_can_pass_color_to_single_channel(self, sdata_blobs: SpatialData):
sdata_blobs.pl.render_images(element="blobs_image", channel=1, palette="red").pl.show()

Expand Down
39 changes: 39 additions & 0 deletions tests/pl/test_render_labels.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,45 @@ def test_render_labels_all_nan_color_renders_under_rasterize(sdata_blobs: Spatia
plt.close(fig)


def test_render_labels_as_points_all_nan_color_does_not_crash(sdata_blobs: SpatialData):
# Regression: as_points rebuilds the centroid ColorSpec; an all-NaN column is the "none" colortype
# and must not be re-classified "categorical" (which would later call Categorical-only methods on
# the na-array source). It must just render the centroids in na_color.
labels_name = "blobs_labels"
instances = get_element_instances(sdata_blobs[labels_name])
n_obs = len(instances)
adata = AnnData(np.zeros((n_obs, 1)))
adata.obs["instance_id"] = instances.values
adata.obs["nanvals"] = np.full(n_obs, np.nan)
adata.obs["region"] = labels_name
sdata_blobs["label_table"] = TableModel.parse(
adata=adata, region_key="region", instance_key="instance_id", region=labels_name
)
fig, ax = plt.subplots()
sdata_blobs.pl.render_labels(labels_name, color="nanvals", table_name="label_table", as_points=True).pl.show(ax=ax)
plt.close(fig)


def test_render_labels_lognorm_with_zeros_does_not_crash(sdata_blobs: SpatialData):
# Regression: a continuous LogNorm column containing 0 must derive a positive vmin instead of a
# LogNorm(vmin=0) that raises "Invalid vmin or vmax" when the segmentation colors are mapped.
from matplotlib.colors import LogNorm

labels_name = "blobs_labels"
instances = get_element_instances(sdata_blobs[labels_name])
n_obs = len(instances)
adata = AnnData(np.zeros((n_obs, 1)))
adata.obs["instance_id"] = instances.values
adata.obs["counts"] = np.linspace(0.0, 10.0, n_obs) # includes 0
adata.obs["region"] = labels_name
sdata_blobs["label_table"] = TableModel.parse(
adata=adata, region_key="region", instance_key="instance_id", region=labels_name
)
fig, ax = plt.subplots()
sdata_blobs.pl.render_labels(labels_name, color="counts", table_name="label_table", norm=LogNorm()).pl.show(ax=ax)
plt.close(fig)


@pytest.mark.parametrize("dtype", [np.float16, np.float32, np.float64])
def test_render_labels_rejects_float_dtype(dtype):
# Regression test for #606: float-dtype labels must raise a clear
Expand Down
12 changes: 12 additions & 0 deletions tests/pl/test_render_points.py
Original file line number Diff line number Diff line change
Expand Up @@ -666,6 +666,18 @@ def test_groups_na_color_none_no_match_points(sdata_blobs: SpatialData):
).pl.show()


def test_render_points_lognorm_with_zeros_does_not_crash(sdata_blobs: SpatialData):
# Regression: matplotlib points resolve their continuous norm through the shared resolver, so a
# LogNorm column containing 0 must derive a positive vmin instead of a LogNorm(vmin=0) that raises.
from matplotlib.colors import LogNorm

n = len(sdata_blobs["blobs_points"])
sdata_blobs["blobs_points"]["counts"] = pd.Series(np.linspace(0.0, 10.0, n)) # includes 0
fig, ax = plt.subplots()
sdata_blobs.pl.render_points("blobs_points", color="counts", norm=LogNorm(), method="matplotlib").pl.show(ax=ax)
plt.close(fig)


@pytest.mark.parametrize("na_color", [None, "red"])
def test_groups_warns_when_no_groups_match_points(sdata_blobs: SpatialData, caplog, na_color):
"""Warning fires regardless of na_color when no groups match."""
Expand Down
23 changes: 23 additions & 0 deletions tests/pl/test_render_shapes.py
Original file line number Diff line number Diff line change
Expand Up @@ -1150,6 +1150,29 @@ def test_render_shapes_all_nan_color_with_groups_does_not_crash(sdata_blobs_shap
plt.close(fig)


def test_render_shapes_lognorm_with_zeros_does_not_crash(sdata_blobs_shapes_annotated: SpatialData):
# Regression: a continuous LogNorm column containing 0 must derive a positive vmin instead of
# producing a LogNorm(vmin=0) that raises "Invalid vmin or vmax" when the fill is mapped.
from matplotlib.colors import LogNorm

sdata_blobs_shapes_annotated["blobs_polygons"]["counts"] = [0.0, 2.5, 5.0, 7.5, 10.0]
fig, ax = plt.subplots()
sdata_blobs_shapes_annotated.pl.render_shapes("blobs_polygons", color="counts", norm=LogNorm()).pl.show(ax=ax)
plt.close(fig)


def test_render_shapes_continuous_colorbar_reflects_norm_subclass(sdata_blobs_shapes_annotated: SpatialData):
# Regression: the fill colorbar must use the resolved norm subclass (LogNorm), not the
# collection's default linear Normalize — i.e. set_norm, not set_clim.
from matplotlib.colors import LogNorm

sdata_blobs_shapes_annotated["blobs_polygons"]["counts"] = [1.0, 2.5, 5.0, 7.5, 10.0]
fig, ax = plt.subplots()
sdata_blobs_shapes_annotated.pl.render_shapes("blobs_polygons", color="counts", norm=LogNorm()).pl.show(ax=ax)
assert any(isinstance(c.norm, LogNorm) for c in ax.collections), "fill colorbar norm was linearized"
plt.close(fig)


def test_gene_symbols_auto_detect_table(sdata_blobs: SpatialData):
"""gene_symbols resolves correctly without explicit table_name (#247)."""
sdata_blobs["table"].obs["region"] = pd.Categorical(["blobs_circles"] * sdata_blobs["table"].n_obs)
Expand Down
Loading
Loading