Skip to content

Commit a992dde

Browse files
authored
Add kubernetes labels (Netflix#1236)
* Add kubernetes labels * Add labels to cli * Empty labels fix * Fix labels in decorator * Allow dictionaries in decorator * Convert strings to dictionaries * Make parse node selector function more generic * Add argo labels * Clean kubernetes labels * Hash original string and use shorter hash * Fix command join * Add rename parse list and add tests for value cleaning * override spec parser * Don't reencode json objects * Strip quotes * Fix rebase error * Throw exception for invalid labels * Changes based on PR comments * Fix invalid label error message * Add code comment to kube validation function * Add tests and fix bad json * Fix pre-commit error * Remove f-strings :( * More python 3.5 fixes
1 parent e148cd8 commit a992dde

File tree

6 files changed

+216
-19
lines changed

6 files changed

+216
-19
lines changed

metaflow/metaflow_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -269,6 +269,8 @@
269269
KUBERNETES_NODE_SELECTOR = from_conf("KUBERNETES_NODE_SELECTOR", "")
270270
KUBERNETES_TOLERATIONS = from_conf("KUBERNETES_TOLERATIONS", "")
271271
KUBERNETES_SECRETS = from_conf("KUBERNETES_SECRETS", "")
272+
# Default labels for kubernetes pods
273+
KUBERNETES_LABELS = from_conf("KUBERNETES_LABELS", "")
272274
# Default GPU vendor to use by K8S jobs created by Metaflow (supports nvidia, amd)
273275
KUBERNETES_GPU_VENDOR = from_conf("KUBERNETES_GPU_VENDOR", "nvidia")
274276
# Default container image for K8S

metaflow/plugins/argo/argo_workflows.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -887,20 +887,22 @@ def _container_templates(self):
887887
.retry_strategy(
888888
times=total_retries,
889889
minutes_between_retries=minutes_between_retries,
890-
)
891-
.metadata(
890+
).metadata(
892891
ObjectMeta().annotation("metaflow/step_name", node.name)
893892
# Unfortunately, we can't set the task_id since it is generated
894893
# inside the pod. However, it can be inferred from the annotation
895894
# set by argo-workflows - `workflows.argoproj.io/outputs` - refer
896895
# the field 'task-id' in 'parameters'
897896
# .annotation("metaflow/task_id", ...)
898897
.annotation("metaflow/attempt", retry_count)
898+
# Set labels
899+
.labels(resources.get("labels"))
899900
)
900901
# Set emptyDir volume for state management
901902
.empty_dir_volume("out")
902903
# Set node selectors
903904
.node_selectors(resources.get("node_selector"))
905+
# Set tolerations
904906
.tolerations(resources.get("tolerations"))
905907
# Set container
906908
.container(

metaflow/plugins/kubernetes/kubernetes.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
from metaflow.metaflow_config import (
1010
SERVICE_HEADERS,
1111
SERVICE_INTERNAL_URL,
12+
CARD_AZUREROOT,
13+
CARD_GSROOT,
1214
CARD_S3ROOT,
1315
DATASTORE_SYSROOT_S3,
1416
DATATOOLS_S3ROOT,
@@ -29,8 +31,8 @@
2931
BASH_SAVE_LOGS,
3032
bash_capture_logs,
3133
export_mflog_env_vars,
32-
tail_logs,
3334
get_log_tailer,
35+
tail_logs,
3436
)
3537

3638
from .kubernetes_client import KubernetesClient
@@ -152,6 +154,7 @@ def create_job(
152154
run_time_limit=None,
153155
env=None,
154156
tolerations=None,
157+
labels=None,
155158
):
156159

157160
if env is None:
@@ -185,6 +188,7 @@ def create_job(
185188
retries=0,
186189
step_name=step_name,
187190
tolerations=tolerations,
191+
labels=labels,
188192
)
189193
.environment_variable("METAFLOW_CODE_SHA", code_package_sha)
190194
.environment_variable("METAFLOW_CODE_URL", code_package_url)

metaflow/plugins/kubernetes/kubernetes_cli.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import time
44
import traceback
55

6-
from metaflow import util, JSONTypeClass
6+
from metaflow import JSONTypeClass, util
77
from metaflow._vendor import click
88
from metaflow.exception import METAFLOW_EXIT_DISALLOW_RETRY, CommandException
99
from metaflow.metadata.util import sync_local_metadata_from_datastore
@@ -91,6 +91,12 @@ def kubernetes():
9191
type=JSONTypeClass(),
9292
multiple=False,
9393
)
94+
@click.option(
95+
"--labels",
96+
multiple=True,
97+
default=None,
98+
help="Labels for Kubernetes pod.",
99+
)
94100
@click.pass_context
95101
def step(
96102
ctx,
@@ -110,6 +116,7 @@ def step(
110116
gpu_vendor=None,
111117
run_time_limit=None,
112118
tolerations=None,
119+
labels=None,
113120
**kwargs
114121
):
115122
def echo(msg, stream="stderr", job_id=None, **kwargs):
@@ -175,7 +182,12 @@ def echo(msg, stream="stderr", job_id=None, **kwargs):
175182
stderr_location = ds.get_log_location(TASK_LOG_SOURCE, "stderr")
176183

177184
# `node_selector` is a tuple of strings, convert it to a dictionary
178-
node_selector = KubernetesDecorator.parse_node_selector(node_selector)
185+
node_selector = KubernetesDecorator.parse_kube_keyvalue_list(node_selector)
186+
187+
# `labels` is a tuple of strings or a tuple with a single comma separated string
188+
# convert it to a dict
189+
labels = KubernetesDecorator.parse_kube_keyvalue_list(labels, False)
190+
KubernetesDecorator.validate_kube_labels(labels)
179191

180192
def _sync_metadata():
181193
if ctx.obj.metadata.TYPE == "local":
@@ -218,6 +230,7 @@ def _sync_metadata():
218230
run_time_limit=run_time_limit,
219231
env=env,
220232
tolerations=tolerations,
233+
labels=labels,
221234
)
222235
except Exception as e:
223236
traceback.print_exc(chain=False)

metaflow/plugins/kubernetes/kubernetes_decorator.py

Lines changed: 96 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
1+
import hashlib
12
import json
23
import os
34
import platform
5+
import re
46
import sys
7+
from typing import Dict, List, Optional, Union
58

69
from metaflow.decorators import StepDecorator
710
from metaflow.exception import MetaflowException
@@ -12,11 +15,12 @@
1215
KUBERNETES_CONTAINER_IMAGE,
1316
KUBERNETES_CONTAINER_REGISTRY,
1417
KUBERNETES_GPU_VENDOR,
18+
KUBERNETES_LABELS,
1519
KUBERNETES_NAMESPACE,
1620
KUBERNETES_NODE_SELECTOR,
1721
KUBERNETES_TOLERATIONS,
18-
KUBERNETES_SERVICE_ACCOUNT,
1922
KUBERNETES_SECRETS,
23+
KUBERNETES_SERVICE_ACCOUNT,
2024
KUBERNETES_FETCH_EC2_METADATA,
2125
)
2226
from metaflow.plugins.resources_decorator import ResourcesDecorator
@@ -65,6 +69,8 @@ class KubernetesDecorator(StepDecorator):
6569
in Metaflow configuration.
6670
tolerations : List[str], default: METAFLOW_KUBERNETES_TOLERATIONS
6771
Kubernetes tolerations to use when launching pod in Kubernetes.
72+
labels : Dict[str, str], default: METAFLOW_KUBERNETES_LABELS
73+
Kubernetes labels to use when launching pod in Kubernetes.
6874
"""
6975

7076
name = "kubernetes"
@@ -76,6 +82,7 @@ class KubernetesDecorator(StepDecorator):
7682
"service_account": None,
7783
"secrets": None, # e.g., mysecret
7884
"node_selector": None, # e.g., kubernetes.io/os=linux
85+
"labels": None, # e.g., my_label=my_value
7986
"namespace": None,
8087
"gpu": None, # value of 0 implies that the scheduled node should not have GPUs
8188
"gpu_vendor": None,
@@ -99,9 +106,17 @@ def __init__(self, attributes=None, statically_defined=False):
99106
self.attributes["node_selector"] = KUBERNETES_NODE_SELECTOR
100107
if not self.attributes["tolerations"] and KUBERNETES_TOLERATIONS:
101108
self.attributes["tolerations"] = json.loads(KUBERNETES_TOLERATIONS)
109+
if not self.attributes["labels"] and KUBERNETES_LABELS:
110+
self.attributes["labels"] = KUBERNETES_LABELS
111+
112+
if isinstance(self.attributes["labels"], str):
113+
self.attributes["labels"] = self.parse_kube_keyvalue_list(
114+
self.attributes["labels"].split(","), False
115+
)
116+
self.validate_kube_labels(self.attributes["labels"])
102117

103118
if isinstance(self.attributes["node_selector"], str):
104-
self.attributes["node_selector"] = self.parse_node_selector(
119+
self.attributes["node_selector"] = self.parse_kube_keyvalue_list(
105120
self.attributes["node_selector"].split(",")
106121
)
107122

@@ -280,10 +295,11 @@ def runtime_step_cli(
280295
for k, v in self.attributes.items():
281296
if k == "namespace":
282297
cli_args.command_options["k8s_namespace"] = v
283-
elif k == "node_selector" and v:
284-
cli_args.command_options[k] = ",".join(
285-
["=".join([key, str(val)]) for key, val in v.items()]
286-
)
298+
elif k in {"node_selector", "labels"} and v:
299+
cli_args.command_options[k] = [
300+
"=".join([key, str(val)]) if val else key
301+
for key, val in v.items()
302+
]
287303
elif k == "tolerations":
288304
cli_args.command_options[k] = json.dumps(v)
289305
else:
@@ -391,14 +407,80 @@ def _save_package_once(cls, flow_datastore, package):
391407
[package.blob], len_hint=1
392408
)[0]
393409

410+
@classmethod
411+
def _parse_decorator_spec(cls, deco_spec: str):
412+
if not deco_spec:
413+
return cls()
414+
415+
valid_options = "|".join(cls.defaults.keys())
416+
deco_spec_parts = []
417+
for part in re.split(""",(?=[\s\w]+[{}]=)""".format(valid_options), deco_spec):
418+
name, val = part.split("=", 1)
419+
if name in {"labels", "node_selector"}:
420+
try:
421+
tmp_vals = json.loads(val.strip().replace('\\"', '"'))
422+
for val_i in tmp_vals.values():
423+
if not (val_i is None or isinstance(val_i, str)):
424+
raise KubernetesException(
425+
"All values must be string or null."
426+
)
427+
except json.JSONDecodeError:
428+
if val.startswith("{"):
429+
raise KubernetesException(
430+
"Malform json detected in %s" % str(val)
431+
)
432+
both = name == "node_selector"
433+
val = json.dumps(
434+
cls.parse_kube_keyvalue_list(val.split(","), both),
435+
separators=(",", ":"),
436+
)
437+
deco_spec_parts.append("=".join([name, val]))
438+
deco_spec_parsed = ",".join(deco_spec_parts)
439+
return super()._parse_decorator_spec(deco_spec_parsed)
440+
394441
@staticmethod
395-
def parse_node_selector(node_selector: list):
442+
def parse_kube_keyvalue_list(items: List[str], requires_both: bool = True):
396443
try:
397-
return {
398-
str(k.split("=", 1)[0]): str(k.split("=", 1)[1])
399-
for k in node_selector or []
400-
}
444+
ret = {}
445+
for item_str in items:
446+
item = item_str.split("=", 1)
447+
if requires_both:
448+
item[1] # raise IndexError
449+
if str(item[0]) in ret:
450+
raise KubernetesException("Duplicate key found: %s" % str(item[0]))
451+
ret[str(item[0])] = str(item[1]) if len(item) > 1 else None
452+
return ret
453+
except KubernetesException as e:
454+
raise e
401455
except (AttributeError, IndexError):
402-
raise KubernetesException(
403-
"Unable to parse node_selector: %s" % node_selector
404-
)
456+
raise KubernetesException("Unable to parse kubernetes list: %s" % items)
457+
458+
@staticmethod
459+
def validate_kube_labels(
460+
labels: Optional[Dict[str, Optional[str]]],
461+
) -> bool:
462+
"""Validate label values.
463+
464+
This validates the kubernetes label values. It does not validate the keys.
465+
Ideally, keys should be static and also the validation rules for keys are
466+
more complex than those for values. For full validation rules, see:
467+
468+
https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set
469+
"""
470+
471+
def validate_label(s: Optional[str]):
472+
regex_match = r"^(([A-Za-z0-9][-A-Za-z0-9_.]{0,61})?[A-Za-z0-9])?$"
473+
if not s:
474+
# allow empty label
475+
return True
476+
if not re.search(regex_match, s):
477+
raise KubernetesException(
478+
'Invalid value: "%s"\n'
479+
"A valid label must be an empty string or one that\n"
480+
" - Consist of alphanumeric, '-', '_' or '.' characters\n"
481+
" - Begins and ends with an alphanumeric character\n"
482+
" - Is at most 63 characters" % s
483+
)
484+
return True
485+
486+
return all([validate_label(v) for v in labels.values()]) if labels else True
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import pytest
2+
3+
from metaflow.plugins.kubernetes.kubernetes import KubernetesException
4+
from metaflow.plugins.kubernetes.kubernetes_decorator import KubernetesDecorator
5+
6+
7+
@pytest.mark.parametrize(
8+
"labels",
9+
[
10+
None,
11+
{"label": "value"},
12+
{"label1": "val1", "label2": "val2"},
13+
{"label1": "val1", "label2": None},
14+
{"label": "a"},
15+
{"label": ""},
16+
{
17+
"label": (
18+
"1234567890"
19+
"1234567890"
20+
"1234567890"
21+
"1234567890"
22+
"1234567890"
23+
"1234567890"
24+
"123"
25+
)
26+
},
27+
{
28+
"label": (
29+
"1234567890"
30+
"1234567890"
31+
"1234-_.890"
32+
"1234567890"
33+
"1234567890"
34+
"1234567890"
35+
"123"
36+
)
37+
},
38+
],
39+
)
40+
def test_kubernetes_decorator_validate_kube_labels(labels):
41+
assert KubernetesDecorator.validate_kube_labels(labels)
42+
43+
44+
@pytest.mark.parametrize(
45+
"labels",
46+
[
47+
{"label": "a-"},
48+
{"label": ".a"},
49+
{"label": "test()"},
50+
{
51+
"label": (
52+
"1234567890"
53+
"1234567890"
54+
"1234567890"
55+
"1234567890"
56+
"1234567890"
57+
"1234567890"
58+
"1234"
59+
)
60+
},
61+
{"label": "(){}??"},
62+
{"valid": "test", "invalid": "bißchen"},
63+
],
64+
)
65+
def test_kubernetes_decorator_validate_kube_labels_fail(labels):
66+
"""Fail if label contains invalid characters or is too long"""
67+
with pytest.raises(KubernetesException):
68+
KubernetesDecorator.validate_kube_labels(labels)
69+
70+
71+
@pytest.mark.parametrize(
72+
"items,requires_both,expected",
73+
[
74+
(["key=value"], True, {"key": "value"}),
75+
(["key=value"], False, {"key": "value"}),
76+
(["key"], False, {"key": None}),
77+
(["key=value", "key2=value2"], True, {"key": "value", "key2": "value2"}),
78+
],
79+
)
80+
def test_kubernetes_parse_keyvalue_list(items, requires_both, expected):
81+
ret = KubernetesDecorator.parse_kube_keyvalue_list(items, requires_both)
82+
assert ret == expected
83+
84+
85+
@pytest.mark.parametrize(
86+
"items,requires_both",
87+
[
88+
(["key=value", "key=value2"], True),
89+
(["key"], True),
90+
],
91+
)
92+
def test_kubernetes_parse_keyvalue_list(items, requires_both):
93+
with pytest.raises(KubernetesException):
94+
KubernetesDecorator.parse_kube_keyvalue_list(items, requires_both)

0 commit comments

Comments
 (0)