Skip to content

Commit 923d0fc

Browse files
droctothorpeshipengcheng1230andreyvelich
authored
[SDK] Enable resource specification for trial containers (#2192)
Co-authored-by: shipengcheng1230 <[email protected]> Co-authored-by: Andrey Velichkevich <[email protected]>
1 parent 114485d commit 923d0fc

File tree

1 file changed

+30
-3
lines changed

1 file changed

+30
-3
lines changed

sdk/python/v1beta1/kubeflow/katib/api/katib_client.py

Lines changed: 30 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import multiprocessing
1717
import textwrap
1818
import time
19-
from typing import Any, Callable, Dict, List, Optional
19+
from typing import Any, Callable, Dict, List, Optional, Union
2020

2121
import grpc
2222
import kubeflow.katib.katib_api_pb2 as katib_api_pb2
@@ -144,6 +144,7 @@ def tune(
144144
max_trial_count: int = None,
145145
parallel_trial_count: int = None,
146146
max_failed_trial_count: int = None,
147+
resources_per_trial: Union[dict, client.V1ResourceRequirements, None] = None,
147148
retain_trials: bool = False,
148149
packages_to_install: List[str] = None,
149150
pip_index_url: str = "https://pypi.org/simple",
@@ -177,6 +178,24 @@ def tune(
177178
values check this doc: https://www.kubeflow.org/docs/components/katib/experiment/#configuration-spec.
178179
parallel_trial_count: Number of Trials that Experiment runs in parallel.
179180
max_failed_trial_count: Maximum number of Trials allowed to fail.
181+
resources_per_trial: A parameter that lets you specify how much
182+
resources each trial container should have. You can either specify a
183+
kubernetes.client.V1ResourceRequirements object (documented here:
184+
https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/V1ResourceRequirements.md)
185+
or a dictionary that includes one or more of the following keys:
186+
`cpu`, `memory`, or `gpu` (other keys will be ignored). Appropriate
187+
values for these keys are documented here:
188+
https://kubernetes.io/docs/concepts/configuration/manage-resources-containers/.
189+
For example:
190+
{
191+
"cpu": "1",
192+
"gpu": "1",
193+
"memory": "2Gi",
194+
}
195+
Please note, `gpu` specifies a resource request with a key of
196+
`nvidia.com/gpu`, i.e. an NVIDIA GPU. If you need a different type
197+
of GPU, pass in a V1ResourceRequirement instance instead, since it's
198+
more flexible. This parameter is optional and defaults to None.
180199
retain_trials: Whether Trials' resources (e.g. pods) are deleted after Succeeded state.
181200
packages_to_install: List of Python packages to install in addition
182201
to the base image packages. These packages are installed before
@@ -280,6 +299,15 @@ def tune(
280299
+ exec_script
281300
)
282301

302+
if isinstance(resources_per_trial, dict):
303+
if "gpu" in resources_per_trial:
304+
resources_per_trial["nvidia.com/gpu"] = resources_per_trial.pop("gpu")
305+
306+
resources_per_trial = client.V1ResourceRequirements(
307+
requests=resources_per_trial,
308+
limits=resources_per_trial,
309+
)
310+
283311
# Create Trial specification.
284312
trial_spec = client.V1Job(
285313
api_version="batch/v1",
@@ -297,6 +325,7 @@ def tune(
297325
image=base_image,
298326
command=["bash", "-c"],
299327
args=[exec_script],
328+
resources=resources_per_trial,
300329
)
301330
],
302331
),
@@ -640,7 +669,6 @@ def wait_for_experiment_condition(
640669
namespace = namespace or self.namespace
641670

642671
for _ in range(round(timeout / polling_interval)):
643-
644672
# We should get Experiment only once per cycle and check the statuses.
645673
experiment = self.get_experiment(name, namespace, apiserver_timeout)
646674

@@ -1175,7 +1203,6 @@ def get_trial_metrics(
11751203
)
11761204

11771205
with katib_api_pb2.beta_create_DBManager_stub(channel) as client:
1178-
11791206
try:
11801207
# When metric name is empty, we select all logs from the Katib DB.
11811208
observation_logs = client.GetObservationLog(

0 commit comments

Comments
 (0)