Skip to content

Epsilon value of 1e-12 too small for mixed precision #5

@Avelina9X

Description

@Avelina9X

The epsilon value of 1e-12 used in the following lines for the first_step and sam_train_step functions is too low and can cause NaN errors with training with mixed precision:
e_w = gradients[i] * self.rho / (grad_norm + 1e-12)

I recommend modifying the value to be at least 1e-7 and to also include loss scaling for sam_train_step such that it supports loss scale optimizers. Example implementation is:

def sam_train_step(self, data, rho=0.05, epsilon=1e-7):
    # Unpack the data. Its structure depends on your model and
    # on what you pass to `fit()`.
    x, y = data

    with tf.GradientTape() as tape:
        y_pred = self(x, training=True)  # Forward pass
        # Compute the loss value
        # (the loss function is configured in `compile()`)
        loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)
        scaled_loss = self.optimizer.get_scaled_loss(loss)

    # Compute gradients
    trainable_vars = self.trainable_variables
    scaled_gradients  = tape.gradient(scaled_loss, trainable_vars)
    gradients = self.optimizer.get_unscaled_gradients(scaled_gradients)

    # first step
    e_ws = []
    grad_norm = tf.linalg.global_norm(gradients)
    for i in range(len(trainable_vars)):
        e_w = gradients[i] * rho / (grad_norm + epsilon)
        trainable_vars[i].assign_add(e_w)
        e_ws.append(e_w)

    with tf.GradientTape() as tape:
        y_pred = self(x, training=True)  # Forward pass
        # Compute the loss value
        # (the loss function is configured in `compile()`)
        loss = self.compiled_loss(y, y_pred, regularization_losses=self.losses)
        scaled_loss = self.optimizer.get_scaled_loss(loss)
    trainable_vars = self.trainable_variables
    scaled_gradients  = tape.gradient(scaled_loss, trainable_vars)
    gradients = self.optimizer.get_unscaled_gradients(scaled_gradients)

    for i in range(len(trainable_vars)):
        trainable_vars[i].assign_add(-e_ws[i])
    self.optimizer.apply_gradients(zip(gradients, trainable_vars))
    # Update metrics (includes the metric that tracks the loss)
    self.compiled_metrics.update_state(y, y_pred)
    # Return a dict mapping metric names to current value
    return {m.name: m.result() for m in self.metrics}

And I also recommend suggesting that user use a loss scale optimizer with a low initial scale e.g. keras.mixed_precision.LossScaleOptimizer( optimizer, initial_scale=2 ** 2 ) . I have not submitted this is a pull request as I have yet to fully experiment with various architectures and hyperparameters, but this has proven effective at preventing NaN errors for ResNet and DenseNet style architectures when using mixed precision with my limited experimentation.

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions