Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
86 changes: 79 additions & 7 deletions scanpy/plotting/_tools/scatterplots.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import collections.abc as cabc
import inspect
import sys
import warnings
from collections.abc import Mapping, Sequence # noqa: TCH003
from copy import copy
from functools import partial
Expand Down Expand Up @@ -63,7 +64,12 @@ def embedding(
mask_obs: NDArray[np.bool_] | str | None = None,
gene_symbols: str | None = None,
use_raw: bool | None = None,
sort_order: bool = True,
sort_order: bool | Empty = Empty,
order_continuous: Literal["ascending", "descending"]
| None
| np.ndarray
| Empty = Empty,
order_categorical: None | np.ndarray = None,
edges: bool = False,
edges_width: float = 0.1,
edges_color: str | Sequence[float] | Sequence[str] = "grey",
Expand Down Expand Up @@ -140,6 +146,9 @@ def embedding(
raise ValueError("Groups and mask arguments are incompatible.")
if mask_obs is not None:
mask_obs = _check_mask(adata, mask_obs, "obs")
order_continuous = check_continuous_order(
order_continuous, sort_order, adata.shape[0]
)

# Figure out if we're using raw
if use_raw is None:
Expand Down Expand Up @@ -282,12 +291,23 @@ def embedding(

# Order points
order = slice(None)
if sort_order and value_to_plot is not None and color_type == "cont":
# Higher values plotted on top, null values on bottom
order = np.argsort(-color_vector, kind="stable")[::-1]
elif sort_order and color_type == "cat":
# Null points go on bottom
order = np.argsort(~pd.isnull(color_source_vector), kind="stable")
if value_to_plot is None:
pass
elif color_type == "cont" and order_continuous is not None:
if isinstance(order_continuous, np.ndarray):
order = order_continuous
elif order_continuous == "ascending":
order = np.argsort(color_source_vector, kind="stable")
elif order_continuous == "descending":
order = np.argsort(-color_source_vector, kind="stable")

elif color_type == "cat" and order_categorical is not None:
order = order_categorical
if (masked_entries := pd.isnull(color_source_vector)).any():
if isinstance(order, slice):
order = np.arange(adata.n_obs)
order = order[np.argsort(~masked_entries[order], kind="stable")]

# Set orders
if isinstance(size, np.ndarray):
size = np.array(size)[order]
Expand Down Expand Up @@ -466,6 +486,58 @@ def embedding(
return axs


def check_continuous_order(
order_continuous: Literal["ascending", "descending"] | None | np.ndarray | Empty,
sort_order: bool | Empty,
N: int,
) -> Literal["ascending", "descending"] | None | np.ndarray:
# Backwards compat
if sort_order is not Empty:
warnings.warn(
"The `sort_order` parameter is deprecated and will be removed in the future. "
"Please use `order_continuous` and `order_categorical` instead.",
FutureWarning,
stacklevel=2,
)
if order_continuous is not Empty:
raise ValueError(
"Cannot specify both `sort_order` and `order_continuous`. "
"Please use only `order_continuous`."
)
elif sort_order:
order_continuous = "ascending"
else:
order_continuous = None
elif order_continuous is Empty:
# Default path
order_continuous = "ascending"
elif isinstance(order_continuous, np.ndarray) and order_continuous.shape != (N,):
raise ValueError(
f"order_continuous array must have shape ({N},). Got shape {order_continuous.shape}."
)
elif order_continuous not in ["ascending", "descending", None]:
raise ValueError(
f"order_continuous must be 'ascending', 'descending', None, or an array of values. Got {order_continuous}."
)
return order_continuous


def check_categorical_order(
order_categorical: None | np.ndarray, N: int
) -> None | np.ndarray:
if order_categorical is None:
pass
elif isinstance(order_categorical, np.ndarray) and order_categorical.shape != (N,):
raise ValueError(
f"order_categorical array must have shape ({N},). Got shape {order_categorical.shape}."
)
else:
raise ValueError(
"order_categorical must be None or an array of values. Got {order_categorical}."
)
return order_categorical


def _panel_grid(hspace, wspace, ncols, num_panels):
from matplotlib import gridspec

Expand Down