Skip to content

Commit b5905d6

Browse files
authored
Merge pull request #68 from Toloka/fix_rasa
Fix RASA & HRRASA
2 parents 57288fb + ce81975 commit b5905d6

File tree

2 files changed

+18
-18
lines changed

2 files changed

+18
-18
lines changed

crowdkit/aggregation/texts/text_hrrasa.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -60,11 +60,11 @@ def fit_predict_scores(self, data: pd.DataFrame, true_objects: pd.Series = None)
6060
"""Fit the model and return scores.
6161
6262
Args:
63-
data (DataFrame): Workers' outputs.
64-
A pandas.DataFrame containing `task`, `worker` and `output` columns.
65-
true_objects (Series): Tasks' ground truth labels.
63+
data (DataFrame): Workers' responses.
64+
A pandas.DataFrame containing `task`, `worker` and `text` columns.
65+
true_objects (Series): Tasks' ground truth texts.
6666
A pandas.Series indexed by `task` such that `labels.loc[task]`
67-
is the tasks's ground truth label.
67+
is the tasks's ground truth text.
6868
6969
Returns:
7070
DataFrame: Tasks' label scores.
@@ -78,11 +78,11 @@ def fit_predict(self, data: pd.DataFrame, true_objects: pd.Series = None) -> pd.
7878
"""Fit the model and return aggregated texts.
7979
8080
Args:
81-
data (DataFrame): Workers' outputs.
82-
A pandas.DataFrame containing `task`, `worker` and `output` columns.
83-
true_objects (Series): Tasks' ground truth labels.
81+
data (DataFrame): Workers' responses.
82+
A pandas.DataFrame containing `task`, `worker` and `text` columns.
83+
true_objects (Series): Tasks' ground truth texts.
8484
A pandas.Series indexed by `task` such that `labels.loc[task]`
85-
is the tasks's ground truth label.
85+
is the tasks's ground truth text.
8686
8787
Returns:
8888
Series: Tasks' texts.
@@ -91,11 +91,11 @@ def fit_predict(self, data: pd.DataFrame, true_objects: pd.Series = None) -> pd.
9191
"""
9292

9393
hrrasa_results = self._hrrasa.fit_predict(self._encode_data(data), self._encode_true_objects(true_objects))
94-
self.texts_ = hrrasa_results.reset_index()[['task', 'output']].set_index('task')
94+
self.texts_ = hrrasa_results.reset_index()[['task', 'output']].rename(columns={'output': 'text'}).set_index('task')
9595
return self.texts_
9696

9797
def _encode_data(self, data: pd.DataFrame) -> pd.DataFrame:
98-
data = data[['task', 'worker', 'output']]
98+
data = data[['task', 'worker', 'text']].rename(columns={'text': 'output'})
9999
data['embedding'] = data.output.apply(self.encoder)
100100
return data
101101

crowdkit/aggregation/texts/text_rasa.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -67,11 +67,11 @@ def fit_predict_scores(self, data: pd.DataFrame, true_objects: Optional[pd.Serie
6767
"""Fit the model and return scores.
6868
6969
Args:
70-
data (DataFrame): Workers' outputs.
70+
data (DataFrame): Workers' responses.
7171
A pandas.DataFrame containing `task`, `worker` and `output` columns.
72-
true_objects (Series): Tasks' ground truth labels.
72+
true_objects (Series): Tasks' ground truth texts.
7373
A pandas.Series indexed by `task` such that `labels.loc[task]`
74-
is the tasks's ground truth label.
74+
is the tasks's ground truth text.
7575
7676
Returns:
7777
DataFrame: Tasks' label scores.
@@ -85,11 +85,11 @@ def fit_predict(self, data: pd.DataFrame, true_objects: Optional[pd.Series] = No
8585
"""Fit the model and return aggregated texts.
8686
8787
Args:
88-
data (DataFrame): Workers' outputs.
88+
data (DataFrame): Workers' responses.
8989
A pandas.DataFrame containing `task`, `worker` and `output` columns.
90-
true_objects (Series): Tasks' ground truth labels.
90+
true_objects (Series): Tasks' ground truth texts.
9191
A pandas.Series indexed by `task` such that `labels.loc[task]`
92-
is the tasks's ground truth label.
92+
is the tasks's ground truth text.
9393
9494
Returns:
9595
Series: Tasks' texts.
@@ -98,11 +98,11 @@ def fit_predict(self, data: pd.DataFrame, true_objects: Optional[pd.Series] = No
9898
"""
9999

100100
rasa_results = self._rasa.fit_predict(self._encode_data(data), self._encode_true_objects(true_objects))
101-
self.texts_ = rasa_results.reset_index()[['task', 'output']].set_index('task')
101+
self.texts_ = rasa_results.reset_index()[['task', 'output']].rename(columns={'output': 'text'}).set_index('task')
102102
return self.texts_
103103

104104
def _encode_data(self, data: pd.DataFrame) -> pd.DataFrame:
105-
data = data[['task', 'worker', 'output']]
105+
data = data[['task', 'worker', 'text']].rename(columns={'text': 'output'})
106106
data['embedding'] = data.output.apply(self.encoder)
107107
return data
108108

0 commit comments

Comments
 (0)