1717import logging
1818from typing import Callable , Optional
1919
20+ import numpy as np
2021import torch
2122from torch .optim import Optimizer
2223
@@ -103,6 +104,11 @@ def clip_and_accumulate(self):
103104 ]
104105 per_sample_norms = torch .stack (per_param_norms , dim = 1 ).norm (2 , dim = 1 )
105106
107+ quantiles = [0.01 , 0.05 , 0.1 , 0.25 , 0.5 , 0.75 , 0.9 , 0.95 , 0.99 ]
108+ quantile_values = torch .quantile (per_sample_norms , torch .tensor (quantiles ).cuda ())
109+ for q , value in zip (quantiles , quantile_values ):
110+ print (f"!!! quantile { int (q * 100 )} %: { value .item ()} " )
111+
106112 #print(f"max per_param_norms before clipping: {per_sample_norms.max().item()}")
107113
108114 # Create a mask to determine which gradients need to be clipped based on the clipbound
@@ -154,7 +160,9 @@ def update_clipbound(self):
154160 elif self .clipbound < self .min_clipbound :
155161 self .clipbound = self .min_clipbound
156162
157- #print(f"!!! self.clipbound: {self.clipbound}")
163+ print (f"!!! unclipped_frac: { unclipped_frac } , self.target_unclipped_quantile: { self .target_unclipped_quantile } " )
164+ print (f"!!! self.clipbound: { self .clipbound } " )
165+ print ("============================================" )
158166
159167 def pre_step (
160168 self , closure : Optional [Callable [[], float ]] = None
@@ -163,3 +171,195 @@ def pre_step(
163171 if pre_step_full :
164172 self .update_clipbound ()
165173 return pre_step_full
174+
175+ class LayerwiseAdaClipDPOptimizer (DPOptimizer ):
176+
177+ def __init__ (
178+ self ,
179+ optimizer : Optimizer ,
180+ * ,
181+ noise_multiplier : float ,
182+ max_grad_norm : float ,
183+ expected_batch_size : Optional [int ],
184+ loss_reduction : str = "mean" ,
185+ generator = None ,
186+ secure_mode : bool = False ,
187+ normalize_clipping : bool = False ,
188+ optim_args : dict = None ,
189+ ):
190+
191+ assert (normalize_clipping == True ), "Let us focus on the normalized version first"
192+ max_grad_norm = 1.0
193+
194+ super ().__init__ (
195+ optimizer ,
196+ noise_multiplier = noise_multiplier ,
197+ max_grad_norm = max_grad_norm ,
198+ expected_batch_size = expected_batch_size ,
199+ loss_reduction = loss_reduction ,
200+ generator = generator ,
201+ secure_mode = secure_mode ,
202+ normalize_clipping = normalize_clipping ,
203+ optim_args = optim_args ,
204+ )
205+
206+ target_unclipped_quantile = optim_args .get ('target_unclipped_quantile' , 0.0 )
207+ clipbound_learning_rate = optim_args .get ('clipbound_learning_rate' , 1.0 )
208+ count_threshold = optim_args .get ('count_threshold' , 1.0 )
209+ max_clipbound = optim_args .get ('max_clipbound' , torch .inf )
210+ min_clipbound = optim_args .get ('min_clipbound' , - torch .inf )
211+ unclipped_num_std = optim_args .get ('unclipped_num_std' )
212+ assert (max_clipbound > min_clipbound ), "max_clipbound must be larger than min_clipbound."
213+ self .backbone_clipbound = max_grad_norm # Initial clip bound for backbone
214+ self .head_clipbound = max_grad_norm # Initial clip bound for head
215+ self .target_unclipped_quantile = target_unclipped_quantile
216+ self .clipbound_learning_rate = clipbound_learning_rate
217+ self .count_threshold = count_threshold
218+ self .max_clipbound = max_clipbound
219+ self .min_clipbound = min_clipbound
220+ self .unclipped_num_std = unclipped_num_std
221+ # Theorem 1. in https://arxiv.org/pdf/1905.03871.pdf
222+ self .noise_multiplier = (
223+ self .noise_multiplier ** (- 2 ) - (2 * unclipped_num_std ) ** (- 2 )
224+ ) ** (- 1 / 2 )
225+ self .sample_size = 0
226+ self .unclipped_num_backbone = 0
227+ self .unclipped_num_head = 0
228+
229+ def zero_grad (self , set_to_none : bool = False ):
230+ """
231+ Clear gradients, self.sample_size and self.unclipped_num
232+ """
233+ super ().zero_grad (set_to_none )
234+
235+ self .sample_size = 0
236+ self .unclipped_num_backbone = 0
237+ self .unclipped_num_head = 0
238+
239+ def ensure_base_bound (self , mean_backbone_norm , mean_head_norm ):
240+ """
241+ Normalize the backbone and head norms such that their combined norm equals max_grad_norm.
242+ """
243+ factor = self .max_grad_norm / np .sqrt (mean_backbone_norm ** 2 + mean_head_norm ** 2 )
244+ backbone_max_grad_norm = mean_backbone_norm * factor
245+ head_max_grad_norm = mean_head_norm * factor
246+ return backbone_max_grad_norm , head_max_grad_norm
247+
248+ def clip_and_accumulate (self ):
249+ per_param_norms = [
250+ g .view (len (g ), - 1 ).norm (2 , dim = - 1 ) for g in self .grad_samples
251+ ]
252+ # per_sample_norms = torch.stack(per_param_norms, dim=1).norm(2, dim=1)
253+
254+ # Separate backbone and head gradients
255+ backbone_norms = per_param_norms [:- 2 ]
256+ head_norms = per_param_norms [- 2 :]
257+
258+ per_sample_norms_backbone = torch .stack (backbone_norms , dim = 1 ).norm (2 , dim = 1 )
259+ per_sample_norms_head = torch .stack (head_norms , dim = 1 ).norm (2 , dim = 1 )
260+
261+ mean_backbone_norm = per_sample_norms_backbone .mean ().item ()
262+ mean_head_norm = per_sample_norms_head .mean ().item ()
263+
264+ # NOTE: now it is not private, fix it later
265+ print (f" - mean_backbone_norm: { mean_backbone_norm } " )
266+ print (f" - mean_head_norm: { mean_head_norm } " )
267+
268+ backbone_max_grad_norm , head_max_grad_norm = self .ensure_base_bound (mean_backbone_norm , mean_head_norm )
269+
270+ # Calculate separate clip factors based on adjusted max_grad_norms
271+ backbone_clip_factor = torch .minimum (
272+ backbone_max_grad_norm / (per_sample_norms_backbone + 1e-6 ),
273+ torch .full_like (per_sample_norms_backbone , backbone_max_grad_norm / self .backbone_clipbound ),
274+ )
275+
276+ head_clip_factor = torch .minimum (
277+ head_max_grad_norm / (per_sample_norms_head + 1e-6 ),
278+ torch .full_like (per_sample_norms_head , head_max_grad_norm / self .head_clipbound ),
279+ )
280+
281+ # Clip and scale gradients
282+ for p in self .params [:- 2 ]:
283+ _check_processed_flag (p .grad_sample )
284+ grad_sample = self ._get_flat_grad_sample (p )
285+ grad = torch .einsum ("i,i..." , backbone_clip_factor , grad_sample )
286+
287+ if p .summed_grad is not None :
288+ p .summed_grad += grad
289+ else :
290+ p .summed_grad = grad
291+
292+ _mark_as_processed (p .grad_sample )
293+
294+ for p in self .params [- 2 :]:
295+ _check_processed_flag (p .grad_sample )
296+ grad_sample = self ._get_flat_grad_sample (p )
297+ grad = torch .einsum ("i,i..." , head_clip_factor , grad_sample )
298+
299+ if p .summed_grad is not None :
300+ p .summed_grad += grad
301+ else :
302+ p .summed_grad = grad
303+
304+ _mark_as_processed (p .grad_sample )
305+
306+ # Combine gradients into final form
307+ self .sample_size += len (per_sample_norms_head )
308+ self .unclipped_num_backbone += (
309+ per_sample_norms_backbone < self .backbone_clipbound * self .count_threshold
310+ ).sum ()
311+ self .unclipped_num_head += (
312+ per_sample_norms_head < self .head_clipbound * self .count_threshold
313+ ).sum ()
314+
315+ def add_noise (self ):
316+ super ().add_noise ()
317+
318+ unclipped_num_noise_backbone = _generate_noise (
319+ std = self .unclipped_num_std ,
320+ reference = self .unclipped_num_backbone ,
321+ generator = self .generator ,
322+ )
323+
324+ unclipped_num_noise_head = _generate_noise (
325+ std = self .unclipped_num_std ,
326+ reference = self .unclipped_num_head ,
327+ generator = self .generator ,
328+ )
329+
330+ self .unclipped_num_backbone = float (self .unclipped_num_backbone )
331+ self .unclipped_num_head = float (self .unclipped_num_head )
332+ self .unclipped_num_backbone += unclipped_num_noise_backbone
333+ self .unclipped_num_head += unclipped_num_noise_head
334+
335+ def update_clipbound (self ):
336+ """
337+ Update clipping bound based on unclipped fraction
338+ """
339+ unclipped_frac_backbone = self .unclipped_num_backbone / self .sample_size
340+ unclipped_frac_head = self .unclipped_num_head / self .sample_size
341+
342+ self .backbone_clipbound *= torch .exp (
343+ - self .clipbound_learning_rate
344+ * (unclipped_frac_backbone - self .target_unclipped_quantile )
345+ )
346+ self .head_clipbound *= torch .exp (
347+ - self .clipbound_learning_rate
348+ * (unclipped_frac_head - self .target_unclipped_quantile )
349+ )
350+
351+ # Ensure bounds are within min and max limits
352+ self .backbone_clipbound = torch .clamp (self .backbone_clipbound , self .min_clipbound , self .max_clipbound )
353+ self .head_clipbound = torch .clamp (self .head_clipbound , self .min_clipbound , self .max_clipbound )
354+
355+ print (f"!!! - self.backbone_clipbound: { self .backbone_clipbound } " )
356+ print (f"!!! - self.head_clipbound: { self .head_clipbound } " )
357+
358+ def pre_step (
359+ self , closure : Optional [Callable [[], float ]] = None
360+ ) -> Optional [float ]:
361+ pre_step_full = super ().pre_step ()
362+ if pre_step_full :
363+ self .update_clipbound ()
364+ return pre_step_full
365+
0 commit comments