Skip to content

Commit 1f59af5

Browse files
committed
Add verbose output, add layerwise adapt clip (non-DP yet)
1 parent 1e8a50d commit 1f59af5

File tree

3 files changed

+214
-2
lines changed

3 files changed

+214
-2
lines changed

opacus/optimizers/__init__.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from .adaclipoptimizer import AdaClipDPOptimizer
15+
from .adaclipoptimizer import (
16+
AdaClipDPOptimizer,
17+
LayerwiseAdaClipDPOptimizer,
18+
)
1619
from .ddp_perlayeroptimizer import (
1720
DistributedPerLayerOptimizer,
1821
SimpleDistributedPerLayerOptimizer,
@@ -56,6 +59,9 @@ def get_optimizer_class(clipping: str, distributed: bool, grad_sample_mode: str
5659
return SimpleDistributedPerLayerOptimizer
5760
else:
5861
raise ValueError(f"Unexpected grad_sample_mode: {grad_sample_mode}")
62+
elif clipping == "adapt_layer":
63+
print("Using LayerwiseAdaClipDPOptimizer")
64+
return LayerwiseAdaClipDPOptimizer
5965
elif clipping == "adaptive" and distributed is False:
6066
return AdaClipDPOptimizer
6167
raise ValueError(

opacus/optimizers/adaclipoptimizer.py

Lines changed: 201 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import logging
1818
from typing import Callable, Optional
1919

20+
import numpy as np
2021
import torch
2122
from torch.optim import Optimizer
2223

@@ -103,6 +104,11 @@ def clip_and_accumulate(self):
103104
]
104105
per_sample_norms = torch.stack(per_param_norms, dim=1).norm(2, dim=1)
105106

107+
quantiles = [0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]
108+
quantile_values = torch.quantile(per_sample_norms, torch.tensor(quantiles).cuda())
109+
for q, value in zip(quantiles, quantile_values):
110+
print(f"!!! quantile {int(q * 100)}%: {value.item()}")
111+
106112
#print(f"max per_param_norms before clipping: {per_sample_norms.max().item()}")
107113

108114
# Create a mask to determine which gradients need to be clipped based on the clipbound
@@ -154,7 +160,9 @@ def update_clipbound(self):
154160
elif self.clipbound < self.min_clipbound:
155161
self.clipbound = self.min_clipbound
156162

157-
#print(f"!!! self.clipbound: {self.clipbound}")
163+
print(f"!!! unclipped_frac: {unclipped_frac}, self.target_unclipped_quantile: {self.target_unclipped_quantile}")
164+
print(f"!!! self.clipbound: {self.clipbound}")
165+
print("============================================")
158166

159167
def pre_step(
160168
self, closure: Optional[Callable[[], float]] = None
@@ -163,3 +171,195 @@ def pre_step(
163171
if pre_step_full:
164172
self.update_clipbound()
165173
return pre_step_full
174+
175+
class LayerwiseAdaClipDPOptimizer(DPOptimizer):
176+
177+
def __init__(
178+
self,
179+
optimizer: Optimizer,
180+
*,
181+
noise_multiplier: float,
182+
max_grad_norm: float,
183+
expected_batch_size: Optional[int],
184+
loss_reduction: str = "mean",
185+
generator=None,
186+
secure_mode: bool = False,
187+
normalize_clipping: bool = False,
188+
optim_args: dict = None,
189+
):
190+
191+
assert(normalize_clipping == True), "Let us focus on the normalized version first"
192+
max_grad_norm = 1.0
193+
194+
super().__init__(
195+
optimizer,
196+
noise_multiplier=noise_multiplier,
197+
max_grad_norm=max_grad_norm,
198+
expected_batch_size=expected_batch_size,
199+
loss_reduction=loss_reduction,
200+
generator=generator,
201+
secure_mode=secure_mode,
202+
normalize_clipping=normalize_clipping,
203+
optim_args=optim_args,
204+
)
205+
206+
target_unclipped_quantile = optim_args.get('target_unclipped_quantile', 0.0)
207+
clipbound_learning_rate = optim_args.get('clipbound_learning_rate', 1.0)
208+
count_threshold = optim_args.get('count_threshold', 1.0)
209+
max_clipbound = optim_args.get('max_clipbound', torch.inf)
210+
min_clipbound = optim_args.get('min_clipbound', -torch.inf)
211+
unclipped_num_std = optim_args.get('unclipped_num_std')
212+
assert (max_clipbound > min_clipbound), "max_clipbound must be larger than min_clipbound."
213+
self.backbone_clipbound = max_grad_norm # Initial clip bound for backbone
214+
self.head_clipbound = max_grad_norm # Initial clip bound for head
215+
self.target_unclipped_quantile = target_unclipped_quantile
216+
self.clipbound_learning_rate = clipbound_learning_rate
217+
self.count_threshold = count_threshold
218+
self.max_clipbound = max_clipbound
219+
self.min_clipbound = min_clipbound
220+
self.unclipped_num_std = unclipped_num_std
221+
# Theorem 1. in https://arxiv.org/pdf/1905.03871.pdf
222+
self.noise_multiplier = (
223+
self.noise_multiplier ** (-2) - (2 * unclipped_num_std) ** (-2)
224+
) ** (-1 / 2)
225+
self.sample_size = 0
226+
self.unclipped_num_backbone = 0
227+
self.unclipped_num_head = 0
228+
229+
def zero_grad(self, set_to_none: bool = False):
230+
"""
231+
Clear gradients, self.sample_size and self.unclipped_num
232+
"""
233+
super().zero_grad(set_to_none)
234+
235+
self.sample_size = 0
236+
self.unclipped_num_backbone = 0
237+
self.unclipped_num_head = 0
238+
239+
def ensure_base_bound(self, mean_backbone_norm, mean_head_norm):
240+
"""
241+
Normalize the backbone and head norms such that their combined norm equals max_grad_norm.
242+
"""
243+
factor = self.max_grad_norm / np.sqrt(mean_backbone_norm**2 + mean_head_norm**2)
244+
backbone_max_grad_norm = mean_backbone_norm * factor
245+
head_max_grad_norm = mean_head_norm * factor
246+
return backbone_max_grad_norm, head_max_grad_norm
247+
248+
def clip_and_accumulate(self):
249+
per_param_norms = [
250+
g.view(len(g), -1).norm(2, dim=-1) for g in self.grad_samples
251+
]
252+
# per_sample_norms = torch.stack(per_param_norms, dim=1).norm(2, dim=1)
253+
254+
# Separate backbone and head gradients
255+
backbone_norms = per_param_norms[:-2]
256+
head_norms = per_param_norms[-2:]
257+
258+
per_sample_norms_backbone = torch.stack(backbone_norms, dim=1).norm(2, dim=1)
259+
per_sample_norms_head = torch.stack(head_norms, dim=1).norm(2, dim=1)
260+
261+
mean_backbone_norm = per_sample_norms_backbone.mean().item()
262+
mean_head_norm = per_sample_norms_head.mean().item()
263+
264+
# NOTE: now it is not private, fix it later
265+
print(f" - mean_backbone_norm: {mean_backbone_norm}")
266+
print(f" - mean_head_norm: {mean_head_norm}")
267+
268+
backbone_max_grad_norm, head_max_grad_norm = self.ensure_base_bound(mean_backbone_norm, mean_head_norm)
269+
270+
# Calculate separate clip factors based on adjusted max_grad_norms
271+
backbone_clip_factor = torch.minimum(
272+
backbone_max_grad_norm / (per_sample_norms_backbone + 1e-6),
273+
torch.full_like(per_sample_norms_backbone, backbone_max_grad_norm / self.backbone_clipbound),
274+
)
275+
276+
head_clip_factor = torch.minimum(
277+
head_max_grad_norm / (per_sample_norms_head + 1e-6),
278+
torch.full_like(per_sample_norms_head, head_max_grad_norm / self.head_clipbound),
279+
)
280+
281+
# Clip and scale gradients
282+
for p in self.params[:-2]:
283+
_check_processed_flag(p.grad_sample)
284+
grad_sample = self._get_flat_grad_sample(p)
285+
grad = torch.einsum("i,i...", backbone_clip_factor, grad_sample)
286+
287+
if p.summed_grad is not None:
288+
p.summed_grad += grad
289+
else:
290+
p.summed_grad = grad
291+
292+
_mark_as_processed(p.grad_sample)
293+
294+
for p in self.params[-2:]:
295+
_check_processed_flag(p.grad_sample)
296+
grad_sample = self._get_flat_grad_sample(p)
297+
grad = torch.einsum("i,i...", head_clip_factor, grad_sample)
298+
299+
if p.summed_grad is not None:
300+
p.summed_grad += grad
301+
else:
302+
p.summed_grad = grad
303+
304+
_mark_as_processed(p.grad_sample)
305+
306+
# Combine gradients into final form
307+
self.sample_size += len(per_sample_norms_head)
308+
self.unclipped_num_backbone += (
309+
per_sample_norms_backbone < self.backbone_clipbound * self.count_threshold
310+
).sum()
311+
self.unclipped_num_head += (
312+
per_sample_norms_head < self.head_clipbound * self.count_threshold
313+
).sum()
314+
315+
def add_noise(self):
316+
super().add_noise()
317+
318+
unclipped_num_noise_backbone = _generate_noise(
319+
std=self.unclipped_num_std,
320+
reference=self.unclipped_num_backbone,
321+
generator=self.generator,
322+
)
323+
324+
unclipped_num_noise_head = _generate_noise(
325+
std=self.unclipped_num_std,
326+
reference=self.unclipped_num_head,
327+
generator=self.generator,
328+
)
329+
330+
self.unclipped_num_backbone = float(self.unclipped_num_backbone)
331+
self.unclipped_num_head = float(self.unclipped_num_head)
332+
self.unclipped_num_backbone += unclipped_num_noise_backbone
333+
self.unclipped_num_head += unclipped_num_noise_head
334+
335+
def update_clipbound(self):
336+
"""
337+
Update clipping bound based on unclipped fraction
338+
"""
339+
unclipped_frac_backbone = self.unclipped_num_backbone / self.sample_size
340+
unclipped_frac_head = self.unclipped_num_head / self.sample_size
341+
342+
self.backbone_clipbound *= torch.exp(
343+
-self.clipbound_learning_rate
344+
* (unclipped_frac_backbone - self.target_unclipped_quantile)
345+
)
346+
self.head_clipbound *= torch.exp(
347+
-self.clipbound_learning_rate
348+
* (unclipped_frac_head - self.target_unclipped_quantile)
349+
)
350+
351+
# Ensure bounds are within min and max limits
352+
self.backbone_clipbound = torch.clamp(self.backbone_clipbound, self.min_clipbound, self.max_clipbound)
353+
self.head_clipbound = torch.clamp(self.head_clipbound, self.min_clipbound, self.max_clipbound)
354+
355+
print(f"!!! - self.backbone_clipbound: {self.backbone_clipbound}")
356+
print(f"!!! - self.head_clipbound: {self.head_clipbound}")
357+
358+
def pre_step(
359+
self, closure: Optional[Callable[[], float]] = None
360+
) -> Optional[float]:
361+
pre_step_full = super().pre_step()
362+
if pre_step_full:
363+
self.update_clipbound()
364+
return pre_step_full
365+

opacus/optimizers/optimizer.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -456,6 +456,12 @@ def clip_and_accumulate(self):
456456
self.max_grad_norm / (per_sample_norms + 1e-6)
457457
).clamp(max=1.0)
458458

459+
quantiles = [0.01, 0.05, 0.1, 0.25, 0.5, 0.75, 0.9, 0.95, 0.99]
460+
quantile_values = torch.quantile(per_sample_norms, torch.tensor(quantiles).cuda())
461+
for q, value in zip(quantiles, quantile_values):
462+
print(f"!!! quantile {int(q * 100)}%: {value.item()}")
463+
print("---------------------------")
464+
459465
for p in self.params:
460466
_check_processed_flag(p.grad_sample)
461467
grad_sample = self._get_flat_grad_sample(p)

0 commit comments

Comments
 (0)