Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
53 commits
Select commit Hold shift + click to select a range
b8ee4d6
.
May 3, 2018
5417950
.
May 8, 2018
d11cdfe
AutoSklearnClassifier/Regressor's fit, refit, fit_ensemble now return…
May 9, 2018
66ca590
Initial commit. Work in Progress.
May 14, 2018
43f6bb4
Fix minor printing error in sprint_statistics.
May 14, 2018
f9634a3
Merge pull request #1 from ahn1340/Fix#420
ahn1340 May 16, 2018
3f7cd1a
Merge pull request #2 from ahn1340/Fix#460
ahn1340 May 16, 2018
28f6805
Revert "Fix#460"
ahn1340 May 16, 2018
4f33872
Merge pull request #3 from ahn1340/revert-2-Fix#460
ahn1340 May 16, 2018
6614c23
Merge branch 'development' of https://github.com/ahn1340/auto-sklearn…
May 17, 2018
21e9eec
Raise error if ensemble is not built (#480)
ahn1340 May 18, 2018
2f5b318
Merge remote-tracking branch 'upstream/development' into development
ahn1340 May 29, 2018
5519d5c
Check target type at the beginning of the fitting process.
ahn1340 Jul 3, 2018
eec11b3
.
ahn1340 Jul 4, 2018
6632a53
Fixed minor error in uniitest
ahn1340 Jul 5, 2018
5eb8a14
.
ahn1340 Jul 5, 2018
9c67c1d
Add unittest for target type checking.
ahn1340 Sep 23, 2018
616a35a
.
ahn1340 Sep 23, 2018
7d9a315
.
ahn1340 Oct 1, 2018
c75d2b2
[Debug] try with numpy version 1.14.5
ahn1340 Oct 17, 2018
34d0a7c
[Debug] Check if numpy version 1.14.6 raises error.
ahn1340 Oct 18, 2018
e207cec
Check target type at the beginning of the fitting process.
ahn1340 Jul 3, 2018
ea11474
.
ahn1340 Jul 4, 2018
71c0abb
Fixed minor error in uniitest
ahn1340 Jul 5, 2018
82ad644
.
ahn1340 Jul 5, 2018
bec184b
Add unittest for target type checking.
ahn1340 Sep 23, 2018
45f0757
.
ahn1340 Sep 23, 2018
0b8758d
.
ahn1340 Oct 1, 2018
080961a
[Debug] Check if numpy version 1.14.6 raises error.
ahn1340 Oct 18, 2018
612e132
Merge branch 'target_type' of https://github.com/ahn1340/auto-sklearn…
ahn1340 Oct 18, 2018
f3384a9
Fix numpy version to 1.14.5
ahn1340 Oct 18, 2018
f4cf7c7
Add comment to Mock in test_type_of_target
ahn1340 Oct 18, 2018
00e7133
Fix line length in example_parallel.py
ahn1340 Oct 19, 2018
ddd4997
Fix minor error
ahn1340 Oct 19, 2018
ca1dc63
Merge branch 'development' of https://github.com/ahn1340/auto-sklearn…
Oct 29, 2018
4c8d853
FIX classifier returning prediction larger than 1
Nov 21, 2018
21493cb
Remove comments
Nov 21, 2018
24da442
ADD unittest for ensemble_selection.predict()
Nov 27, 2018
597cc8b
minor FIX
Nov 27, 2018
a2df7ca
Merge branch 'development' of https://github.com/automl/auto-sklearn …
ahn1340 Nov 29, 2018
096d207
ADD assertion in predict_proba to check probabilities sum up to 1.
Nov 30, 2018
0ce0c16
REVERT changes in autosklearn/ensemble_builder.py
Nov 30, 2018
a789c41
simplify ensemble prediction method
Nov 30, 2018
0066005
Merge branch 'target_type' into classifier_bug
Nov 30, 2018
cc31c99
Merge branch 'development' into classifier_bug
Nov 30, 2018
6a0f737
Modify assertion statement
Dec 3, 2018
9eeb0ad
ADD case check in ensemble_selection.predict()
ahn1340 Dec 5, 2018
5f1970b
Fix minor error in pred_probs verficiation.
ahn1340 Dec 5, 2018
02fbaeb
Modify unittest for ensemble_selection.predict()
ahn1340 Dec 5, 2018
c3a2aa6
FIX flake8 errors
ahn1340 Dec 5, 2018
0837c67
FIX flake8 error
ahn1340 Dec 5, 2018
4cce583
ADD Ignore assertion for multilabel, check probabilities lie between …
ahn1340 Dec 6, 2018
b6336e3
Debug flake8 error
ahn1340 Dec 6, 2018
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 17 additions & 4 deletions autosklearn/ensembles/ensemble_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,10 +209,23 @@ def _bagging(self, predictions, labels, fraction=0.5, n_bags=20):
return np.array(order_of_each_bag)

def predict(self, predictions):
non_null_weights = (weight for weight in self.weights_ if weight > 0)
for i, weight in enumerate(non_null_weights):
predictions[i] *= weight
return np.sum(predictions, axis=0)
predictions = np.asarray(predictions)

# if predictions.shape[0] == len(self.weights_),
# predictions include those of zero-weight models.
if predictions.shape[0] == len(self.weights_):
return np.average(predictions, axis=0, weights=self.weights_)

# if prediction model.shape[0] == len(non_null_weights),
# predictions do not include those of zero-weight models.
elif predictions.shape[0] == np.count_nonzero(self.weights_):
non_null_weights = [w for w in self.weights_ if w > 0]
return np.average(predictions, axis=0, weights=non_null_weights)

# If none of the above applies, then something must have gone wrong.
else:
raise ValueError("The dimensions of ensemble predictions"
" and ensemble weights do not match!")

def __str__(self):
return 'Ensemble Selection:\n\tTrajectory: %s\n\tMembers: %s' \
Expand Down
22 changes: 21 additions & 1 deletion autosklearn/estimators.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# -*- encoding: utf-8 -*-
from sklearn.base import BaseEstimator
import numpy as np

from autosklearn.automl import AutoMLClassifier, AutoMLRegressor
from autosklearn.util.backend import create
Expand Down Expand Up @@ -486,6 +487,9 @@ def fit(self, X, y,
raise ValueError("classification with data of type %s is"
" not supported" % target_type)

# remember target type for using in predict_proba later.
self.target_type = target_type

super().fit(
X=X,
y=y,
Expand Down Expand Up @@ -527,9 +531,25 @@ def predict_proba(self, X, batch_size=None, n_jobs=1):
The predicted class probabilities.

"""
return super().predict_proba(
pred_proba = super().predict_proba(
X, batch_size=batch_size, n_jobs=n_jobs)

# Check if all probabilities sum up to 1.
# Assert only if target type is not multilabel-indicator.
if self.target_type not in ['multilabel-indicator']:
assert(
np.allclose(
np.sum(pred_proba, axis=1),
np.ones_like(pred_proba[:, 0]))
), "prediction probability does not sum up to 1!"

# Check that all probability values lie between 0 and 1.
assert(
(pred_proba >= 0).all() and (pred_proba <= 1).all()
), "found prediction probability value outside of [0, 1]!"

return pred_proba

def _get_automl_class(self):
return AutoMLClassifier

Expand Down
71 changes: 68 additions & 3 deletions test/test_ensemble_builder/test_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
import unittest
import unittest.mock

from autosklearn.metrics import roc_auc, accuracy
from autosklearn.ensembles.ensemble_selection import EnsembleSelection
from autosklearn.ensemble_builder import EnsembleBuilder, Y_VALID, Y_TEST
import numpy as np

this_directory = os.path.dirname(__file__)
sys.path.append(this_directory)

from autosklearn.ensemble_builder import EnsembleBuilder, Y_ENSEMBLE, Y_VALID, Y_TEST
from autosklearn.metrics import roc_auc


class BackendMock(object):

Expand Down Expand Up @@ -260,3 +260,68 @@ def testLimit(self):

# it should try to reduce ensemble_nbest until it also failed at 2
self.assertEqual(ensbuilder.ensemble_nbest,1)


class EnsembleSelectionTest(unittest.TestCase):
def testPredict(self):
# Test that ensemble prediction applies weights correctly to given
# predictions. There are two possible cases:
# 1) predictions.shape[0] == len(self.weights_). In this case,
# predictions include those made by zero-weighted models. Therefore,
# we simply apply each weights to the corresponding model preds.
# 2) predictions.shape[0] < len(self.weights_). In this case,
# predictions exclude those made by zero-weighted models. Therefore,
# we first exclude all occurrences of zero in self.weights_, and then
# apply the weights.
# If none of the above is the case, predict() raises Error.
ensemble = EnsembleSelection(ensemble_size=3,
task_type=1,
metric=accuracy,
)
# Test for case 1. Create (3, 2, 2) predictions.
per_model_pred = np.array([
[[0.9, 0.1],
[0.4, 0.6]],
[[0.8, 0.2],
[0.3, 0.7]],
[[1.0, 0.0],
[0.1, 0.9]]
])
# Weights of 3 hypothetical models
ensemble.weights_ = [0.7, 0.2, 0.1]
pred = ensemble.predict(per_model_pred)
truth = np.array([[0.89, 0.11], # This should be the true prediction.
[0.35, 0.65]])
self.assertTrue(np.allclose(pred, truth))

# Test for case 2.
per_model_pred = np.array([
[[0.9, 0.1],
[0.4, 0.6]],
[[0.8, 0.2],
[0.3, 0.7]],
[[1.0, 0.0],
[0.1, 0.9]]
])
# The third model now has weight of zero.
ensemble.weights_ = [0.7, 0.2, 0.0, 0.1]
pred = ensemble.predict(per_model_pred)
truth = np.array([[0.89, 0.11],
[0.35, 0.65]])
self.assertTrue(np.allclose(pred, truth))

# Test for error case.
per_model_pred = np.array([
[[0.9, 0.1],
[0.4, 0.6]],
[[0.8, 0.2],
[0.3, 0.7]],
[[1.0, 0.0],
[0.1, 0.9]]
])
# Now the weights have 2 zero weights and 2 non-zero weights,
# which is incompatible.
ensemble.weights_ = [0.6, 0.0, 0.0, 0.4]

with self.assertRaises(ValueError):
ensemble.predict(per_model_pred)