Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions seaborn/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -286,7 +286,7 @@ def __init__(

if map_type == "numeric":

levels, lookup_table, norm = self.numeric_mapping(
levels, lookup_table, norm, size_range = self.numeric_mapping(
data, sizes, norm,
)

Expand All @@ -297,6 +297,7 @@ def __init__(
levels, lookup_table = self.categorical_mapping(
data, sizes, order,
)
size_range = None

# --- Option 3: datetime mapping

Expand All @@ -308,11 +309,13 @@ def __init__(
# pandas and numpy represent datetime64 data
list(data), sizes, order,
)
size_range = None

self.map_type = map_type
self.levels = levels
self.norm = norm
self.sizes = sizes
self.size_range = size_range
self.lookup_table = lookup_table

def infer_map_type(self, norm, sizes, var_type):
Expand All @@ -334,9 +337,7 @@ def _lookup_single(self, key):
normed = self.norm(key)
if np.ma.is_masked(normed):
normed = np.nan
size_values = self.lookup_table.values()
size_range = min(size_values), max(size_values)
value = size_range[0] + normed * np.ptp(size_range)
value = self.size_range[0] + normed * np.ptp(self.size_range)
return value

def categorical_mapping(self, data, sizes, order):
Expand Down Expand Up @@ -385,15 +386,15 @@ def categorical_mapping(self, data, sizes, order):
# across the visual representation of the data. But at this
# point, we don't know the visual representation. Likely we
# want to change the logic of this Mapping so that it gives
# points on a nornalized range that then gets unnormalized
# points on a normalized range that then gets un-normalized
# when we know what we're drawing. But given the way the
# package works now, this way is cleanest.
sizes = self.plotter._default_size_range

# For categorical sizes, use regularly-spaced linear steps
# between the minimum and maximum sizes. Then reverse the
# ramp so that the largest value is used for the first entry
# in size_order, etc. This is because "ordered" categoricals
# in size_order, etc. This is because "ordered" categories
# are often though to go in decreasing priority.
sizes = np.linspace(*sizes, len(levels))[::-1]
lookup_table = dict(zip(levels, sizes))
Expand Down Expand Up @@ -437,7 +438,7 @@ def numeric_mapping(self, data, sizes, norm):

# When not provided, we get the size range from the plotter
# object we are attached to. See the note in the categorical
# method about how this is suboptimal for future development.:
# method about how this is suboptimal for future development.
size_range = self.plotter._default_size_range

# Now that we know the minimum and maximum sizes that will get drawn,
Expand Down Expand Up @@ -477,7 +478,7 @@ def numeric_mapping(self, data, sizes, norm):
sizes = lo + sizes_scaled * (hi - lo)
lookup_table = dict(zip(levels, sizes))

return levels, lookup_table, norm
return levels, lookup_table, norm, size_range


@share_init_params_with_map
Expand Down
34 changes: 34 additions & 0 deletions seaborn/tests/test_relational.py
Original file line number Diff line number Diff line change
Expand Up @@ -1626,6 +1626,40 @@ def test_linewidths(self, long_df):
scatterplot(data=long_df, x="x", y="y", linewidth=lw)
assert ax.collections[0].get_linewidths().item() == lw

def test_size_norm_extrapolation(self):

# https://github.com/mwaskom/seaborn/issues/2539
x = np.arange(0, 20, 2)
f, axs = plt.subplots(1, 2, sharex=True, sharey=True)

slc = 5
kws = dict(sizes=(50, 200), size_norm=(0, x.max()), legend="brief")

scatterplot(x=x, y=x, size=x, ax=axs[0], **kws)
scatterplot(x=x[:slc], y=x[:slc], size=x[:slc], ax=axs[1], **kws)

assert np.allclose(
axs[0].collections[0].get_sizes()[:slc],
axs[1].collections[0].get_sizes()
)

legends = [ax.legend_ for ax in axs]
legend_data = [
{
label.get_text(): handle.get_sizes().item()
for label, handle in zip(legend.get_texts(), legend.legendHandles)
} for legend in legends
]

for key in set(legend_data[0]) & set(legend_data[1]):
if key == "y":
# At some point (circa 3.0) matplotlib auto-added pandas series
# with a valid name into the legend, which messes up this test.
# I can't track down when that was added (or removed), so let's
# just anticipate and ignore it here.
continue
assert legend_data[0][key] == legend_data[1][key]

def test_datetime_scale(self, long_df):

ax = scatterplot(data=long_df, x="t", y="y")
Expand Down