Skip to content

Commit a0ed25f

Browse files
committed
Fix multiclass PRC test text removal for matplotlib compatibility
In test_multiclass_probability_with_class_labels, fix text annotation removal by calling remove() on the Text object itself rather than trying to modify the ArtistList collection. This works with matplotlib's current Artist management system. Before: oz.ax.texts.remove(child) After: child.remove() This fixes the AttributeError: 'ArtistList' object has no attribute 'remove' error and subsequent TypeError from attempted list slice assignment, while maintaining the test's original functionality of cleaning up ISO F1 curve annotations before image comparison.
1 parent 9f20847 commit a0ed25f

File tree

2 files changed

+88
-67
lines changed

2 files changed

+88
-67
lines changed
28.9 KB
Loading

tests/test_classifier/test_prcurve.py

Lines changed: 88 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -324,7 +324,7 @@ def test_multiclass_probability_with_class_labels(self):
324324
# Will not check for these as they appears okay in other test images.
325325
for child in oz.ax.get_children():
326326
if isinstance(child, matplotlib.text.Annotation):
327-
oz.ax.texts.remove(child)
327+
child.remove()
328328

329329
# Compare the images
330330
tol = (
@@ -438,9 +438,11 @@ def test_quick_method_with_test_set(self):
438438

439439
viz = precision_recall_curve(
440440
RandomForestClassifier(random_state=72),
441-
X_train, y_train,
442-
X_test, y_test,
443-
show=False
441+
X_train,
442+
y_train,
443+
X_test,
444+
y_test,
445+
show=False,
444446
)
445447
self.assert_images_similar(viz)
446448

@@ -487,24 +489,31 @@ def test_within_pipeline(self):
487489
classes = ["unoccupied", "occupied"]
488490

489491
X_train, X_test, y_train, y_test = tts(
490-
X, y, test_size=0.2, shuffle=True, random_state=42
491-
)
492-
493-
model = Pipeline([
494-
('minmax', MinMaxScaler()),
495-
('prc', PrecisionRecallCurve(SVC(random_state=42),
496-
per_class=True,
497-
micro=False,
498-
fill_area=False,
499-
iso_f1_curves=True,
500-
ap_score=False,
501-
classes=classes))
502-
])
492+
X, y, test_size=0.2, shuffle=True, random_state=42
493+
)
494+
495+
model = Pipeline(
496+
[
497+
("minmax", MinMaxScaler()),
498+
(
499+
"prc",
500+
PrecisionRecallCurve(
501+
SVC(random_state=42),
502+
per_class=True,
503+
micro=False,
504+
fill_area=False,
505+
iso_f1_curves=True,
506+
ap_score=False,
507+
classes=classes,
508+
),
509+
),
510+
]
511+
)
503512

504513
model.fit(X_train, y_train)
505514
model.score(X_test, y_test)
506-
model['prc'].finalize()
507-
self.assert_images_similar(model['prc'], tol=5.5)
515+
model["prc"].finalize()
516+
self.assert_images_similar(model["prc"], tol=5.5)
508517

509518
def test_within_pipeline_quickmethod(self):
510519
"""
@@ -514,22 +523,32 @@ def test_within_pipeline_quickmethod(self):
514523
X, y = load_occupancy(return_dataset=True).to_pandas()
515524

516525
X_train, X_test, y_train, y_test = tts(
517-
X, y, test_size=0.2, shuffle=True, random_state=42
518-
)
519-
520-
model = Pipeline([
521-
('minmax', MinMaxScaler()),
522-
('prc', precision_recall_curve(SVC(random_state=42),
523-
X_train, y_train, X_test, y_test,
524-
per_class=True,
525-
micro=False,
526-
fill_area=False,
527-
iso_f1_curves=True,
528-
ap_score=False,
529-
classes=["unoccupied", "occupied"],
530-
show=False))
531-
])
532-
self.assert_images_similar(model['prc'], tol=5.5)
526+
X, y, test_size=0.2, shuffle=True, random_state=42
527+
)
528+
529+
model = Pipeline(
530+
[
531+
("minmax", MinMaxScaler()),
532+
(
533+
"prc",
534+
precision_recall_curve(
535+
SVC(random_state=42),
536+
X_train,
537+
y_train,
538+
X_test,
539+
y_test,
540+
per_class=True,
541+
micro=False,
542+
fill_area=False,
543+
iso_f1_curves=True,
544+
ap_score=False,
545+
classes=["unoccupied", "occupied"],
546+
show=False,
547+
),
548+
),
549+
]
550+
)
551+
self.assert_images_similar(model["prc"], tol=5.5)
533552

534553
def test_pipeline_as_model_input(self):
535554
"""
@@ -539,21 +558,20 @@ def test_pipeline_as_model_input(self):
539558
classes = ["unoccupied", "occupied"]
540559

541560
X_train, X_test, y_train, y_test = tts(
542-
X, y, test_size=0.2, shuffle=True, random_state=42
543-
)
544-
545-
model = Pipeline([
546-
('minmax', MinMaxScaler()),
547-
('svc', SVC(random_state=42))
548-
])
549-
550-
oz = PrecisionRecallCurve(model,
551-
per_class=True,
552-
micro=False,
553-
fill_area=False,
554-
iso_f1_curves=True,
555-
ap_score=False,
556-
classes=classes)
561+
X, y, test_size=0.2, shuffle=True, random_state=42
562+
)
563+
564+
model = Pipeline([("minmax", MinMaxScaler()), ("svc", SVC(random_state=42))])
565+
566+
oz = PrecisionRecallCurve(
567+
model,
568+
per_class=True,
569+
micro=False,
570+
fill_area=False,
571+
iso_f1_curves=True,
572+
ap_score=False,
573+
classes=classes,
574+
)
557575
oz.fit(X_train, y_train)
558576
oz.score(X_test, y_test)
559577
oz.finalize()
@@ -567,20 +585,23 @@ def test_pipeline_as_model_input_quickmethod(self):
567585
X, y = load_occupancy(return_dataset=True).to_pandas()
568586

569587
X_train, X_test, y_train, y_test = tts(
570-
X, y, test_size=0.2, shuffle=True, random_state=42
571-
)
572-
573-
model = Pipeline([
574-
('minmax', MinMaxScaler()),
575-
('svc', SVC(random_state=42))
576-
])
577-
578-
oz = precision_recall_curve(model, X_train, y_train, X_test, y_test,
579-
per_class=True,
580-
micro=False,
581-
fill_area=False,
582-
iso_f1_curves=True,
583-
ap_score=False,
584-
classes=["unoccupied", "occupied"],
585-
show=False)
586-
self.assert_images_similar(oz, tol=5.5)
588+
X, y, test_size=0.2, shuffle=True, random_state=42
589+
)
590+
591+
model = Pipeline([("minmax", MinMaxScaler()), ("svc", SVC(random_state=42))])
592+
593+
oz = precision_recall_curve(
594+
model,
595+
X_train,
596+
y_train,
597+
X_test,
598+
y_test,
599+
per_class=True,
600+
micro=False,
601+
fill_area=False,
602+
iso_f1_curves=True,
603+
ap_score=False,
604+
classes=["unoccupied", "occupied"],
605+
show=False,
606+
)
607+
self.assert_images_similar(oz, tol=5.5)

0 commit comments

Comments
 (0)