Skip to content

Commit 430897c

Browse files
committed
Fix Plot legend with > 2 layers
Fixes #3023 This ended up being a simple typo :(
1 parent 63186a4 commit 430897c

File tree

2 files changed

+18
-4
lines changed

2 files changed

+18
-4
lines changed

seaborn/_core/plot.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1582,7 +1582,7 @@ def _make_legend(self, p: Plot) -> None:
15821582
merged_contents: dict[
15831583
tuple[str, str | int], tuple[list[Artist], list[str]],
15841584
] = {}
1585-
for key, artists, labels in self._legend_contents:
1585+
for key, new_artists, labels in self._legend_contents:
15861586
# Key is (name, id); we need the id to resolve variable uniqueness,
15871587
# but will need the name in the next step to title the legend
15881588
if key in merged_contents:
@@ -1591,11 +1591,11 @@ def _make_legend(self, p: Plot) -> None:
15911591
for i, artist in enumerate(existing_artists):
15921592
# Matplotlib accepts a tuple of artists and will overlay them
15931593
if isinstance(artist, tuple):
1594-
artist += artist[i],
1594+
artist += new_artists[i],
15951595
else:
1596-
existing_artists[i] = artist, artists[i]
1596+
existing_artists[i] = artist, new_artists[i]
15971597
else:
1598-
merged_contents[key] = artists.copy(), labels
1598+
merged_contents[key] = new_artists.copy(), labels
15991599

16001600
# TODO explain
16011601
loc = "center right" if self._pyplot else "center left"

tests/_core/test_plot.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1963,6 +1963,20 @@ def _legend_artist(self, variables, value, scales):
19631963
assert len(contents.findobj(mpl.lines.Line2D)) == len(names)
19641964
assert len(contents.findobj(mpl.patches.Patch)) == len(names)
19651965

1966+
def test_three_layers(self, xy):
1967+
1968+
class MockMarkLine(MockMark):
1969+
def _legend_artist(self, variables, value, scales):
1970+
return mpl.lines.Line2D([], [])
1971+
1972+
s = pd.Series(["a", "b", "a", "c"], name="s")
1973+
p = Plot(**xy, color=s)
1974+
for _ in range(3):
1975+
p = p.add(MockMarkLine())
1976+
p = p.plot()
1977+
texts = p._figure.legends[0].get_texts()
1978+
assert len(texts) == len(s.unique())
1979+
19661980
def test_identity_scale_ignored(self, xy):
19671981

19681982
s = pd.Series(["r", "g", "b", "g"])

0 commit comments

Comments
 (0)