16
16
import multiprocessing
17
17
import textwrap
18
18
import time
19
- from typing import Any , Callable , Dict , List , Optional
19
+ from typing import Any , Callable , Dict , List , Optional , Union
20
20
21
21
import grpc
22
22
import kubeflow .katib .katib_api_pb2 as katib_api_pb2
@@ -144,6 +144,7 @@ def tune(
144
144
max_trial_count : int = None ,
145
145
parallel_trial_count : int = None ,
146
146
max_failed_trial_count : int = None ,
147
+ resources_per_trial : Union [dict , client .V1ResourceRequirements , None ] = None ,
147
148
retain_trials : bool = False ,
148
149
packages_to_install : List [str ] = None ,
149
150
pip_index_url : str = "https://pypi.org/simple" ,
@@ -177,6 +178,24 @@ def tune(
177
178
values check this doc: https://www.kubeflow.org/docs/components/katib/experiment/#configuration-spec.
178
179
parallel_trial_count: Number of Trials that Experiment runs in parallel.
179
180
max_failed_trial_count: Maximum number of Trials allowed to fail.
181
+ resources_per_trial: A parameter that lets you specify how much
182
+ resources each trial container should have. You can either specify a
183
+ kubernetes.client.V1ResourceRequirements object (documented here:
184
+ https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/V1ResourceRequirements.md)
185
+ or a dictionary that includes one or more of the following keys:
186
+ `cpu`, `memory`, or `gpu` (other keys will be ignored). Appropriate
187
+ values for these keys are documented here:
188
+ https://kubernetes.io/docs/concepts/configuration/manage-resources-containers/.
189
+ For example:
190
+ {
191
+ "cpu": "1",
192
+ "gpu": "1",
193
+ "memory": "2Gi",
194
+ }
195
+ Please note, `gpu` specifies a resource request with a key of
196
+ `nvidia.com/gpu`, i.e. an NVIDIA GPU. If you need a different type
197
+ of GPU, pass in a V1ResourceRequirement instance instead, since it's
198
+ more flexible. This parameter is optional and defaults to None.
180
199
retain_trials: Whether Trials' resources (e.g. pods) are deleted after Succeeded state.
181
200
packages_to_install: List of Python packages to install in addition
182
201
to the base image packages. These packages are installed before
@@ -280,6 +299,15 @@ def tune(
280
299
+ exec_script
281
300
)
282
301
302
+ if isinstance (resources_per_trial , dict ):
303
+ if "gpu" in resources_per_trial :
304
+ resources_per_trial ["nvidia.com/gpu" ] = resources_per_trial .pop ("gpu" )
305
+
306
+ resources_per_trial = client .V1ResourceRequirements (
307
+ requests = resources_per_trial ,
308
+ limits = resources_per_trial ,
309
+ )
310
+
283
311
# Create Trial specification.
284
312
trial_spec = client .V1Job (
285
313
api_version = "batch/v1" ,
@@ -297,6 +325,7 @@ def tune(
297
325
image = base_image ,
298
326
command = ["bash" , "-c" ],
299
327
args = [exec_script ],
328
+ resources = resources_per_trial ,
300
329
)
301
330
],
302
331
),
@@ -640,7 +669,6 @@ def wait_for_experiment_condition(
640
669
namespace = namespace or self .namespace
641
670
642
671
for _ in range (round (timeout / polling_interval )):
643
-
644
672
# We should get Experiment only once per cycle and check the statuses.
645
673
experiment = self .get_experiment (name , namespace , apiserver_timeout )
646
674
@@ -1175,7 +1203,6 @@ def get_trial_metrics(
1175
1203
)
1176
1204
1177
1205
with katib_api_pb2 .beta_create_DBManager_stub (channel ) as client :
1178
-
1179
1206
try :
1180
1207
# When metric name is empty, we select all logs from the Katib DB.
1181
1208
observation_logs = client .GetObservationLog (
0 commit comments