8
8
import json
9
9
import os
10
10
11
+ import regex as re
11
12
from hostlist import expand_hostlist
12
13
13
14
14
15
class NodeToGPUMapping :
15
16
"""Helper class to generate JSON file, load it in memory, and query GPU type for a nodename."""
16
17
17
- def __init__ (self , cluster_name , nodes_info_file ):
18
+ def __init__ (self , cluster_name , nodes_info_file , harmonize_gpu_map , gpus ):
18
19
"""Initialize with cluster name and TXT file path to parse."""
19
20
20
21
# Mapping is empty by default.
21
22
self .mapping = {}
22
23
self .json_path = None
24
+ self .harmonize_gpu_map = {
25
+ ** {
26
+ re .compile (regex ): gpu_type
27
+ for regex , gpu_type in harmonize_gpu_map .items ()
28
+ },
29
+ ** {re .compile (f".*{ gpu } .*" ): gpu for gpu in gpus },
30
+ }
23
31
24
32
# Mapping is filled only if TXT file is available.
25
33
if nodes_info_file and os .path .exists (nodes_info_file ):
@@ -36,7 +44,7 @@ def __init__(self, cluster_name, nodes_info_file):
36
44
not os .path .exists (self .json_path )
37
45
or os .stat (self .json_path ).st_mtime < info_file_stat .st_mtime
38
46
):
39
- # Pase TXT file into self.mapping.
47
+ # Parse TXT file into self.mapping.
40
48
self ._parse_nodenames (nodes_info_file , self .mapping )
41
49
# Save self.mapping into JSON file.
42
50
with open (self .json_path , "w" , encoding = "utf-8" ) as file :
@@ -46,9 +54,22 @@ def __init__(self, cluster_name, nodes_info_file):
46
54
with open (self .json_path , encoding = "utf-8" ) as file :
47
55
self .mapping = json .load (file )
48
56
57
+ def _harmonize_gpu (self , gpu_type : str ):
58
+ gpu_type = gpu_type .lower ().replace (" " , "-" ).split (":" )
59
+ if gpu_type [0 ] == "gpu" :
60
+ gpu_type .pop (0 )
61
+ gpu_type = gpu_type [0 ]
62
+ for regex , harmonized_gpu in self .harmonize_gpu_map .items ():
63
+ if regex .match (gpu_type ):
64
+ break
65
+ else :
66
+ harmonized_gpu = None
67
+ return harmonized_gpu
68
+
49
69
def __getitem__ (self , nodename ):
50
70
"""Return GPU type for nodename, or None if not found."""
51
- return self .mapping .get (nodename , None )
71
+ gpu_type = self .mapping .get (nodename , None )
72
+ return self ._harmonize_gpu (gpu_type )
52
73
53
74
@staticmethod
54
75
def _parse_nodenames (path : str , output : dict ):
0 commit comments