1
+ import hashlib
1
2
import json
2
3
import os
3
4
import platform
5
+ import re
4
6
import sys
7
+ from typing import Dict , List , Optional , Union
5
8
6
9
from metaflow .decorators import StepDecorator
7
10
from metaflow .exception import MetaflowException
12
15
KUBERNETES_CONTAINER_IMAGE ,
13
16
KUBERNETES_CONTAINER_REGISTRY ,
14
17
KUBERNETES_GPU_VENDOR ,
18
+ KUBERNETES_LABELS ,
15
19
KUBERNETES_NAMESPACE ,
16
20
KUBERNETES_NODE_SELECTOR ,
17
21
KUBERNETES_TOLERATIONS ,
18
- KUBERNETES_SERVICE_ACCOUNT ,
19
22
KUBERNETES_SECRETS ,
23
+ KUBERNETES_SERVICE_ACCOUNT ,
20
24
KUBERNETES_FETCH_EC2_METADATA ,
21
25
)
22
26
from metaflow .plugins .resources_decorator import ResourcesDecorator
@@ -65,6 +69,8 @@ class KubernetesDecorator(StepDecorator):
65
69
in Metaflow configuration.
66
70
tolerations : List[str], default: METAFLOW_KUBERNETES_TOLERATIONS
67
71
Kubernetes tolerations to use when launching pod in Kubernetes.
72
+ labels : Dict[str, str], default: METAFLOW_KUBERNETES_LABELS
73
+ Kubernetes labels to use when launching pod in Kubernetes.
68
74
"""
69
75
70
76
name = "kubernetes"
@@ -76,6 +82,7 @@ class KubernetesDecorator(StepDecorator):
76
82
"service_account" : None ,
77
83
"secrets" : None , # e.g., mysecret
78
84
"node_selector" : None , # e.g., kubernetes.io/os=linux
85
+ "labels" : None , # e.g., my_label=my_value
79
86
"namespace" : None ,
80
87
"gpu" : None , # value of 0 implies that the scheduled node should not have GPUs
81
88
"gpu_vendor" : None ,
@@ -99,9 +106,17 @@ def __init__(self, attributes=None, statically_defined=False):
99
106
self .attributes ["node_selector" ] = KUBERNETES_NODE_SELECTOR
100
107
if not self .attributes ["tolerations" ] and KUBERNETES_TOLERATIONS :
101
108
self .attributes ["tolerations" ] = json .loads (KUBERNETES_TOLERATIONS )
109
+ if not self .attributes ["labels" ] and KUBERNETES_LABELS :
110
+ self .attributes ["labels" ] = KUBERNETES_LABELS
111
+
112
+ if isinstance (self .attributes ["labels" ], str ):
113
+ self .attributes ["labels" ] = self .parse_kube_keyvalue_list (
114
+ self .attributes ["labels" ].split ("," ), False
115
+ )
116
+ self .validate_kube_labels (self .attributes ["labels" ])
102
117
103
118
if isinstance (self .attributes ["node_selector" ], str ):
104
- self .attributes ["node_selector" ] = self .parse_node_selector (
119
+ self .attributes ["node_selector" ] = self .parse_kube_keyvalue_list (
105
120
self .attributes ["node_selector" ].split ("," )
106
121
)
107
122
@@ -280,10 +295,11 @@ def runtime_step_cli(
280
295
for k , v in self .attributes .items ():
281
296
if k == "namespace" :
282
297
cli_args .command_options ["k8s_namespace" ] = v
283
- elif k == "node_selector" and v :
284
- cli_args .command_options [k ] = "," .join (
285
- ["=" .join ([key , str (val )]) for key , val in v .items ()]
286
- )
298
+ elif k in {"node_selector" , "labels" } and v :
299
+ cli_args .command_options [k ] = [
300
+ "=" .join ([key , str (val )]) if val else key
301
+ for key , val in v .items ()
302
+ ]
287
303
elif k == "tolerations" :
288
304
cli_args .command_options [k ] = json .dumps (v )
289
305
else :
@@ -391,14 +407,80 @@ def _save_package_once(cls, flow_datastore, package):
391
407
[package .blob ], len_hint = 1
392
408
)[0 ]
393
409
410
+ @classmethod
411
+ def _parse_decorator_spec (cls , deco_spec : str ):
412
+ if not deco_spec :
413
+ return cls ()
414
+
415
+ valid_options = "|" .join (cls .defaults .keys ())
416
+ deco_spec_parts = []
417
+ for part in re .split (""",(?=[\s\w]+[{}]=)""" .format (valid_options ), deco_spec ):
418
+ name , val = part .split ("=" , 1 )
419
+ if name in {"labels" , "node_selector" }:
420
+ try :
421
+ tmp_vals = json .loads (val .strip ().replace ('\\ "' , '"' ))
422
+ for val_i in tmp_vals .values ():
423
+ if not (val_i is None or isinstance (val_i , str )):
424
+ raise KubernetesException (
425
+ "All values must be string or null."
426
+ )
427
+ except json .JSONDecodeError :
428
+ if val .startswith ("{" ):
429
+ raise KubernetesException (
430
+ "Malform json detected in %s" % str (val )
431
+ )
432
+ both = name == "node_selector"
433
+ val = json .dumps (
434
+ cls .parse_kube_keyvalue_list (val .split ("," ), both ),
435
+ separators = ("," , ":" ),
436
+ )
437
+ deco_spec_parts .append ("=" .join ([name , val ]))
438
+ deco_spec_parsed = "," .join (deco_spec_parts )
439
+ return super ()._parse_decorator_spec (deco_spec_parsed )
440
+
394
441
@staticmethod
395
- def parse_node_selector ( node_selector : list ):
442
+ def parse_kube_keyvalue_list ( items : List [ str ], requires_both : bool = True ):
396
443
try :
397
- return {
398
- str (k .split ("=" , 1 )[0 ]): str (k .split ("=" , 1 )[1 ])
399
- for k in node_selector or []
400
- }
444
+ ret = {}
445
+ for item_str in items :
446
+ item = item_str .split ("=" , 1 )
447
+ if requires_both :
448
+ item [1 ] # raise IndexError
449
+ if str (item [0 ]) in ret :
450
+ raise KubernetesException ("Duplicate key found: %s" % str (item [0 ]))
451
+ ret [str (item [0 ])] = str (item [1 ]) if len (item ) > 1 else None
452
+ return ret
453
+ except KubernetesException as e :
454
+ raise e
401
455
except (AttributeError , IndexError ):
402
- raise KubernetesException (
403
- "Unable to parse node_selector: %s" % node_selector
404
- )
456
+ raise KubernetesException ("Unable to parse kubernetes list: %s" % items )
457
+
458
+ @staticmethod
459
+ def validate_kube_labels (
460
+ labels : Optional [Dict [str , Optional [str ]]],
461
+ ) -> bool :
462
+ """Validate label values.
463
+
464
+ This validates the kubernetes label values. It does not validate the keys.
465
+ Ideally, keys should be static and also the validation rules for keys are
466
+ more complex than those for values. For full validation rules, see:
467
+
468
+ https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set
469
+ """
470
+
471
+ def validate_label (s : Optional [str ]):
472
+ regex_match = r"^(([A-Za-z0-9][-A-Za-z0-9_.]{0,61})?[A-Za-z0-9])?$"
473
+ if not s :
474
+ # allow empty label
475
+ return True
476
+ if not re .search (regex_match , s ):
477
+ raise KubernetesException (
478
+ 'Invalid value: "%s"\n '
479
+ "A valid label must be an empty string or one that\n "
480
+ " - Consist of alphanumeric, '-', '_' or '.' characters\n "
481
+ " - Begins and ends with an alphanumeric character\n "
482
+ " - Is at most 63 characters" % s
483
+ )
484
+ return True
485
+
486
+ return all ([validate_label (v ) for v in labels .values ()]) if labels else True
0 commit comments