Skip to content

Commit 81a2cf0

Browse files
andersonicJun Ru Anderson
andauthored
[feat] remove support for non-multitensor Adam
Co-authored-by: Jun Ru Anderson <[email protected]>
1 parent 57079b0 commit 81a2cf0

File tree

3 files changed

+27
-192
lines changed

3 files changed

+27
-192
lines changed
Lines changed: 2 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,8 @@
11
#include <torch/extension.h>
22

33
// CUDA forward declaration
4-
void fused_adam_cuda(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
5-
6-
void fused_adam_cuda_mt(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
7-
8-
9-
#define CHECK_CUDA(x) AT_ASSERTM(x.is_cuda(), #x " must be a CUDA tensor")
10-
#define CHECK_CONTIGUOUS(x) AT_ASSERTM(x.is_contiguous(), #x " must be contiguous")
11-
#define CHECK_INPUT(x) CHECK_CUDA(x); CHECK_CONTIGUOUS(x)
12-
13-
// C++ interface
14-
void adam(at::Tensor & p, at::Tensor & p_copy, at::Tensor & m, at::Tensor & v, at::Tensor & g, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay) {
15-
CHECK_INPUT(p);
16-
if (p_copy.numel() > 0) CHECK_INPUT(p_copy);
17-
CHECK_INPUT(m);
18-
CHECK_INPUT(v);
19-
CHECK_INPUT(g);
20-
int64_t num_elem = p.numel();
21-
AT_ASSERTM(m.numel() == num_elem, "number of elements in m and p tensors should be equal");
22-
AT_ASSERTM(v.numel() == num_elem, "number of elements in v and p tensors should be equal");
23-
AT_ASSERTM(g.numel() == num_elem, "number of elements in g and p tensors should be equal");
24-
AT_ASSERTM(p_copy.numel() == num_elem || p_copy.numel() == 0, "number of elements in p_copy and p tensors should be equal, or p_copy should be empty");
25-
26-
fused_adam_cuda(p, p_copy, m, v, g, lr, beta1, beta2, eps, grad_scale, step, mode, bias_correction, decay);
27-
}
4+
void fused_adam_cuda(int chunk_size, at::Tensor noop_flag, std::vector<std::vector<at::Tensor>> tensor_lists, float lr, float beta1, float beta2, float eps, float grad_scale, int step, int mode, int bias_correction, float decay);
285

296
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
30-
m.def("adam", &adam, "Adam optimized CUDA implementation.");
31-
m.def("adam_mt", &fused_adam_cuda_mt, "Multi tensor Adam optimized CUDA implementation.");
7+
m.def("adam", &fused_adam_cuda, "Multi tensor Adam optimized CUDA implementation.");
328
}

fairscale/clib/fused_adam_cuda/fused_adam_cuda_kernel.cu

Lines changed: 0 additions & 117 deletions
Original file line numberDiff line numberDiff line change
@@ -21,43 +21,7 @@ typedef enum{
2121
ADAM_MODE_1 =1 // eps outside square root
2222
} adamMode_t;
2323

24-
template <typename T, typename GRAD_T>
25-
__global__ void adam_cuda_kernel(
26-
GRAD_T* __restrict__ p,
27-
GRAD_T* __restrict__ p_copy, // For mixed precision training, pass NULL if not needed
28-
T* __restrict__ m,
29-
T* __restrict__ v,
30-
const GRAD_T * __restrict__ g,
31-
const float b1,
32-
const float b2,
33-
const float eps,
34-
const float grad_scale,
35-
const float step_size,
36-
const size_t tsize,
37-
adamMode_t mode,
38-
const float decay)
39-
{
40-
//Assuming 2D grids and 2D blocks
41-
const int blockId = gridDim.x * blockIdx.y + blockIdx.x;
42-
const int threadsPerBlock = blockDim.x * blockDim.y;
43-
const int threadIdInBlock = threadIdx.y * blockDim.x + threadIdx.x;
44-
const int i = (blockId * threadsPerBlock + threadIdInBlock);
45-
const int totThreads = gridDim.x*gridDim.y*threadsPerBlock;
4624

47-
for (int j = i; j < tsize; j+=totThreads) {
48-
T scaled_grad = g[j]/grad_scale;
49-
m[j] = b1*m[j] + (1-b1)*scaled_grad;
50-
v[j] = b2*v[j] + (1-b2)*scaled_grad*scaled_grad;
51-
float denom;
52-
if (mode == ADAM_MODE_0)
53-
denom = sqrtf(v[j] + eps);
54-
else // Mode 1
55-
denom = sqrtf(v[j]) + eps;
56-
float update = (m[j]/denom) + (decay*p[j]);
57-
p[j] = (GRAD_T)((float)p[j] - (step_size*update));
58-
if (p_copy != NULL) p_copy[j] = (GRAD_T) p[j];
59-
}
60-
}
6125

6226
template <int DEPTH, typename T, typename GRAD_T>
6327
struct AdamFunctor
@@ -147,87 +111,6 @@ struct AdamFunctor
147111
};
148112

149113
void fused_adam_cuda(
150-
at::Tensor & p,
151-
at::Tensor & p_copy,
152-
at::Tensor & m,
153-
at::Tensor & v,
154-
at::Tensor & g,
155-
float lr,
156-
float beta1,
157-
float beta2,
158-
float eps,
159-
float grad_scale,
160-
int step,
161-
int mode,
162-
int bias_correction,
163-
float decay)
164-
{
165-
// using namespace at;
166-
167-
//Get tensor size
168-
int tsize = p.numel();
169-
//Determine #threads and #blocks
170-
const int threadsPerBlock = 512;
171-
const dim3 blocks((tsize+threadsPerBlock-1)/threadsPerBlock);
172-
AT_ASSERTM(at::cuda::detail::canUse32BitIndexMath(p), "parameter tensor is too large to be indexed with int32");
173-
//Constants
174-
float step_size = 0;
175-
if (bias_correction == 1) {
176-
const float bias_correction1 = 1 - std::pow(beta1, step);
177-
const float bias_correction2 = 1 - std::pow(beta2, step);
178-
step_size = lr * std::sqrt(bias_correction2)/bias_correction1;
179-
}
180-
else {
181-
step_size = lr;
182-
}
183-
cudaStream_t stream = at::cuda::getCurrentCUDAStream();
184-
185-
if (g.scalar_type() == at::ScalarType::Half) {
186-
//all other values should be fp32 for half gradients
187-
AT_ASSERTM(p.scalar_type() == at::ScalarType::Half, "expected parameter to be of half type");
188-
//dispatch is done on the gradient type
189-
using namespace at; // prevents "toString is undefined" errors
190-
DISPATCH_FLOAT_AND_HALF(g.scalar_type(), 0, "adam_cuda_kernel",
191-
using accscalar_t = at::acc_type<scalar_t_0, true>;
192-
adam_cuda_kernel<accscalar_t, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
193-
p.DATA_PTR<scalar_t_0>(),
194-
p_copy.numel() ? p_copy.DATA_PTR<scalar_t_0>() : NULL,
195-
m.DATA_PTR<accscalar_t>(),
196-
v.DATA_PTR<accscalar_t>(),
197-
g.DATA_PTR<scalar_t_0>(),
198-
beta1,
199-
beta2,
200-
eps,
201-
grad_scale,
202-
step_size,
203-
tsize,
204-
(adamMode_t) mode,
205-
decay);
206-
);
207-
} else {
208-
using namespace at;
209-
DISPATCH_DOUBLE_AND_FLOAT(g.scalar_type(), 0, "adam_cuda_kernel",
210-
adam_cuda_kernel<scalar_t_0, scalar_t_0><<<blocks,threadsPerBlock, 0, stream>>>(
211-
p.DATA_PTR<scalar_t_0>(),
212-
NULL, //don't output p_copy for fp32, it's wasted write
213-
m.DATA_PTR<scalar_t_0>(),
214-
v.DATA_PTR<scalar_t_0>(),
215-
g.DATA_PTR<scalar_t_0>(),
216-
beta1,
217-
beta2,
218-
eps,
219-
grad_scale,
220-
step_size,
221-
tsize,
222-
(adamMode_t) mode,
223-
decay);
224-
);
225-
}
226-
THCudaCheck(cudaGetLastError());
227-
228-
}
229-
230-
void fused_adam_cuda_mt(
231114
int chunk_size,
232115
at::Tensor noop_flag,
233116
std::vector<std::vector<at::Tensor>> tensor_lists, // p, m, v, g, p_copy

fairscale/optim/adam.py

Lines changed: 25 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -55,13 +55,10 @@ def __init__(
5555
weight_decay: Optional[float] = 0.0,
5656
max_grad_norm: Optional[float] = 0.0,
5757
amsgrad: Optional[bool] = False,
58-
use_mt: Optional[bool] = True,
5958
):
6059

6160
self._use_multi_tensor = False
62-
if use_mt:
63-
self._use_multi_tensor = True
64-
self._overflow_buf = torch.cuda.IntTensor([0]) # type: ignore
61+
self._overflow_buf = torch.cuda.IntTensor([0]) # type: ignore
6562

6663
if amsgrad:
6764
raise RuntimeError("FusedAdam does not support the AMSGrad variant.")
@@ -131,51 +128,30 @@ def step(self, closure: Optional[Callable[[], float]] = None, scale: Optional[fl
131128
state["step"] += 1
132129
out_p = torch.tensor([])
133130

134-
if self._use_multi_tensor:
135-
pl = [p.data, exp_avg, exp_avg_sq, grad]
136-
137-
if p.device not in tensorlists:
138-
tensorlists[p.device] = [[], [], [], []]
139-
140-
for tl, t in zip(tensorlists[p.device], pl):
141-
tl.append(t)
142-
143-
else:
144-
with torch.cuda.device(p.device):
145-
fused_adam_cuda.adam(
146-
p.data,
147-
out_p,
148-
exp_avg,
149-
exp_avg_sq,
150-
grad,
151-
group["lr"],
152-
beta1,
153-
beta2,
154-
group["eps"],
155-
scale,
156-
state["step"],
157-
self.eps_mode,
158-
bias_correction,
159-
group["weight_decay"],
160-
)
161-
162-
if self._use_multi_tensor:
163-
for tensordevice, tensorlist in tensorlists.items():
164-
with torch.cuda.device(tensordevice):
165-
fused_adam_cuda.adam_mt(
166-
2048 * 32,
167-
self._overflow_buf,
168-
tensorlist,
169-
group["lr"],
170-
beta1,
171-
beta2,
172-
group["eps"],
173-
scale,
174-
state["step"],
175-
self.eps_mode,
176-
bias_correction,
177-
group["weight_decay"],
178-
)
131+
pl = [p.data, exp_avg, exp_avg_sq, grad]
132+
133+
if p.device not in tensorlists:
134+
tensorlists[p.device] = [[], [], [], []]
135+
136+
for tl, t in zip(tensorlists[p.device], pl):
137+
tl.append(t)
138+
139+
for tensordevice, tensorlist in tensorlists.items():
140+
with torch.cuda.device(tensordevice):
141+
fused_adam_cuda.adam(
142+
2048 * 32,
143+
self._overflow_buf,
144+
tensorlist,
145+
group["lr"],
146+
beta1,
147+
beta2,
148+
group["eps"],
149+
scale,
150+
state["step"],
151+
self.eps_mode,
152+
bias_correction,
153+
group["weight_decay"],
154+
)
179155

180156
return loss
181157

0 commit comments

Comments
 (0)