Skip to content

Commit f4c8861

Browse files
[SDK] Add env & env_from in client tune (#2235)
* add env & env_from spec * unify env and env_from specs
1 parent fbe7c78 commit f4c8861

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

sdk/python/v1beta1/kubeflow/katib/api/katib_client.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -140,6 +140,7 @@ def tune(
140140
parameters: Dict[str, Any],
141141
base_image: str = constants.BASE_IMAGE_TENSORFLOW,
142142
namespace: Optional[str] = None,
143+
env_per_trial: Optional[Union[Dict[str, str], List[Union[client.V1EnvVar, client.V1EnvFromSource]]]] = None,
143144
algorithm_name: str = "random",
144145
algorithm_settings: Union[dict, List[models.V1beta1AlgorithmSetting], None] = None,
145146
objective_metric_name: str = None,
@@ -172,6 +173,12 @@ def tune(
172173
objective function.
173174
base_image: Image to use when executing the objective function.
174175
namespace: Namespace for the Experiment.
176+
env_per_trial: Environment variable(s) to be attached to each trial container.
177+
You can specify a dictionary as a mapping object representing the environment variables.
178+
Otherwise, you can specify a list, in which the element can either be a kubernetes.client.models.V1EnvVar (documented here:
179+
https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/V1EnvVar.md)
180+
or a kubernetes.client.models.V1EnvFromSource (documented here:
181+
https://github.com/kubernetes-client/python/blob/master/kubernetes/docs/V1EnvFromSource.md)
175182
algorithm_name: Search algorithm for the HyperParameter tuning.
176183
algorithm_settings: Settings for the search algorithm given.
177184
For available fields, check this doc: https://www.kubeflow.org/docs/components/katib/experiment/#search-algorithms-in-detail.
@@ -318,6 +325,15 @@ def tune(
318325
requests=resources_per_trial,
319326
limits=resources_per_trial,
320327
)
328+
329+
if isinstance(env_per_trial, dict):
330+
env, env_from = [client.V1EnvVar(name=str(k), value=str(v)) for k, v in env_per_trial.items()] or None, None
331+
332+
if env_per_trial:
333+
env = [x for x in env_per_trial if isinstance(x, client.V1EnvVar)] or None
334+
env_from = [x for x in env_per_trial if isinstance(x, client.V1EnvFromSource)] or None
335+
else:
336+
env, env_from = None, None
321337

322338
# Create Trial specification.
323339
trial_spec = client.V1Job(
@@ -336,6 +352,8 @@ def tune(
336352
image=base_image,
337353
command=["bash", "-c"],
338354
args=[exec_script],
355+
env=env,
356+
env_from=env_from,
339357
resources=resources_per_trial,
340358
)
341359
],

0 commit comments

Comments
 (0)