-
Notifications
You must be signed in to change notification settings - Fork 9
Description
The epsilon value of 1e-12 used in the following lines for the first_step
and sam_train_step
functions is too low and can cause NaN errors with training with mixed precision:
e_w = gradients[i] * self.rho / (grad_norm + 1e-12)
I recommend modifying the value to be at least 1e-7
and to also include loss scaling for sam_train_step
such that it supports loss scale optimizers. Example implementation is:
def sam_train_step(self, data, rho=0.05, epsilon=1e-7):
# Unpack the data. Its structure depends on your model and
# on what you pass to `fit()`.
x, y = data
with tf.GradientTape() as tape:
y_pred = self(x, training=True) # Forward pass
# Compute the loss value
# (the loss function is configured in `compile()`)
loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)
scaled_loss = self.optimizer.get_scaled_loss(loss)
# Compute gradients
trainable_vars = self.trainable_variables
scaled_gradients = tape.gradient(scaled_loss, trainable_vars)
gradients = self.optimizer.get_unscaled_gradients(scaled_gradients)
# first step
e_ws = []
grad_norm = tf.linalg.global_norm(gradients)
for i in range(len(trainable_vars)):
e_w = gradients[i] * rho / (grad_norm + epsilon)
trainable_vars[i].assign_add(e_w)
e_ws.append(e_w)
with tf.GradientTape() as tape:
y_pred = self(x, training=True) # Forward pass
# Compute the loss value
# (the loss function is configured in `compile()`)
loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)
scaled_loss = self.optimizer.get_scaled_loss(loss)
trainable_vars = self.trainable_variables
scaled_gradients = tape.gradient(scaled_loss, trainable_vars)
gradients = self.optimizer.get_unscaled_gradients(scaled_gradients)
for i in range(len(trainable_vars)):
trainable_vars[i].assign_add(-e_ws[i])
self.optimizer.apply_gradients(zip(gradients, trainable_vars))
# Update metrics (includes the metric that tracks the loss)
self.compiled_metrics.update_state(y, y_pred)
# Return a dict mapping metric names to current value
return {m.name: m.result() for m in self.metrics}
And I also recommend suggesting that user use a loss scale optimizer with a low initial scale e.g. keras.mixed_precision.LossScaleOptimizer( optimizer, initial_scale=2 ** 2 )
. I have not submitted this is a pull request as I have yet to fully experiment with various architectures and hyperparameters, but this has proven effective at preventing NaN errors for ResNet and DenseNet style architectures when using mixed precision with my limited experimentation.