Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/packages/slicing.rst
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ Programmatic data set slicing: SF creation, monitoring utilities, and representa
apply.dask.PandasParallelSFApplier
PandasSFApplier
SFApplier
SliceAwareClassifier
SliceCombinerModule
SlicingClassifier
SlicingFunction
apply.spark.SparkSFApplier
add_slice_labels
Expand Down
2 changes: 1 addition & 1 deletion snorkel/slicing/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,5 +4,5 @@
from .modules.slice_combiner import SliceCombinerModule # noqa: F401
from .monitor import slice_dataframe # noqa: F401
from .sf.core import SlicingFunction, slicing_function # noqa: F401
from .slicing_classifier import SlicingClassifier # noqa: F401
from .sliceaware_classifier import SliceAwareClassifier # noqa: F401
from .utils import add_slice_labels, convert_to_slice_tasks # noqa: F401
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from .utils import add_slice_labels, convert_to_slice_tasks


class SlicingClassifier(MultitaskClassifier):
class SliceAwareClassifier(MultitaskClassifier):
"""A slice-aware classifier that supports training + scoring on slice labels.

NOTE: This model currently only supports binary classification.
Expand Down Expand Up @@ -87,7 +87,7 @@ def __init__(
slice_tasks = convert_to_slice_tasks(self.base_task, slice_names)

# Initialize a MultitaskClassifier with all slice_tasks
model_name = f"{task_name}_slicing_classifier"
model_name = f"{task_name}_sliceaware_classifier"
super().__init__(tasks=slice_tasks, name=model_name, **multitask_kwargs)
self.slice_names = slice_names

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from snorkel.analysis import Scorer
from snorkel.classification import DictDataset
from snorkel.slicing import SFApplier, SlicingClassifier, slicing_function
from snorkel.slicing import SFApplier, SliceAwareClassifier, slicing_function


@slicing_function()
Expand Down Expand Up @@ -63,7 +63,7 @@ def setUp(self):
for split in splits
]

self.slice_model = SlicingClassifier(
self.slice_model = SliceAwareClassifier(
base_architecture=self.mlp,
head_dim=self.hidden_dim,
slice_names=[sf.name for sf in sfs],
Expand Down