44 * SPDX-License-Identifier: BSD-3-Clause
55 */
66
7+ // Sphinx: #1
78#include < stdlib.h>
89#include < stdio.h>
910
@@ -29,11 +30,10 @@ if( err != CUTENSORNET_STATUS_SUCCESS ) \
2930
3031struct GPUTimer
3132{
32- GPUTimer ()
33+ GPUTimer (cudaStream_t stream): stream_(stream )
3334 {
3435 cudaEventCreate (&start_);
3536 cudaEventCreate (&stop_);
36- cudaEventRecord (start_, 0 );
3737 }
3838
3939 ~GPUTimer ()
@@ -44,19 +44,21 @@ struct GPUTimer
4444
4545 void start ()
4646 {
47- cudaEventRecord (start_, 0 );
47+ cudaEventRecord (start_, stream_ );
4848 }
4949
5050 float seconds ()
5151 {
52- cudaEventRecord (stop_, 0 );
52+ cudaEventRecord (stop_, stream_ );
5353 cudaEventSynchronize (stop_);
5454 float time;
5555 cudaEventElapsedTime (&time, start_, stop_);
5656 return time * 1e-3 ;
5757 }
58+
5859 private:
5960 cudaEvent_t start_, stop_;
61+ cudaStream_t stream_;
6062};
6163
6264
@@ -80,12 +82,12 @@ int main()
8082 printf (" ========================\n " );
8183
8284 typedef float floatType;
83-
8485 cudaDataType_t typeData = CUDA_R_32F;
8586 cutensornetComputeType_t typeCompute = CUTENSORNET_COMPUTE_32F;
8687
8788 printf (" Include headers and define data types\n " );
8889
90+ // Sphinx: #2
8991 /* *********************
9092 * Computing: D_{m,x,n,y} = A_{m,h,k,n} B_{u,k,h} C_{x,u,y}
9193 **********************/
@@ -124,6 +126,7 @@ int main()
124126
125127 printf (" Define network, modes, and extents\n " );
126128
129+ // Sphinx: #3
127130 /* *********************
128131 * Allocating data
129132 **********************/
@@ -182,18 +185,20 @@ int main()
182185 *******************/
183186
184187 for (uint64_t i = 0 ; i < elementsA; i++)
185- A[i] = ((( float ) rand ())/RAND_MAX - 0.5 )* 100 ;
188+ A[i] = ((float ) rand ())/RAND_MAX;
186189 for (uint64_t i = 0 ; i < elementsB; i++)
187- B[i] = ((( float ) rand ())/RAND_MAX - 0.5 )* 100 ;
190+ B[i] = ((float ) rand ())/RAND_MAX;
188191 for (uint64_t i = 0 ; i < elementsC; i++)
189- C[i] = (((float ) rand ())/RAND_MAX - 0.5 )*100 ;
192+ C[i] = ((float ) rand ())/RAND_MAX;
193+ memset (D, 0 , sizeof (floatType) * elementsD);
190194
191195 HANDLE_CUDA_ERROR (cudaMemcpy (rawDataIn_d[0 ], A, sizeA, cudaMemcpyHostToDevice));
192196 HANDLE_CUDA_ERROR (cudaMemcpy (rawDataIn_d[1 ], B, sizeB, cudaMemcpyHostToDevice));
193197 HANDLE_CUDA_ERROR (cudaMemcpy (rawDataIn_d[2 ], C, sizeC, cudaMemcpyHostToDevice));
194198
195199 printf (" Allocate memory for data and workspace, and initialize data.\n " );
196200
201+ // Sphinx: #4
197202 /* ************************
198203 * cuTensorNet
199204 *************************/
@@ -247,6 +252,7 @@ int main()
247252
248253 printf (" Initialize the cuTensorNet library and create a network descriptor.\n " );
249254
255+ // Sphinx: #5
250256 /* ******************************
251257 * Find "optimal" contraction order and slicing
252258 *******************************/
@@ -284,6 +290,7 @@ int main()
284290
285291 printf (" Find an optimized contraction path with cuTensorNet optimizer.\n " );
286292
293+ // Sphinx: #6
287294 /* ******************************
288295 * Initialize all pair-wise contraction plans (for cuTENSOR)
289296 *******************************/
@@ -349,10 +356,11 @@ int main()
349356
350357 printf (" Create a contraction plan for cuTENSOR and optionally auto-tune it.\n " );
351358
359+ // Sphinx: #7
352360 /* *********************
353361 * Run
354362 **********************/
355- GPUTimer timer;
363+ GPUTimer timer{stream} ;
356364 double minTimeCUTENSOR = 1e100 ;
357365 const int numRuns = 3 ; // to get stable perf results
358366 for (int i=0 ; i < numRuns; ++i)
@@ -364,6 +372,9 @@ int main()
364372 * Contract over all slices.
365373 *
366374 * A user may choose to parallelize this loop across multiple devices.
375+ * (Note, however, that as of cuTensorNet v1.0.0 the contraction must
376+ * start from slice 0, see the cutensornetContraction documentation at
377+ * https://docs.nvidia.com/cuda/cuquantum/cutensornet/api/functions.html#cutensornetcontraction )
367378 */
368379 for (int64_t sliceId=0 ; sliceId < numSlices; ++sliceId)
369380 {
0 commit comments