Skip to content

Commit 932b8c4

Browse files
committed
Refactor label resolution
1 parent 9493b0c commit 932b8c4

File tree

1 file changed

+18
-23
lines changed

1 file changed

+18
-23
lines changed

seaborn/_core/plot.py

Lines changed: 18 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -789,6 +789,20 @@ def _extract_data(self, p: Plot) -> tuple[PlotData, list[Layer]]:
789789

790790
return common_data, layers
791791

792+
def _resolve_label(self, p: Plot, var: str, auto_label: str | None) -> str | None:
793+
794+
label: str | None
795+
if var in p._labels:
796+
manual_label = p._labels[var]
797+
if callable(manual_label) and auto_label is not None:
798+
label = manual_label(auto_label)
799+
else:
800+
# mypy needs a lot of help here, I'm not sure why
801+
label = cast(Optional[str], manual_label)
802+
else:
803+
label = auto_label
804+
return label
805+
792806
def _setup_figure(self, p: Plot, common: PlotData, layers: list[Layer]) -> None:
793807

794808
# --- Parsing the faceting/pairing parameterization to specify figure grid
@@ -830,16 +844,7 @@ def _setup_figure(self, p: Plot, common: PlotData, layers: list[Layer]) -> None:
830844
*(layer["data"].names.get(axis_key) for layer in layers)
831845
]
832846
auto_label = next((name for name in names if name is not None), None)
833-
if axis_key in p._labels:
834-
manual_label = p._labels[axis_key]
835-
label: str | None
836-
if callable(manual_label) and auto_label is not None:
837-
label = manual_label(auto_label)
838-
else:
839-
# mypy needs a lot of help here, I'm not sure why
840-
label = cast(Optional[str], manual_label)
841-
else:
842-
label = auto_label
847+
label = self._resolve_label(p, axis_key, auto_label)
843848
ax.set(**{f"{axis}label": label})
844849

845850
# ~~ Decoration visibility
@@ -1196,7 +1201,7 @@ def get_order(var):
11961201
view["ax"].autoscale_view()
11971202

11981203
if layer["legend"]:
1199-
self._update_legend_contents(mark, data, scales, p._labels)
1204+
self._update_legend_contents(p, mark, data, scales)
12001205

12011206
def _scale_coords(self, subplots: list[dict], df: DataFrame) -> DataFrame:
12021207
# TODO stricter type on subplots
@@ -1392,10 +1397,10 @@ def split_generator(keep_na=False) -> Generator:
13921397

13931398
def _update_legend_contents(
13941399
self,
1400+
p: Plot,
13951401
mark: Mark,
13961402
data: PlotData,
13971403
scales: dict[str, Scale],
1398-
titles: dict[str, str | Callable[[str], str] | None],
13991404
) -> None:
14001405
"""Add legend artists / labels for one layer in the plot."""
14011406
if data.frame.empty and data.frames:
@@ -1420,18 +1425,8 @@ def _update_legend_contents(
14201425
part_vars.append(var)
14211426
break
14221427
else:
1423-
# TODO copied from _setup_figure
14241428
auto_title = data.names[var]
1425-
if var in titles:
1426-
manual_title = titles[var]
1427-
title: str | None
1428-
if callable(manual_title) and auto_title is not None:
1429-
title = manual_title(auto_title)
1430-
else:
1431-
# mypy needs a lot of help here, I'm not sure why
1432-
title = cast(Optional[str], manual_title)
1433-
else:
1434-
title = auto_title
1429+
title = self._resolve_label(p, var, auto_title)
14351430
entry = (title, data.ids[var]), [var], (values, labels)
14361431
schema.append(entry)
14371432

0 commit comments

Comments
 (0)