@@ -53,10 +53,10 @@ class KatibClient(object):
53
53
max_trial_count : int = None ,
54
54
parallel_trial_count : int = None ,
55
55
max_failed_trial_count : int = None ,
56
- resources_per_trial : Union[dict , client.V1ResourceRequirements, None ] = None ,
57
- pytorch_config = katib.PyTorchConfig(
56
+ pytorch_config = katib.PyTorchConfig(
58
57
num_workers: int = 1 ,
59
58
num_procs_per_worker: int = 1 ,
59
+ resources_per_worker: Union[dict , client.V1ResourceRequirements, None ] = None ,
60
60
),
61
61
retain_trials : bool = False ,
62
62
env_per_trial : Optional[Union[Dict[str , str ], List[Union[client.V1EnvVar, client.V1EnvFromSource]]]] = None ,
@@ -81,8 +81,7 @@ class KatibClient(object):
81
81
- max_trial_count: Maximum number of trials to run.
82
82
- parallel_trial_count: Number of trials to run in parallel.
83
83
- max_failed_trial_count: Maximum number of allowed failed trials.
84
- - resources_per_trial: Resources required per trial.
85
- - pytorch_config: Configuration for PyTorch jobs, including number of workers and processes per worker.
84
+ - pytorch_config: Configuration for PyTorch jobs, including number of workers, processes per worker and resources per worker.
86
85
- retain_trials: Whether to retain trial resources after completion.
87
86
- env_per_trial: Environment variables for worker containers.
88
87
- packages_to_install: Additional Python packages to install.
@@ -149,13 +148,15 @@ katib_client.tune(
149
148
algorithm_name = " random" ,
150
149
max_trial_count = 50 ,
151
150
parallel_trial_count = 2 ,
152
- resources_per_trial = {
153
- " gpu" : 8 ,
154
- " cpu" : 20 ,
155
- " memory" : " 40G" ,
156
- },
157
- num_workers = 4 ,
158
- num_procs_per_worker = 2 ,
151
+ pytorch_config = katib.PyTorchConfig(
152
+ num_workers = 4 ,
153
+ num_procs_per_worker = 2 ,
154
+ resources_per_worker = {
155
+ " gpu" : 2 ,
156
+ " cpu" : 5 ,
157
+ " memory" : " 10G" ,
158
+ },
159
+ ),
159
160
)
160
161
161
162
# Get the best hyperparameters
0 commit comments