Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 14 additions & 4 deletions paddlenlp/experimental/autonlp/auto_trainer_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ def _preprocess_fn(
preprocess an example from raw features to input features that Transformers models expect (e.g. input_ids, attention_mask, labels, etc)
"""

@abstractmethod
def export(self, export_path, trial_id=None):
"""
Export the model from a certain `trial_id` to the given file path.
Expand Down Expand Up @@ -171,15 +172,24 @@ def evaluate(self, trial_id=None, eval_dataset=None) -> Dict[str, float]:
"""
raise NotImplementedError

@abstractmethod
def predict(self, test_dataset: Dataset, trial_id: Optional[str] = None):
"""
Run prediction and returns predictions and potential metrics from a certain `trial_id` on the given dataset
Args:
test_dataset (Dataset, required): Custom test dataset and must contains the 'text_column' and 'label_column' fields.
trial_id (str, optional): Specify the model to be evaluated through the `trial_id`. Defaults to the best model selected by `metric_for_best_model`.
"""
raise NotImplementedError

def _override_hp(self, config: Dict[str, Any], default_hp: Any) -> Any:
"""
Overrides the arguments with the provided hyperparameter config
"""
new_hp = copy.deepcopy(default_hp)
for key, value in config.items():
if key.startswith(default_hp.__class__.__name__):
_, hp_key = key.split(".")
setattr(new_hp, hp_key, value)
if key in new_hp.to_dict():
setattr(new_hp, key, value)
return new_hp

def _filter_model_candidates(
Expand Down Expand Up @@ -264,7 +274,7 @@ def train(
experiment_name: (str, optional): name of the experiment. Experiment log will be stored under <output_dir>/<experiment_name>.
Defaults to UNIX timestamp.
hp_overrides: (dict[str, Any], optional): Advanced users only.
override the hyperparameters of every model candidate. For example, {"TrainingArguments.max_steps": 5}.
override the hyperparameters of every model candidate. For example, {"max_steps": 5}.
custom_model_candiates: (dict[str, Any], optional): Advanced users only.
Run the user-provided model candidates instead of the default model candidated from PaddleNLP. See `._model_candidates` property as an example

Expand Down
Loading