Skip to content

Commit 70aca0c

Browse files
authored
Support multi node for lmi-dist (#2125)
1 parent d62f747 commit 70aca0c

File tree

8 files changed

+175
-25
lines changed

8 files changed

+175
-25
lines changed

engines/python/setup/djl_python/arg_parser.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@ def python_engine_args():
7171
dest="tensor_parallel_degree",
7272
type=int,
7373
help='The tensor parallel degree')
74+
parser.add_argument('--cluster-size',
75+
required=False,
76+
dest="cluster_size",
77+
type=int,
78+
help='The cluster size')
7479
parser.add_argument('--recommended-entry-point',
7580
required=False,
7681
type=str,

engines/python/setup/djl_python/properties_manager/hf_properties.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@ class HuggingFaceProperties(Properties):
4545
device_id: int = -1
4646
task: str = None
4747
tensor_parallel_degree: int = -1
48+
cluster_size: int = 1
4849
device_map: str = None
4950
load_in_4bit: Optional[bool] = None
5051
load_in_8bit: Optional[bool] = None
@@ -112,10 +113,12 @@ def construct_kwargs_device_map(self):
112113
self.kwargs["device_map"] = self.device_map
113114
self.device = None
114115
logging.info(f"Using device map {self.device_map}")
115-
elif self.tensor_parallel_degree > 0 and torch.cuda.device_count() > 0:
116+
elif self.tensor_parallel_degree > 0 \
117+
and self.cluster_size > 0 \
118+
and torch.cuda.device_count() > 0:
116119
self.kwargs["device_map"] = "auto"
117120
self.device = None
118-
world_size = torch.cuda.device_count()
121+
world_size = torch.cuda.device_count() * self.cluster_size
119122
assert world_size == self.tensor_parallel_degree, \
120123
f"TP degree ({self.tensor_parallel_degree}) doesn't match available GPUs ({world_size})"
121124
logging.info(f"Using {world_size} gpus")

engines/python/setup/djl_python/properties_manager/properties.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ class Properties(BaseModel):
4949
# Make the default to auto, after java front end changes and test cases are changed.
5050
rolling_batch: RollingBatchEnum = RollingBatchEnum.disable
5151
tensor_parallel_degree: int = 1
52+
cluster_size: int = 1
5253
trust_remote_code: bool = False
5354
enable_streaming: StreamingEnum = StreamingEnum.false
5455
batch_size: int = 1

engines/python/setup/djl_python/tests/test_properties_manager.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,7 @@ def test_hf_all_configs(self):
290290
"model_id": "model_id",
291291
"model_dir": "model_dir",
292292
"tensor_parallel_degree": "4",
293+
"cluster_size": "2",
293294
"load_in_4bit": "false",
294295
"load_in_8bit": "true",
295296
"low_cpu_mem_usage": "true",
@@ -305,6 +306,10 @@ def test_hf_all_configs(self):
305306
}
306307

307308
hf_configs = HuggingFaceProperties(**properties)
309+
self.assertEqual(hf_configs.tensor_parallel_degree,
310+
int(properties['tensor_parallel_degree']))
311+
self.assertEqual(hf_configs.cluster_size,
312+
int(properties['cluster_size']))
308313
self.assertTrue(hf_configs.load_in_8bit)
309314
self.assertTrue(hf_configs.low_cpu_mem_usage)
310315
self.assertFalse(hf_configs.disable_flash_attn)

engines/python/setup/djl_python_engine.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,23 +47,27 @@ def __init__(self, args, service):
4747

4848
self.model_dir = args.model_dir
4949
self.sock_type = args.sock_type
50-
self.sock_name = f"{args.sock_name}.{rank}" if rank else args.sock_name
50+
self.sock_name = args.sock_name
5151
self.port = args.port
5252
self.service = service
5353
self.device_id = args.device_id
5454
self.tensor_parallel_degree = args.tensor_parallel_degree
55+
self.cluster_size = args.cluster_size
5556
self.entry_point = args.entry_point
5657
self.recommended_entry_point = args.recommended_entry_point
5758

5859
if self.sock_type == "unix":
5960
if self.sock_name is None:
6061
raise ValueError("Missing sock-name argument.")
62+
self.sock_name = f"{args.sock_name}.{rank}" if rank else args.sock_name
6163

6264
self.clean_up()
6365
elif self.sock_type == "tcp":
64-
self.sock_name = "127.0.0.1"
66+
if self.sock_name is None:
67+
self.sock_name = "0.0.0.0"
6568
if self.port is None:
6669
raise ValueError("Missing port argument.")
70+
self.port = int(self.port) + int(rank) if rank else self.port
6771
else:
6872
raise ValueError(f"Invalid socket-type: {self.sock_type}.")
6973

@@ -99,6 +103,8 @@ def run_server(self):
99103
if self.sock_type == "unix":
100104
self.sock.bind(self.sock_name)
101105
else:
106+
logging.info(
107+
f"Socket bind on address: {self.sock_name}:{self.port}")
102108
self.sock.bind((self.sock_name, int(self.port)))
103109

104110
self.sock.listen(128)
@@ -115,6 +121,8 @@ def run_server(self):
115121
prop = inputs.get_properties()
116122
if self.tensor_parallel_degree:
117123
prop["tensor_parallel_degree"] = self.tensor_parallel_degree
124+
if self.cluster_size:
125+
prop["cluster_size"] = self.cluster_size
118126
prop["device_id"] = self.device_id
119127
if "output_formatter" in prop and hasattr(
120128
self.service, prop["output_formatter"]):

engines/python/src/main/java/ai/djl/python/engine/Connection.java

Lines changed: 90 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@
5050
import java.io.IOException;
5151
import java.net.InetSocketAddress;
5252
import java.net.SocketAddress;
53+
import java.net.UnknownHostException;
5354
import java.nio.ByteBuffer;
5455
import java.nio.file.Files;
5556
import java.nio.file.Path;
@@ -64,20 +65,19 @@
6465
class Connection {
6566

6667
private static final Logger logger = LoggerFactory.getLogger(Connection.class);
67-
private static final String MASTER_ADDR = "127.0.0.1";
6868

6969
private int port;
7070
private SocketAddress socketAddress;
7171
private Channel channel;
7272
private RequestHandler requestHandler;
7373

74-
Connection(PyEnv pyEnv, int basePort, int rank) {
74+
Connection(PyEnv pyEnv, int basePort, int rank, String hostname) {
7575
requestHandler = new RequestHandler();
7676
port = 19000 + basePort;
77-
socketAddress = getSocketAddress(pyEnv.isMpiMode(), rank);
77+
socketAddress = getSocketAddress(pyEnv.isMpiMode(), rank, hostname);
7878
}
7979

80-
static Process startPython(PyEnv pyEnv, Model model, int workerId, int port)
80+
static Process startPython(PyEnv pyEnv, Model model, int workerId, int port, String[] hosts)
8181
throws IOException {
8282
Path tmp = Paths.get(System.getProperty("java.io.tmpdir"));
8383
try (Stream<Path> stream = Files.list(tmp)) {
@@ -100,7 +100,7 @@ static Process startPython(PyEnv pyEnv, Model model, int workerId, int port)
100100
});
101101
}
102102
File modelPath = model.getModelPath().toFile();
103-
String[] args = getPythonStartCmd(pyEnv, model, workerId, port);
103+
String[] args = getPythonStartCmd(pyEnv, model, workerId, port, hosts);
104104
String[] envp = pyEnv.getEnvironmentVars(model);
105105
logger.debug("cmd: {}", (Object) args);
106106

@@ -120,21 +120,84 @@ CompletableFuture<Output> send(Input input) throws InterruptedException {
120120
return f;
121121
}
122122

123-
static String[] getPythonStartCmd(PyEnv pyEnv, Model model, int workerId, int port) {
123+
static String[] getPythonStartCmd(
124+
PyEnv pyEnv, Model model, int workerId, int port, String[] hosts) {
124125
Device device = model.getNDManager().getDevice();
125126
int deviceId = device.getDeviceId();
127+
int clusterSize = PyEnv.getClusterSize();
126128
int tensorParallelDegree = pyEnv.getTensorParallelDegree();
127129
String entryPoint = pyEnv.getEntryPoint();
128130
String recommendedEntryPoint = pyEnv.getRecommendedEntryPoint();
131+
132+
if (PyEnv.isMultiNode()) {
133+
String cudaDevices = getVisibleDevices(workerId, tensorParallelDegree / clusterSize);
134+
logger.info("Set before mpirun CUDA_VISIBLE_DEVICES={}", cudaDevices);
135+
StringBuilder sb = new StringBuilder();
136+
boolean first = true;
137+
for (String host : hosts) {
138+
if (first) {
139+
first = false;
140+
} else {
141+
sb.append(',');
142+
}
143+
sb.append(host).append(':').append(tensorParallelDegree / clusterSize);
144+
}
145+
String[] args = new String[46];
146+
args[0] = "mpirun";
147+
args[1] = "-np";
148+
args[2] = String.valueOf(tensorParallelDegree);
149+
args[3] = "--host";
150+
args[4] = sb.toString();
151+
args[5] = "--allow-run-as-root";
152+
args[6] = "--bind-to";
153+
args[7] = "none";
154+
args[8] = "--mca";
155+
args[9] = "orte_keep_fqdn_hostnames";
156+
args[10] = "t";
157+
args[11] = "--tag-output";
158+
args[12] = "-x";
159+
args[13] = "FI_PROVIDER=efa";
160+
args[14] = "-x";
161+
args[15] = "RDMAV_FORK_SAFE=1";
162+
args[16] = "-x";
163+
args[17] = "FI_EFA_USE_DEVICE_RDMA=1";
164+
args[18] = "-x";
165+
args[19] = "LD_LIBRARY_PATH";
166+
args[20] = "-x";
167+
args[21] = "PYTHONPATH";
168+
args[22] = "-x";
169+
args[23] = "CUDA_VISIBLE_DEVICES=" + cudaDevices;
170+
args[24] = "-x";
171+
args[25] = "MASTER_ADDR=" + pyEnv.getMasterAddr();
172+
args[26] = "-x";
173+
args[27] = "MKL_DYNAMIC=FALSE";
174+
args[28] = pyEnv.getPythonExecutable();
175+
args[29] = PyEnv.getEngineCacheDir() + "/djl_python_engine.py";
176+
args[30] = "--model-dir";
177+
args[31] = model.getModelPath().toAbsolutePath().toString();
178+
args[32] = "--entry-point";
179+
args[33] = entryPoint == null ? "" : entryPoint;
180+
args[34] = "--sock-type";
181+
args[35] = "tcp";
182+
args[36] = "--sock-name";
183+
args[37] = "0.0.0.0";
184+
args[38] = "--port";
185+
args[39] = String.valueOf(port);
186+
args[40] = "--tensor-parallel-degree";
187+
args[41] = String.valueOf(tensorParallelDegree);
188+
args[42] = "--cluster-size";
189+
args[43] = String.valueOf(clusterSize);
190+
args[44] = "--recommended-entry-point";
191+
args[45] = recommendedEntryPoint == null ? "" : recommendedEntryPoint;
192+
return args;
193+
}
194+
129195
if (pyEnv.isMpiMode()) {
130196
String cudaDevices = getVisibleDevices(workerId, tensorParallelDegree);
131197
logger.info("Set CUDA_VISIBLE_DEVICES={}", cudaDevices);
132198
String[] args = new String[42];
133199
args[0] = "mpirun";
134200
args[1] = "-np";
135-
// TODO: When we support multi nodes, change it to the product of tensor parallel value
136-
// and
137-
// pipeline parallel value.
138201
args[2] = String.valueOf(tensorParallelDegree);
139202
args[3] = "--allow-run-as-root";
140203
args[4] = "--bind-to";
@@ -156,7 +219,7 @@ static String[] getPythonStartCmd(PyEnv pyEnv, Model model, int workerId, int po
156219
args[20] = "-x";
157220
args[21] = "CUDA_VISIBLE_DEVICES=" + cudaDevices;
158221
args[22] = "-x";
159-
args[23] = "MASTER_ADDR=" + MASTER_ADDR;
222+
args[23] = "MASTER_ADDR=" + pyEnv.getMasterAddr();
160223
args[24] = "-x";
161224
args[25] = "MASTER_PORT=" + port;
162225
args[26] = "-x";
@@ -196,7 +259,7 @@ static String[] getPythonStartCmd(PyEnv pyEnv, Model model, int workerId, int po
196259
logger.info("Set OMP_NUM_THREADS={}", neuronThreads);
197260
}
198261
boolean uds = Epoll.isAvailable() || KQueue.isAvailable();
199-
String[] args = new String[14];
262+
String[] args = new String[16];
200263
args[0] = pyEnv.getPythonExecutable();
201264
args[1] = PyEnv.getEngineCacheDir() + "/djl_python_engine.py";
202265
args[2] = "--sock-type";
@@ -209,8 +272,10 @@ static String[] getPythonStartCmd(PyEnv pyEnv, Model model, int workerId, int po
209272
args[9] = entryPoint == null ? "" : entryPoint;
210273
args[10] = "--device-id";
211274
args[11] = String.valueOf(deviceId);
212-
args[12] = "--recommended-entry-point";
213-
args[13] = recommendedEntryPoint == null ? "" : recommendedEntryPoint;
275+
args[12] = "--cluster-size";
276+
args[13] = String.valueOf(clusterSize);
277+
args[14] = "--recommended-entry-point";
278+
args[15] = recommendedEntryPoint == null ? "" : recommendedEntryPoint;
214279
return args;
215280
}
216281

@@ -248,7 +313,8 @@ private static String getNeuronThreads(int tensorParallelDegree) {
248313
return String.valueOf(1);
249314
}
250315

251-
void connect() throws InterruptedException {
316+
void connect() throws InterruptedException, UnknownHostException {
317+
logger.debug("Connecting to socket: {}", socketAddress);
252318
EventLoopGroup group = PyEnv.getEventLoopGroup();
253319

254320
Bootstrap clientBootstrap = new Bootstrap();
@@ -295,7 +361,10 @@ private static String getSocketPath(int port) {
295361
return System.getProperty("java.io.tmpdir") + "/djl_sock." + port;
296362
}
297363

298-
private SocketAddress getSocketAddress(boolean mpiMode, int rank) {
364+
private SocketAddress getSocketAddress(boolean mpiMode, int rank, String hostname) {
365+
if (PyEnv.isMultiNode()) {
366+
return new InetSocketAddress(hostname, port + rank);
367+
}
299368
if (mpiMode) {
300369
return new DomainSocketAddress(getSocketPath(port) + '.' + rank);
301370
}
@@ -307,16 +376,21 @@ private SocketAddress getSocketAddress(boolean mpiMode, int rank) {
307376
}
308377

309378
static EventLoopGroup newEventLoopGroup() {
379+
if (PyEnv.isMultiNode()) {
380+
return new NioEventLoopGroup(new DaemonThreadFactory());
381+
}
310382
if (Epoll.isAvailable()) {
311383
return new EpollEventLoopGroup(new DaemonThreadFactory());
312384
} else if (KQueue.isAvailable()) {
313385
return new KQueueEventLoopGroup(new DaemonThreadFactory());
314386
}
315-
316387
return new NioEventLoopGroup(new DaemonThreadFactory());
317388
}
318389

319390
private static Class<? extends Channel> getClientChannel() {
391+
if (PyEnv.isMultiNode()) {
392+
return NioSocketChannel.class;
393+
}
320394
if (Epoll.isAvailable()) {
321395
return EpollDomainSocketChannel.class;
322396
} else if (KQueue.isAvailable()) {

engines/python/src/main/java/ai/djl/python/engine/PyEnv.java

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ public class PyEnv {
4242

4343
static final Logger logger = LoggerFactory.getLogger(PyEnv.class);
4444

45+
private static int clusterSize;
4546
private static String engineCacheDir;
4647
private static String version;
4748
private static EventLoopGroup eventLoopGroup;
@@ -84,6 +85,7 @@ static synchronized void init() {
8485
return;
8586
}
8687

88+
setClusterSize();
8789
eventLoopGroup = Connection.newEventLoopGroup();
8890

8991
Path tmp = null;
@@ -128,6 +130,20 @@ static synchronized void init() {
128130
}
129131
}
130132

133+
static void setClusterSize() {
134+
if (clusterSize == 0) {
135+
clusterSize = Integer.parseInt(Utils.getenv("DJL_CLUSTER_SIZE", "1"));
136+
}
137+
}
138+
139+
static int getClusterSize() {
140+
return clusterSize;
141+
}
142+
143+
static boolean isMultiNode() {
144+
return clusterSize > 1;
145+
}
146+
131147
static String getVersion() {
132148
return version;
133149
}
@@ -304,6 +320,15 @@ public void setPythonExecutable(String pythonExecutable) {
304320
this.pythonExecutable = pythonExecutable;
305321
}
306322

323+
/**
324+
* Returns the master address.
325+
*
326+
* @return the master address
327+
*/
328+
public String getMasterAddr() {
329+
return Utils.getenv("MASTER_ADDR", "127.0.0.1");
330+
}
331+
307332
/**
308333
* Returns the tensor parallel degree.
309334
*
@@ -339,7 +364,7 @@ public void setTensorParallelDegree(int tensorParallelDegree) {
339364
}
340365

341366
int getMpiWorkers() {
342-
int gpuCount = CudaUtils.getGpuCount();
367+
int gpuCount = CudaUtils.getGpuCount() * clusterSize;
343368
String visibleDevices = Utils.getenv("CUDA_VISIBLE_DEVICES");
344369
if (gpuCount > 0 && visibleDevices != null) {
345370
int visibleCount = visibleDevices.split(",").length;

0 commit comments

Comments
 (0)