Skip to content
Merged
Show file tree
Hide file tree
Changes from 14 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 83 additions & 48 deletions autosklearn/automl.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
from __future__ import annotations

from typing import Any, Callable, Iterable, Mapping, Optional, Tuple
from typing import Any, Callable, Iterable, Mapping, Optional, Sequence, Tuple

import copy
import io
import itertools
import json
import logging.handlers
import multiprocessing
Expand Down Expand Up @@ -66,7 +67,7 @@
from autosklearn.evaluation import ExecuteTaFuncWithQueue, get_cost_of_crash
from autosklearn.evaluation.abstract_evaluator import _fit_and_suppress_warnings
from autosklearn.evaluation.train_evaluator import TrainEvaluator, _fit_with_budget
from autosklearn.metrics import Scorer, calculate_metric, default_metric_for_task
from autosklearn.metrics import Scorer, compute_single_metric, default_metric_for_task
from autosklearn.pipeline.base import BasePipeline
from autosklearn.pipeline.components.classification import ClassifierChoice
from autosklearn.pipeline.components.data_preprocessing.categorical_encoding import (
Expand Down Expand Up @@ -210,7 +211,7 @@ def __init__(
get_smac_object_callback: Optional[Callable] = None,
smac_scenario_args: Optional[Mapping] = None,
logging_config: Optional[Mapping] = None,
metric: Optional[Scorer] = None,
metrics: Sequence[Scorer] | None = None,
scoring_functions: Optional[list[Scorer]] = None,
get_trials_callback: Optional[IncorporateRunResultCallback] = None,
dataset_compression: bool | Mapping[str, Any] = True,
Expand Down Expand Up @@ -244,7 +245,7 @@ def __init__(
self._delete_tmp_folder_after_terminate = delete_tmp_folder_after_terminate
self._time_for_task = time_left_for_this_task
self._per_run_time_limit = per_run_time_limit
self._metric = metric
self._metrics = metrics
self._ensemble_size = ensemble_size
self._ensemble_nbest = ensemble_nbest
self._max_models_on_disc = max_models_on_disc
Expand All @@ -265,7 +266,7 @@ def __init__(
initial_configurations_via_metalearning
)

self._scoring_functions = scoring_functions or {}
self._scoring_functions = scoring_functions or []
self._resampling_strategy_arguments = resampling_strategy_arguments or {}

# Single core, local runs should use fork to prevent the __main__ requirements
Expand Down Expand Up @@ -422,8 +423,8 @@ def _do_dummy_prediction(self) -> None:
if self._resampling_strategy in ["partial-cv", "partial-cv-iterative-fit"]:
return

if self._metric is None:
raise ValueError("Metric was not set")
if self._metrics is None:
raise ValueError("Metric/Metrics was/were not set")

# Dummy prediction always have num_run set to 1
dummy_run_num = 1
Expand All @@ -447,11 +448,11 @@ def _do_dummy_prediction(self) -> None:
resampling_strategy=self._resampling_strategy,
initial_num_run=dummy_run_num,
stats=stats,
metric=self._metric,
metrics=self._metrics,
memory_limit=memory_limit,
disable_file_output=self._disable_evaluator_output,
abort_on_first_run_crash=False,
cost_for_crash=get_cost_of_crash(self._metric),
cost_for_crash=get_cost_of_crash(self._metrics),
port=self._logger_port,
pynisher_context=self._multiprocessing_context,
**self._resampling_strategy_arguments,
Expand Down Expand Up @@ -611,8 +612,8 @@ def fit(
self._task = task

# Assign a metric if it doesnt exist
if self._metric is None:
self._metric = default_metric_for_task[self._task]
if self._metrics is None:
self._metrics = [default_metric_for_task[self._task]]

if dataset_name is None:
dataset_name = str(uuid.uuid1(clock_seq=os.getpid()))
Expand Down Expand Up @@ -690,11 +691,16 @@ def fit(
# The metric must exist as of this point
# It can be provided in the constructor, or automatically
# defined in the estimator fit call
if self._metric is None:
raise ValueError("No metric given.")
if not isinstance(self._metric, Scorer):
if isinstance(self._metrics, Sequence):
for entry in self._metrics:
if not isinstance(entry, Scorer):
raise ValueError(
"Metric {entry} must be instance of autosklearn.metrics.Scorer."
)
else:
raise ValueError(
"Metric must be instance of " "autosklearn.metrics.Scorer."
"Metric must be a sequence of instances of "
"autosklearn.metrics.Scorer."
)

# If no dask client was provided, we create one, so that we can
Expand Down Expand Up @@ -790,7 +796,7 @@ def fit(
backend=copy.deepcopy(self._backend),
dataset_name=dataset_name,
task=self._task,
metric=self._metric,
metric=self._metrics[0],
ensemble_size=self._ensemble_size,
ensemble_nbest=self._ensemble_nbest,
max_models_on_disc=self._max_models_on_disc,
Expand Down Expand Up @@ -862,7 +868,7 @@ def fit(
config_file=configspace_path,
seed=self._seed,
metadata_directory=self._metadata_directory,
metric=self._metric,
metrics=self._metrics,
resampling_strategy=self._resampling_strategy,
resampling_strategy_args=self._resampling_strategy_arguments,
include=self._include,
Expand Down Expand Up @@ -1001,7 +1007,10 @@ def _log_fit_setup(self) -> None:
)
self._logger.debug(" smac_scenario_args: %s", str(self._smac_scenario_args))
self._logger.debug(" logging_config: %s", str(self.logging_config))
self._logger.debug(" metric: %s", str(self._metric))
if len(self._metrics) == 1:
self._logger.debug(" metric: %s", str(self._metrics[0]))
else:
self._logger.debug(" metrics: %s", str(self._metrics))
self._logger.debug("Done printing arguments to auto-sklearn")
self._logger.debug("Starting to print available components")
for choice in (
Expand Down Expand Up @@ -1254,8 +1263,8 @@ def fit_pipeline(
self._task = task

# Assign a metric if it doesnt exist
if self._metric is None:
self._metric = default_metric_for_task[self._task]
if self._metrics is None:
self._metrics = [default_metric_for_task[self._task]]

# Get the configuration space
# This also ensures that the Backend has processed the
Expand Down Expand Up @@ -1288,8 +1297,14 @@ def fit_pipeline(
kwargs["memory_limit"] = self._memory_limit
if "resampling_strategy" not in kwargs:
kwargs["resampling_strategy"] = self._resampling_strategy
if "metric" not in kwargs:
kwargs["metric"] = self._metric
if "metrics" not in kwargs:
if "metric" in kwargs:
kwargs["metrics"] = kwargs["metric"]
del kwargs["metric"]
else:
kwargs["metrics"] = self._metrics
if not isinstance(kwargs["metrics"], Sequence):
kwargs["metrics"] = [kwargs["metrics"]]
if "disable_file_output" not in kwargs:
kwargs["disable_file_output"] = self._disable_evaluator_output
if "pynisher_context" not in kwargs:
Expand All @@ -1307,7 +1322,7 @@ def fit_pipeline(
autosklearn_seed=self._seed,
abort_on_first_run_crash=False,
multi_objectives=["cost"],
cost_for_crash=get_cost_of_crash(kwargs["metric"]),
cost_for_crash=get_cost_of_crash(kwargs["metrics"]),
port=self._logger_port,
**kwargs,
**self._resampling_strategy_arguments,
Expand Down Expand Up @@ -1492,7 +1507,7 @@ def fit_ensemble(
backend=copy.deepcopy(self._backend),
dataset_name=dataset_name if dataset_name else self._dataset_name,
task=task if task else self._task,
metric=self._metric,
metric=self._metrics[0],
ensemble_size=ensemble_size if ensemble_size else self._ensemble_size,
ensemble_nbest=ensemble_nbest if ensemble_nbest else self._ensemble_nbest,
max_models_on_disc=self._max_models_on_disc,
Expand Down Expand Up @@ -1590,7 +1605,7 @@ def _load_best_individual_model(self):

# SingleBest contains the best model found by AutoML
ensemble = SingleBest(
metric=self._metric,
metric=self._metrics[0],
seed=self._seed,
run_history=self.runhistory_,
backend=self._backend,
Expand Down Expand Up @@ -1624,15 +1639,15 @@ def score(self, X, y):
# same representation domain
prediction = self.InputValidator.target_validator.transform(prediction)

return calculate_metric(
return compute_single_metric(
solution=y,
prediction=prediction,
task_type=self._task,
metric=self._metric,
metric=self._metrics[0],
)

def _get_runhistory_models_performance(self):
metric = self._metric
metric = self._metrics[0]
data = self.runhistory_.data
performance_list = []
for run_key, run_value in data.items():
Expand All @@ -1644,7 +1659,10 @@ def _get_runhistory_models_performance(self):
endtime = pd.Timestamp(
time.strftime("%Y-%m-%d %H:%M:%S", time.localtime(run_value.endtime))
)
val_score = metric._optimum - (metric._sign * run_value.cost)
cost = run_value.cost
if len(self._metrics) > 1:
cost = cost[0]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I assume this is a point of API conflict? It would be good to know about all the metrics for a model but at the end of the day, we currently only support one and so we choose the first?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it would be good to know about all the metrics. I will look into returning multiple metrics here (should be possible).

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please see my comment wrt this in the PR comment at the top.

val_score = metric._optimum - (metric._sign * cost)
train_score = metric._optimum - (
metric._sign * run_value.additional_info["train_loss"]
)
Expand All @@ -1656,9 +1674,10 @@ def _get_runhistory_models_performance(self):
# Append test-scores, if data for test_loss are available.
# This is the case, if X_test and y_test where provided.
if "test_loss" in run_value.additional_info:
test_score = metric._optimum - (
metric._sign * run_value.additional_info["test_loss"]
)
test_loss = run_value.additional_info["test_loss"]
if len(self._metrics) > 1:
test_loss = test_loss[0]
test_score = metric._optimum - (metric._sign * test_loss)
scores["single_best_test_score"] = test_score

performance_list.append(scores)
Expand Down Expand Up @@ -1747,14 +1766,11 @@ def cv_results_(self):

metric_mask = dict()
metric_dict = dict()
metric_name = []

for metric in self._scoring_functions:
metric_name.append(metric.name)
for metric in itertools.chain(self._metrics, self._scoring_functions):
metric_dict[metric.name] = []
metric_mask[metric.name] = []

mean_test_score = []
mean_fit_time = []
params = []
status = []
Expand Down Expand Up @@ -1787,9 +1803,7 @@ def cv_results_(self):

param_dict = config.get_dictionary()
params.append(param_dict)
mean_test_score.append(
self._metric._optimum - (self._metric._sign * run_value.cost)
)

mean_fit_time.append(run_value.time)
budgets.append(run_key.budget)

Expand All @@ -1804,6 +1818,14 @@ def cv_results_(self):
parameter_dictionaries[hp_name].append(hp_value)
masks[hp_name].append(mask_value)

cost = [run_value.cost] if len(self._metrics) == 1 else run_value.cost
for metric_idx, metric in enumerate(self._metrics):
metric_cost = cost[metric_idx]
metric_value = metric._optimum - (metric._sign * metric_cost)
mask_value = False
metric_dict[metric.name].append(metric_value)
metric_mask[metric.name].append(mask_value)

for metric in self._scoring_functions:
if metric.name in run_value.additional_info.keys():
metric_cost = run_value.additional_info[metric.name]
Expand All @@ -1815,15 +1837,26 @@ def cv_results_(self):
metric_dict[metric.name].append(metric_value)
metric_mask[metric.name].append(mask_value)

results["mean_test_score"] = np.array(mean_test_score)
for name in metric_name:
masked_array = ma.MaskedArray(metric_dict[name], metric_mask[name])
results["metric_%s" % name] = masked_array
if len(self._metrics) == 1:
results["mean_test_score"] = np.array(metric_dict[self._metrics[0].name])
rank_order = -1 * self._metrics[0]._sign * results["mean_test_score"]
results["rank_test_scores"] = scipy.stats.rankdata(rank_order, method="min")
else:
for metric in self._metrics:
key = f"mean_test_{metric.name}"
results[key] = np.array(metric_dict[metric.name])
rank_order = -1 * metric._sign * results[key]
results[f"rank_test_{metric.name}"] = scipy.stats.rankdata(
rank_order, method="min"
)
for metric in self._scoring_functions:
masked_array = ma.MaskedArray(
metric_dict[metric.name], metric_mask[metric.name]
)
results[f"metric_{metric.name}"] = masked_array

results["mean_fit_time"] = np.array(mean_fit_time)
results["params"] = params
rank_order = -1 * self._metric._sign * results["mean_test_score"]
results["rank_test_scores"] = scipy.stats.rankdata(rank_order, method="min")
results["status"] = status
results["budgets"] = budgets

Expand All @@ -1841,7 +1874,10 @@ def sprint_statistics(self) -> str:
sio = io.StringIO()
sio.write("auto-sklearn results:\n")
sio.write(" Dataset name: %s\n" % self._dataset_name)
sio.write(" Metric: %s\n" % self._metric)
if len(self._metrics) == 1:
sio.write(" Metric: %s\n" % self._metrics[0])
else:
sio.write(" Metrics: %s\n" % self._metrics)
idx_success = np.where(
np.array(
[
Expand All @@ -1852,7 +1888,7 @@ def sprint_statistics(self) -> str:
)
)[0]
if len(idx_success) > 0:
if not self._metric._optimum:
if not self._metrics[0]._optimum:
idx_best_run = np.argmin(cv_results["mean_test_score"][idx_success])
else:
idx_best_run = np.argmax(cv_results["mean_test_score"][idx_success])
Expand Down Expand Up @@ -1912,7 +1948,6 @@ def show_models(self) -> dict[int, Any]:
.. code-block:: python

import sklearn.datasets
import sklearn.metrics
import autosklearn.regression

X, y = sklearn.datasets.load_diabetes(return_X_y=True)
Expand Down
Loading