Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion opacus/optimizers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from .optimizer import DPOptimizer
from .optimizer_fast_gradient_clipping import DPOptimizerFastGradientClipping
from .perlayeroptimizer import DPPerLayerOptimizer
from .autoclipoptimizer import AutoSFixedDPOptimizer


__all__ = [
Expand All @@ -32,9 +33,10 @@
"DistributedDPOptimizer",
"DPOptimizer",
"DPOptimizerFastGradientClipping",
"DistributedDPOptimizerFastGradientlipping",
"DistributedDPOptimizerFastGradientClipping",
"DPPerLayerOptimizer",
"SimpleDistributedPerLayerOptimizer",
"AutoSFixedDPOptimizer",
]


Expand All @@ -58,6 +60,8 @@ def get_optimizer_class(clipping: str, distributed: bool, grad_sample_mode: str
raise ValueError(f"Unexpected grad_sample_mode: {grad_sample_mode}")
elif clipping == "adaptive" and distributed is False:
return AdaClipDPOptimizer
elif clipping == "Auto":
return AutoSFixedDPOptimizer
raise ValueError(
f"Unexpected optimizer parameters. Clipping: {clipping}, distributed: {distributed}"
)
56 changes: 36 additions & 20 deletions opacus/optimizers/adaclipoptimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,17 +43,18 @@ def __init__(
optimizer: Optimizer,
*,
noise_multiplier: float,
target_unclipped_quantile: float,
clipbound_learning_rate: float,
max_clipbound: float,
min_clipbound: float,
unclipped_num_std: float,
max_grad_norm: float,
expected_batch_size: Optional[int],
loss_reduction: str = "mean",
generator=None,
secure_mode: bool = False,
normalize_clipping: bool = False,
optim_args: dict = None,
):

assert(normalize_clipping == True), "Let us focus on the normalized version first"
max_grad_norm = 1.0

super().__init__(
optimizer,
noise_multiplier=noise_multiplier,
Expand All @@ -62,12 +63,22 @@ def __init__(
loss_reduction=loss_reduction,
generator=generator,
secure_mode=secure_mode,
normalize_clipping=normalize_clipping,
optim_args=optim_args,
)
assert (
max_clipbound > min_clipbound
), "max_clipbound must be larger than min_clipbound."

target_unclipped_quantile = optim_args.get('target_unclipped_quantile', 0.0)
clipbound_learning_rate = optim_args.get('clipbound_learning_rate', 1.0)
count_threshold = optim_args.get('count_threshold', 1.0)
max_clipbound = optim_args.get('max_clipbound', torch.inf)
min_clipbound = optim_args.get('min_clipbound', -torch.inf)
unclipped_num_std = optim_args.get('unclipped_num_std')
clip_bound_init = optim_args.get('clip_bound_init', 1.0)
assert (max_clipbound > min_clipbound), "max_clipbound must be larger than min_clipbound."
self.clipbound = clip_bound_init
self.target_unclipped_quantile = target_unclipped_quantile
self.clipbound_learning_rate = clipbound_learning_rate
self.count_threshold = count_threshold
self.max_clipbound = max_clipbound
self.min_clipbound = min_clipbound
self.unclipped_num_std = unclipped_num_std
Expand All @@ -92,16 +103,19 @@ def clip_and_accumulate(self):
g.view(len(g), -1).norm(2, dim=-1) for g in self.grad_samples
]
per_sample_norms = torch.stack(per_param_norms, dim=1).norm(2, dim=1)
per_sample_clip_factor = (self.max_grad_norm / (per_sample_norms + 1e-6)).clamp(
max=1.0

#print(f"max per_param_norms before clipping: {per_sample_norms.max().item()}")

# Create a mask to determine which gradients need to be clipped based on the clipbound
per_sample_clip_factor = torch.minimum(
self.max_grad_norm / (per_sample_norms + 1e-6),
torch.full_like(per_sample_norms, self.max_grad_norm / self.clipbound),
)

# the two lines below are the only changes
# relative to the parent DPOptimizer class.
self.sample_size += len(per_sample_clip_factor)
self.unclipped_num += (
len(per_sample_clip_factor) - (per_sample_clip_factor < 1).sum()
)
self.unclipped_num += (per_sample_norms < self.clipbound * self.count_threshold).sum()

for p in self.params:
_check_processed_flag(p.grad_sample)
Expand All @@ -127,24 +141,26 @@ def add_noise(self):
self.unclipped_num = float(self.unclipped_num)
self.unclipped_num += unclipped_num_noise

def update_max_grad_norm(self):
def update_clipbound(self):
"""
Update clipping bound based on unclipped fraction
"""
unclipped_frac = self.unclipped_num / self.sample_size
self.max_grad_norm *= torch.exp(
self.clipbound *= torch.exp(
-self.clipbound_learning_rate
* (unclipped_frac - self.target_unclipped_quantile)
)
if self.max_grad_norm > self.max_clipbound:
self.max_grad_norm = self.max_clipbound
elif self.max_grad_norm < self.min_clipbound:
self.max_grad_norm = self.min_clipbound
if self.clipbound > self.max_clipbound:
self.clipbound = self.max_clipbound
elif self.clipbound < self.min_clipbound:
self.clipbound = self.min_clipbound

#print(f"!!! self.clipbound: {self.clipbound}")

def pre_step(
self, closure: Optional[Callable[[], float]] = None
) -> Optional[float]:
pre_step_full = super().pre_step()
if pre_step_full:
self.update_max_grad_norm()
self.update_clipbound()
return pre_step_full
102 changes: 102 additions & 0 deletions opacus/optimizers/autoclipoptimizer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from __future__ import annotations
import logging
from typing import Callable, Optional

import torch
from torch.optim import Optimizer

from .optimizer import (
DPOptimizer,
_check_processed_flag,
_mark_as_processed,
)

logger = logging.getLogger(__name__)


class AutoSFixedDPOptimizer(DPOptimizer):
"""
R-independent AUTO-S clipping (arXiv:2206.07136 §4):
g_i -> g_i / (||g_i||_2 + γ), with γ > 0.
Noise std equals `noise_multiplier` (σ from accountant). We force max_grad_norm=1.
"""

def __init__(
self,
optimizer: Optimizer,
*,
noise_multiplier: float, # set σ from accountant directly
max_grad_norm: float, # ignored; forced to 1.0 to keep R-independent
expected_batch_size: Optional[int],
loss_reduction: str = "mean",
generator=None,
secure_mode: bool = False,
normalize_clipping: bool = False, # must be False
optim_args: dict | None = None,
):
if normalize_clipping:
raise AssertionError(
"AUTO-S uses unnormalized clipping (normalize_clipping=False)."
)

# Force R = 1 to make noise std = σ and remove any dependence on R.
if max_grad_norm != 1.0:
logger.warning(
"AutoSFixedDPOptimizer: overriding max_grad_norm=%s to 1.0 for R-independence.",
max_grad_norm,
)
max_grad_norm = 1.0

super().__init__(
optimizer,
noise_multiplier=noise_multiplier, # this is σ
max_grad_norm=max_grad_norm, # fixed to 1.0
expected_batch_size=expected_batch_size,
loss_reduction=loss_reduction,
generator=generator,
secure_mode=secure_mode,
normalize_clipping=normalize_clipping,
optim_args=optim_args,
)

gamma_default = 1e-2 # §4: γ=0.01 as default
self.gamma = float((optim_args or {}).get("stability_const", gamma_default))

if self.gamma <= 0.0:
raise ValueError("stability_const γ must be > 0.")

def clip_and_accumulate(self):
"""
Apply R-independent AUTO-S: clip factor = 1 / (||g_i||_2 + γ).
"""
# per-sample L2 norms over all params
per_param_norms = [
g.view(len(g), -1).norm(2, dim=-1) for g in self.grad_samples
]
per_sample_norms = torch.stack(per_param_norms, dim=1).norm(2, dim=1)

# No min(1, ·), no R; pure 1 / (||g_i|| + γ)
per_sample_clip_factor = 1.0 / (per_sample_norms + self.gamma)

for p in self.params:
_check_processed_flag(p.grad_sample)
grad_sample = self._get_flat_grad_sample(p) # shape: [N, ...]
grad = torch.einsum("i,i...", per_sample_clip_factor, grad_sample)

if p.summed_grad is not None:
p.summed_grad += grad
else:
p.summed_grad = grad

_mark_as_processed(p.grad_sample)

def add_noise(self):
"""
Keep DPOptimizer's noise addition. With max_grad_norm=1.0, the noise std is exactly σ.
"""
super().add_noise()

def pre_step(
self, closure: Optional[Callable[[], float]] = None
) -> Optional[float]:
return super().pre_step(closure)
6 changes: 6 additions & 0 deletions opacus/optimizers/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,7 @@ def __init__(
generator=None,
secure_mode: bool = False,
normalize_clipping: bool = False,
optim_args: dict = None,
):
"""

Expand Down Expand Up @@ -448,6 +449,9 @@ def clip_and_accumulate(self):
g.reshape(len(g), -1).norm(2, dim=-1) for g in self.grad_samples
]
per_sample_norms = torch.stack(per_param_norms, dim=1).norm(2, dim=1)

#print(f"{per_sample_norms.mean()}")

per_sample_clip_factor = (
self.max_grad_norm / (per_sample_norms + 1e-6)
).clamp(max=1.0)
Expand Down Expand Up @@ -487,6 +491,8 @@ def add_noise(self):

_mark_as_processed(p.summed_grad)

#print(f"last noise add: {noise[:10]}")

def scale_grad(self):
"""
Applies given ``loss_reduction`` to ``p.grad``.
Expand Down
9 changes: 8 additions & 1 deletion opacus/privacy_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import IO, Any, BinaryIO, Dict, List, Optional, Tuple, Union

import torch
from torch import distributed as dist
from opacus.accountants import create_accountant
from opacus.accountants.utils import get_noise_multiplier
from opacus.data_loader import DPDataLoader, switch_generator
Expand Down Expand Up @@ -111,6 +112,7 @@ def _prepare_optimizer(
noise_generator=None,
grad_sample_mode="hooks",
normalize_clipping: bool = False,
optim_args: dict = None,
**kwargs,
) -> DPOptimizer:
if isinstance(optimizer, DPOptimizer):
Expand All @@ -137,6 +139,7 @@ def _prepare_optimizer(
generator=generator,
secure_mode=self.secure_mode,
normalize_clipping=normalize_clipping,
optim_args=optim_args,
**kwargs,
)

Expand Down Expand Up @@ -294,6 +297,7 @@ def make_private(
grad_sample_mode: str = "hooks",
normalize_clipping: bool = False,
total_steps: int = None,
optim_args: dict = None,
**kwargs,
) -> Tuple[GradSampleModule, DPOptimizer, DataLoader]:
"""
Expand Down Expand Up @@ -375,7 +379,7 @@ def make_private(
"Module parameters are different than optimizer Parameters"
)

distributed = isinstance(module, (DPDDP, DDP))
distributed = dist.get_world_size() > 1

module = self._prepare_model(
module,
Expand Down Expand Up @@ -427,6 +431,7 @@ def make_private(
clipping=clipping,
grad_sample_mode=grad_sample_mode,
normalize_clipping=normalize_clipping,
optim_args=optim_args,
**kwargs,
)

Expand Down Expand Up @@ -454,6 +459,7 @@ def make_private_with_epsilon(
grad_sample_mode: str = "hooks",
normalize_clipping: bool = False,
total_steps: int = None,
optim_args: dict = None,
**kwargs,
):
"""
Expand Down Expand Up @@ -569,6 +575,7 @@ def make_private_with_epsilon(
clipping=clipping,
normalize_clipping=normalize_clipping,
total_steps=total_steps,
optim_args=optim_args,
)

def get_epsilon(self, delta):
Expand Down
Loading