Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 11 additions & 14 deletions doc/nextgen/demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@
"id": "ae0e288e-74cf-461c-8e68-786e364032a1",
"metadata": {},
"source": [
"### Data transformation: the Stat\n",
"### Data transformations: the Stat\n",
"\n",
"\n",
"Built-in statistical transformations are one of seaborn's key features. But currently, they are tied up with the different visual representations. E.g., you can aggregate data in `lineplot`, but not in `scatterplot`.\n",
Expand All @@ -273,7 +273,7 @@
"id": "1788d935-5ad5-4262-993f-8d48c66631b9",
"metadata": {},
"source": [
"The `Stat` is computed on subsets of data defined by the semantic mappings:"
"A `Stat` is computed on subsets of data defined by the semantic mappings:"
]
},
{
Expand Down Expand Up @@ -323,7 +323,7 @@
"outputs": [],
"source": [
"class PeakAnnotation(so.Mark):\n",
" def plot(self, split_generator, scales, orient):\n",
" def _plot(self, split_generator, scales, orient):\n",
" for keys, data, ax in split_generator():\n",
" ix = data[\"y\"].idxmax()\n",
" ax.annotate(\n",
Expand Down Expand Up @@ -388,7 +388,7 @@
"source": [
"(\n",
" so.Plot(tips, \"day\", \"total_bill\", color=\"time\")\n",
" .add(so.Bar(), so.Agg(), move=so.Dodge())\n",
" .add(so.Dot(), so.Dodge())\n",
")"
]
},
Expand All @@ -409,7 +409,7 @@
"source": [
"(\n",
" so.Plot(tips, \"day\", \"total_bill\", color=\"time\")\n",
" .add(so.Bar(), so.Agg(), move=so.Dodge(empty=\"fill\", gap=.1))\n",
" .add(so.Bar(), so.Agg(), so.Dodge(empty=\"fill\", gap=.1))\n",
")"
]
},
Expand All @@ -430,7 +430,7 @@
"source": [
"(\n",
" so.Plot(tips, \"day\", \"total_bill\", color=\"time\", alpha=\"sex\")\n",
" .add(so.Bar(), so.Agg(), move=so.Dodge())\n",
" .add(so.Bar(), so.Agg(), so.Dodge())\n",
")"
]
},
Expand All @@ -451,7 +451,7 @@
"source": [
"(\n",
" so.Plot(tips, \"day\", \"total_bill\", color=\"time\", alpha=\"smoker\")\n",
" .add(so.Dot(), move=so.Dodge(by=[\"color\"]))\n",
" .add(so.Dot(), so.Dodge(by=[\"color\"]))\n",
")"
]
},
Expand All @@ -460,7 +460,7 @@
"id": "c001004a-6771-46eb-b231-6accf88fe330",
"metadata": {},
"source": [
"It's also possible to stack multiple moves or kinds of moves by passing a list:"
"It's also possible to stack multiple moves or kinds of moves:"
]
},
{
Expand All @@ -472,10 +472,7 @@
"source": [
"(\n",
" so.Plot(tips, \"day\", \"total_bill\", color=\"time\", alpha=\"smoker\")\n",
" .add(\n",
" so.Dot(),\n",
" move=[so.Dodge(by=[\"color\"]), so.Jitter(.5)]\n",
" )\n",
" .add(so.Dot(), so.Dodge(by=[\"color\"]), so.Jitter(.5))\n",
")"
]
},
Expand Down Expand Up @@ -568,8 +565,8 @@
" so.Plot(planets, x=\"mass\", y=\"distance\", color=\"orbital_period\")\n",
" .scale(\n",
" x=\"log\",\n",
" y=so.Continuous(transform=\"log\").tick(at=[3, 10, 30, 100, 300]),\n",
" color=so.Continuous(\"rocket\", transform=\"log\"),\n",
" y=so.Continuous(trans=\"log\").tick(at=[3, 10, 30, 100, 300]),\n",
" color=so.Continuous(\"rocket\", trans=\"log\"),\n",
" )\n",
" .add(so.Dots())\n",
")"
Expand Down
25 changes: 19 additions & 6 deletions seaborn/_core/moves.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,17 @@
from pandas import DataFrame

from seaborn._core.groupby import GroupBy
from seaborn._core.scales import Scale


@dataclass
class Move:

group_by_orient: ClassVar[bool] = True

def __call__(self, data: DataFrame, groupby: GroupBy, orient: str) -> DataFrame:
def __call__(
self, data: DataFrame, groupby: GroupBy, orient: str, scales: dict[str, Scale],
) -> DataFrame:
raise NotImplementedError


Expand All @@ -31,7 +34,9 @@ class Jitter(Move):
# TODO what is the best way to have a reasonable default?
# The problem is that "reasonable" seems dependent on the mark

def __call__(self, data: DataFrame, groupby: GroupBy, orient: str) -> DataFrame:
def __call__(
self, data: DataFrame, groupby: GroupBy, orient: str, scales: dict[str, Scale],
) -> DataFrame:

# TODO is it a problem that GroupBy is not used for anything here?
# Should we type it as optional?
Expand Down Expand Up @@ -68,7 +73,9 @@ class Dodge(Move):
# TODO should the default be an "all" singleton?
by: Optional[list[str]] = None

def __call__(self, data: DataFrame, groupby: GroupBy, orient: str) -> DataFrame:
def __call__(
self, data: DataFrame, groupby: GroupBy, orient: str, scales: dict[str, Scale],
) -> DataFrame:

grouping_vars = [v for v in groupby.order if v in data]
groups = groupby.agg(data, {"width": "max"})
Expand Down Expand Up @@ -138,7 +145,9 @@ def _stack(self, df, orient):

return df

def __call__(self, data: DataFrame, groupby: GroupBy, orient: str) -> DataFrame:
def __call__(
self, data: DataFrame, groupby: GroupBy, orient: str, scales: dict[str, Scale],
) -> DataFrame:

# TODO where to ensure that other semantic variables are sorted properly?
# TODO why are we not using the passed in groupby here?
Expand All @@ -154,7 +163,9 @@ class Shift(Move):
x: float = 0
y: float = 0

def __call__(self, data: DataFrame, groupby: GroupBy, orient: str) -> DataFrame:
def __call__(
self, data: DataFrame, groupby: GroupBy, orient: str, scales: dict[str, Scale],
) -> DataFrame:

data = data.copy(deep=False)
data["x"] = data["x"] + self.x
Expand Down Expand Up @@ -188,7 +199,9 @@ def _norm(self, df, var):

return df

def __call__(self, data: DataFrame, groupby: GroupBy, orient: str) -> DataFrame:
def __call__(
self, data: DataFrame, groupby: GroupBy, orient: str, scales: dict[str, Scale],
) -> DataFrame:

other = {"x": "y", "y": "x"}[orient]
return groupby.apply(data, self._norm, other)
Expand Down
63 changes: 38 additions & 25 deletions seaborn/_core/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from contextlib import contextmanager
from collections import abc
from collections.abc import Callable, Generator, Hashable
from typing import Any, cast
from typing import Any, List, Optional, cast

from cycler import cycler
import pandas as pd
Expand Down Expand Up @@ -338,16 +338,14 @@ def on(self, target: Axes | SubFigure | Figure) -> Plot:
def add(
self,
mark: Mark,
stat: Stat | None = None,
move: Move | list[Move] | None = None,
*,
*transforms: Stat | Mark,
orient: str | None = None,
legend: bool = True,
data: DataSource = None,
**variables: VariableSpec,
) -> Plot:
"""
Define a layer of the visualization.
Define a layer of the visualization in terms of mark and data transform(s).
This is the main method for specifying how the data should be visualized.
It can be called multiple times with different arguments to define
Expand All @@ -357,48 +355,63 @@ def add(
----------
mark : :class:`seaborn.objects.Mark`
The visual representation of the data to use in this layer.
stat : :class:`seaborn.objects.Stat`
A transformation applied to the data before plotting.
move : :class:`seaborn.objects.Move`
Additional transformation(s) to handle over-plotting.
legend : bool
Option to suppress the mark/mappings for this layer from the legend.
transforms : :class:`seaborn.objects.Stat` or :class:`seaborn.objects.Move`
Objects representing transforms to be applied before plotting the data.
Current, at most one :class:`seaborn.objects.Stat` can be used, and it
must be passed first. This constraint will be relaxed in the future.
orient : "x", "y", "v", or "h"
The orientation of the mark, which affects how the stat is computed.
Typically corresponds to the axis that defines groups for aggregation.
The "v" (vertical) and "h" (horizontal) options are synonyms for "x" / "y",
but may be more intuitive with some marks. When not provided, an
orientation will be inferred from characteristics of the data and scales.
legend : bool
Option to suppress the mark/mappings for this layer from the legend.
data : DataFrame or dict
Data source to override the global source provided in the constructor.
variables : data vectors or identifiers
Additional layer-specific variables, including variables that will be
passed directly to the stat without scaling.
passed directly to the transforms without scaling.
"""
if not isinstance(mark, Mark):
msg = f"mark must be a Mark instance, not {type(mark)!r}."
raise TypeError(msg)

if stat is not None and not isinstance(stat, Stat):
msg = f"stat must be a Stat instance, not {type(stat)!r}."
# TODO This API for transforms was a late decision, and previously Plot.add
# accepted 0 or 1 Stat instances and 0, 1, or a list of Move instances.
# It will take some work to refactor the internals so that Stat and Move are
# treated identically, and until then well need to "unpack" the transforms
# here and enforce limitations on the order / types.

stat: Optional[Stat]
move: Optional[List[Move]]
error = False
if not transforms:
stat, move = None, None
elif isinstance(transforms[0], Stat):
stat = transforms[0]
move = [m for m in transforms[1:] if isinstance(m, Move)]
error = len(move) != len(transforms) - 1
else:
stat = None
move = [m for m in transforms if isinstance(m, Move)]
error = len(move) != len(transforms)

if error:
msg = " ".join([
"Transforms must have at most one Stat type (in the first position),",
"and all others must be a Move type. Given transform type(s):",
", ".join(str(type(t).__name__) for t in transforms) + "."
])
raise TypeError(msg)

# TODO decide how to allow Mark to have default Stat/Move
# if stat is None and hasattr(mark, "default_stat"):
# stat = mark.default_stat()

# TODO it doesn't work to supply scalars to variables, but that would be nice

# TODO accept arbitrary variables defined by the stat (/move?) here
# (but not in the Plot constructor)
# Should stat variables ever go in the constructor, or just in the add call?

new = self._clone()
new._layers.append({
"mark": mark,
"stat": stat,
"move": move,
# TODO it doesn't work to supply scalars to variables, but it should
"vars": variables,
"source": data,
"legend": legend,
Expand Down Expand Up @@ -1232,7 +1245,7 @@ def get_order(var):
move_groupers.insert(0, orient)
order = {var: get_order(var) for var in move_groupers}
groupby = GroupBy(order)
df = move_step(df, groupby, orient)
df = move_step(df, groupby, orient, scales)

df = self._unscale_coords(subplots, df, orient)

Expand Down
Loading