1
- #include <cudnn_v7 .h>
1
+ #include <cudnn .h>
2
2
3
3
cudnnStatus_t gocudnnNewConvolution (cudnnConvolutionDescriptor_t * retVal ,
4
- cudnnMathType_t mathType , const int groupCount ,
4
+ cudnnMathType_t mathType , const int groupCount ,
5
5
const int size , const int * padding ,
6
- const int * filterStrides ,
6
+ const int * filterStrides ,
7
7
const int * dilation ,
8
8
cudnnConvolutionMode_t convolutionMode , cudnnDataType_t dataType ) {
9
9
10
10
cudnnStatus_t status ;
11
11
status = cudnnCreateConvolutionDescriptor (retVal );
12
12
if (status != CUDNN_STATUS_SUCCESS ) {
13
- return status ;
13
+ return status ;
14
14
}
15
15
16
16
status = cudnnSetConvolutionMathType (* retVal , mathType );
17
17
if (status != CUDNN_STATUS_SUCCESS ) {
18
18
return status ;
19
19
}
20
20
21
- status = cudnnSetConvolutionGroupCount (* retVal , groupCount );
21
+ status = cudnnSetConvolutionGroupCount (* retVal , groupCount );
22
22
if (status != CUDNN_STATUS_SUCCESS ) {
23
23
return status ;
24
24
}
25
25
26
26
int padH ;
27
27
int padW ;
28
28
int u ;
29
- int v ;
29
+ int v ;
30
30
int dilationH ;
31
31
int dilationW ;
32
32
switch (size ) {
@@ -39,17 +39,17 @@ cudnnStatus_t gocudnnNewConvolution(cudnnConvolutionDescriptor_t *retVal,
39
39
u = filterStrides [0 ];
40
40
v = filterStrides [1 ];
41
41
dilationH = dilation [0 ];
42
- dilationW = dilation [1 ];
42
+ dilationW = dilation [1 ];
43
43
44
- status = cudnnSetConvolution2dDescriptor (* retVal ,
45
- padH , padW ,
46
- u , v ,
47
- dilationH , dilationW ,
44
+ status = cudnnSetConvolution2dDescriptor (* retVal ,
45
+ padH , padW ,
46
+ u , v ,
47
+ dilationH , dilationW ,
48
48
convolutionMode , dataType );
49
49
break ;
50
50
default :
51
51
status = cudnnSetConvolutionNdDescriptor (* retVal , size , padding , filterStrides , dilation , convolutionMode , dataType );
52
52
break ;
53
53
}
54
54
return status ;
55
- }
55
+ }
0 commit comments