5050import java .io .IOException ;
5151import java .net .InetSocketAddress ;
5252import java .net .SocketAddress ;
53+ import java .net .UnknownHostException ;
5354import java .nio .ByteBuffer ;
5455import java .nio .file .Files ;
5556import java .nio .file .Path ;
6465class Connection {
6566
6667 private static final Logger logger = LoggerFactory .getLogger (Connection .class );
67- private static final String MASTER_ADDR = "127.0.0.1" ;
6868
6969 private int port ;
7070 private SocketAddress socketAddress ;
7171 private Channel channel ;
7272 private RequestHandler requestHandler ;
7373
74- Connection (PyEnv pyEnv , int basePort , int rank ) {
74+ Connection (PyEnv pyEnv , int basePort , int rank , String hostname ) {
7575 requestHandler = new RequestHandler ();
7676 port = 19000 + basePort ;
77- socketAddress = getSocketAddress (pyEnv .isMpiMode (), rank );
77+ socketAddress = getSocketAddress (pyEnv .isMpiMode (), rank , hostname );
7878 }
7979
80- static Process startPython (PyEnv pyEnv , Model model , int workerId , int port )
80+ static Process startPython (PyEnv pyEnv , Model model , int workerId , int port , String [] hosts )
8181 throws IOException {
8282 Path tmp = Paths .get (System .getProperty ("java.io.tmpdir" ));
8383 try (Stream <Path > stream = Files .list (tmp )) {
@@ -100,7 +100,7 @@ static Process startPython(PyEnv pyEnv, Model model, int workerId, int port)
100100 });
101101 }
102102 File modelPath = model .getModelPath ().toFile ();
103- String [] args = getPythonStartCmd (pyEnv , model , workerId , port );
103+ String [] args = getPythonStartCmd (pyEnv , model , workerId , port , hosts );
104104 String [] envp = pyEnv .getEnvironmentVars (model );
105105 logger .debug ("cmd: {}" , (Object ) args );
106106
@@ -120,21 +120,84 @@ CompletableFuture<Output> send(Input input) throws InterruptedException {
120120 return f ;
121121 }
122122
123- static String [] getPythonStartCmd (PyEnv pyEnv , Model model , int workerId , int port ) {
123+ static String [] getPythonStartCmd (
124+ PyEnv pyEnv , Model model , int workerId , int port , String [] hosts ) {
124125 Device device = model .getNDManager ().getDevice ();
125126 int deviceId = device .getDeviceId ();
127+ int clusterSize = PyEnv .getClusterSize ();
126128 int tensorParallelDegree = pyEnv .getTensorParallelDegree ();
127129 String entryPoint = pyEnv .getEntryPoint ();
128130 String recommendedEntryPoint = pyEnv .getRecommendedEntryPoint ();
131+
132+ if (PyEnv .isMultiNode ()) {
133+ String cudaDevices = getVisibleDevices (workerId , tensorParallelDegree / clusterSize );
134+ logger .info ("Set before mpirun CUDA_VISIBLE_DEVICES={}" , cudaDevices );
135+ StringBuilder sb = new StringBuilder ();
136+ boolean first = true ;
137+ for (String host : hosts ) {
138+ if (first ) {
139+ first = false ;
140+ } else {
141+ sb .append (',' );
142+ }
143+ sb .append (host ).append (':' ).append (tensorParallelDegree / clusterSize );
144+ }
145+ String [] args = new String [46 ];
146+ args [0 ] = "mpirun" ;
147+ args [1 ] = "-np" ;
148+ args [2 ] = String .valueOf (tensorParallelDegree );
149+ args [3 ] = "--host" ;
150+ args [4 ] = sb .toString ();
151+ args [5 ] = "--allow-run-as-root" ;
152+ args [6 ] = "--bind-to" ;
153+ args [7 ] = "none" ;
154+ args [8 ] = "--mca" ;
155+ args [9 ] = "orte_keep_fqdn_hostnames" ;
156+ args [10 ] = "t" ;
157+ args [11 ] = "--tag-output" ;
158+ args [12 ] = "-x" ;
159+ args [13 ] = "FI_PROVIDER=efa" ;
160+ args [14 ] = "-x" ;
161+ args [15 ] = "RDMAV_FORK_SAFE=1" ;
162+ args [16 ] = "-x" ;
163+ args [17 ] = "FI_EFA_USE_DEVICE_RDMA=1" ;
164+ args [18 ] = "-x" ;
165+ args [19 ] = "LD_LIBRARY_PATH" ;
166+ args [20 ] = "-x" ;
167+ args [21 ] = "PYTHONPATH" ;
168+ args [22 ] = "-x" ;
169+ args [23 ] = "CUDA_VISIBLE_DEVICES=" + cudaDevices ;
170+ args [24 ] = "-x" ;
171+ args [25 ] = "MASTER_ADDR=" + pyEnv .getMasterAddr ();
172+ args [26 ] = "-x" ;
173+ args [27 ] = "MKL_DYNAMIC=FALSE" ;
174+ args [28 ] = pyEnv .getPythonExecutable ();
175+ args [29 ] = PyEnv .getEngineCacheDir () + "/djl_python_engine.py" ;
176+ args [30 ] = "--model-dir" ;
177+ args [31 ] = model .getModelPath ().toAbsolutePath ().toString ();
178+ args [32 ] = "--entry-point" ;
179+ args [33 ] = entryPoint == null ? "" : entryPoint ;
180+ args [34 ] = "--sock-type" ;
181+ args [35 ] = "tcp" ;
182+ args [36 ] = "--sock-name" ;
183+ args [37 ] = "0.0.0.0" ;
184+ args [38 ] = "--port" ;
185+ args [39 ] = String .valueOf (port );
186+ args [40 ] = "--tensor-parallel-degree" ;
187+ args [41 ] = String .valueOf (tensorParallelDegree );
188+ args [42 ] = "--cluster-size" ;
189+ args [43 ] = String .valueOf (clusterSize );
190+ args [44 ] = "--recommended-entry-point" ;
191+ args [45 ] = recommendedEntryPoint == null ? "" : recommendedEntryPoint ;
192+ return args ;
193+ }
194+
129195 if (pyEnv .isMpiMode ()) {
130196 String cudaDevices = getVisibleDevices (workerId , tensorParallelDegree );
131197 logger .info ("Set CUDA_VISIBLE_DEVICES={}" , cudaDevices );
132198 String [] args = new String [42 ];
133199 args [0 ] = "mpirun" ;
134200 args [1 ] = "-np" ;
135- // TODO: When we support multi nodes, change it to the product of tensor parallel value
136- // and
137- // pipeline parallel value.
138201 args [2 ] = String .valueOf (tensorParallelDegree );
139202 args [3 ] = "--allow-run-as-root" ;
140203 args [4 ] = "--bind-to" ;
@@ -156,7 +219,7 @@ static String[] getPythonStartCmd(PyEnv pyEnv, Model model, int workerId, int po
156219 args [20 ] = "-x" ;
157220 args [21 ] = "CUDA_VISIBLE_DEVICES=" + cudaDevices ;
158221 args [22 ] = "-x" ;
159- args [23 ] = "MASTER_ADDR=" + MASTER_ADDR ;
222+ args [23 ] = "MASTER_ADDR=" + pyEnv . getMasterAddr () ;
160223 args [24 ] = "-x" ;
161224 args [25 ] = "MASTER_PORT=" + port ;
162225 args [26 ] = "-x" ;
@@ -196,7 +259,7 @@ static String[] getPythonStartCmd(PyEnv pyEnv, Model model, int workerId, int po
196259 logger .info ("Set OMP_NUM_THREADS={}" , neuronThreads );
197260 }
198261 boolean uds = Epoll .isAvailable () || KQueue .isAvailable ();
199- String [] args = new String [14 ];
262+ String [] args = new String [16 ];
200263 args [0 ] = pyEnv .getPythonExecutable ();
201264 args [1 ] = PyEnv .getEngineCacheDir () + "/djl_python_engine.py" ;
202265 args [2 ] = "--sock-type" ;
@@ -209,8 +272,10 @@ static String[] getPythonStartCmd(PyEnv pyEnv, Model model, int workerId, int po
209272 args [9 ] = entryPoint == null ? "" : entryPoint ;
210273 args [10 ] = "--device-id" ;
211274 args [11 ] = String .valueOf (deviceId );
212- args [12 ] = "--recommended-entry-point" ;
213- args [13 ] = recommendedEntryPoint == null ? "" : recommendedEntryPoint ;
275+ args [12 ] = "--cluster-size" ;
276+ args [13 ] = String .valueOf (clusterSize );
277+ args [14 ] = "--recommended-entry-point" ;
278+ args [15 ] = recommendedEntryPoint == null ? "" : recommendedEntryPoint ;
214279 return args ;
215280 }
216281
@@ -248,7 +313,8 @@ private static String getNeuronThreads(int tensorParallelDegree) {
248313 return String .valueOf (1 );
249314 }
250315
251- void connect () throws InterruptedException {
316+ void connect () throws InterruptedException , UnknownHostException {
317+ logger .debug ("Connecting to socket: {}" , socketAddress );
252318 EventLoopGroup group = PyEnv .getEventLoopGroup ();
253319
254320 Bootstrap clientBootstrap = new Bootstrap ();
@@ -295,7 +361,10 @@ private static String getSocketPath(int port) {
295361 return System .getProperty ("java.io.tmpdir" ) + "/djl_sock." + port ;
296362 }
297363
298- private SocketAddress getSocketAddress (boolean mpiMode , int rank ) {
364+ private SocketAddress getSocketAddress (boolean mpiMode , int rank , String hostname ) {
365+ if (PyEnv .isMultiNode ()) {
366+ return new InetSocketAddress (hostname , port + rank );
367+ }
299368 if (mpiMode ) {
300369 return new DomainSocketAddress (getSocketPath (port ) + '.' + rank );
301370 }
@@ -307,16 +376,21 @@ private SocketAddress getSocketAddress(boolean mpiMode, int rank) {
307376 }
308377
309378 static EventLoopGroup newEventLoopGroup () {
379+ if (PyEnv .isMultiNode ()) {
380+ return new NioEventLoopGroup (new DaemonThreadFactory ());
381+ }
310382 if (Epoll .isAvailable ()) {
311383 return new EpollEventLoopGroup (new DaemonThreadFactory ());
312384 } else if (KQueue .isAvailable ()) {
313385 return new KQueueEventLoopGroup (new DaemonThreadFactory ());
314386 }
315-
316387 return new NioEventLoopGroup (new DaemonThreadFactory ());
317388 }
318389
319390 private static Class <? extends Channel > getClientChannel () {
391+ if (PyEnv .isMultiNode ()) {
392+ return NioSocketChannel .class ;
393+ }
320394 if (Epoll .isAvailable ()) {
321395 return EpollDomainSocketChannel .class ;
322396 } else if (KQueue .isAvailable ()) {
0 commit comments