Skip to content

Commit 1ab0e55

Browse files
committed
Clean up Plot.plot by moving some logic to _setup_scales
1 parent e5dc350 commit 1ab0e55

File tree

1 file changed

+26
-18
lines changed

1 file changed

+26
-18
lines changed

seaborn/_core/plot.py

Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -654,40 +654,37 @@ def save(self, fname, **kwargs) -> Plot:
654654

655655
def plot(self, pyplot=False) -> Plotter:
656656
"""
657-
Compile the plot and return the :class:`Plotter` engine.
658-
657+
Compile the plot spec and return a Plotter object.
659658
"""
660659
# TODO if we have _target object, pyplot should be determined by whether it
661660
# is hooked into the pyplot state machine (how do we check?)
662661

663662
plotter = Plotter(pyplot=pyplot)
664663

664+
# Process the variable assignments and initialize the figure
665665
common, layers = plotter._extract_data(self)
666666
plotter._setup_figure(self, common, layers)
667667

668+
# Process the scale spec for coordinate variables and transform their data
668669
coord_vars = [v for v in self._variables if re.match(r"^x|y", v)]
669670
plotter._setup_scales(self, common, layers, coord_vars)
670671

672+
# Apply statistical transform(s)
671673
plotter._compute_stats(self, layers)
672674

673-
other_vars = set() # TODO move this into a method
674-
for layer in layers:
675-
if layer["data"].frame.empty and layer["data"].frames:
676-
for df in layer["data"].frames.values():
677-
other_vars.update(df.columns)
678-
else:
679-
other_vars.update(layer["data"].frame.columns)
680-
other_vars -= set(coord_vars)
681-
plotter._setup_scales(self, common, layers, list(other_vars))
675+
# Process scale spec for semantic variables and coordinates computed by stat
676+
plotter._setup_scales(self, common, layers)
682677

683678
# TODO Remove these after updating other methods
684679
# ---- Maybe have debug= param that attaches these when True?
685680
plotter._data = common
686681
plotter._layers = layers
687682

683+
# Process the data for each layer and add matplotlib artists
688684
for layer in layers:
689685
plotter._plot_layer(self, layer)
690686

687+
# Add various figure decorations
691688
plotter._make_legend(self)
692689
plotter._finalize_figure(self)
693690

@@ -696,7 +693,6 @@ def plot(self, pyplot=False) -> Plotter:
696693
def show(self, **kwargs) -> None:
697694
"""
698695
Render and display the plot.
699-
700696
"""
701697
# TODO make pyplot configurable at the class level, and when not using,
702698
# import IPython.display and call on self to populate cell output?
@@ -1001,9 +997,23 @@ def _get_subplot_data(self, df, var, view, share_state):
1001997
return seed_values
1002998

1003999
def _setup_scales(
1004-
self, p: Plot, common: PlotData, layers: list[Layer], variables: list[str],
1000+
self, p: Plot,
1001+
common: PlotData,
1002+
layers: list[Layer],
1003+
variables: list[str] | None = None,
10051004
) -> None:
10061005

1006+
if variables is None:
1007+
# Add variables that have data but not a scale, which happens
1008+
# because this method can be called multiple time, to handle
1009+
# variables added during the Stat transform.
1010+
variables = []
1011+
for layer in layers:
1012+
variables.extend(layer["data"].frame.columns)
1013+
for df in layer["data"].frames.values():
1014+
variables.extend(v for v in df if v not in variables)
1015+
variables = [v for v in variables if v not in self._scales]
1016+
10071017
for var in variables:
10081018

10091019
# Determine whether this is a coordinate variable
@@ -1028,11 +1038,9 @@ def _setup_scales(
10281038
cols = [var, "col", "row"]
10291039
parts = [common.frame.filter(cols)]
10301040
for layer in layers:
1031-
if layer["data"].frame.empty and layer["data"].frames:
1032-
for df in layer["data"].frames.values():
1033-
parts.append(df.filter(cols))
1034-
else:
1035-
parts.append(layer["data"].frame.filter(cols))
1041+
parts.append(layer["data"].frame.filter(cols))
1042+
for df in layer["data"].frames.values():
1043+
parts.append(df.filter(cols))
10361044
var_df = pd.concat(parts, ignore_index=True)
10371045

10381046
prop = PROPERTIES[prop_key]

0 commit comments

Comments
 (0)