You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
block_info (BlockInfo): Information about the block, including methods to allocate tensors.
323
-
factor_matrix_dtype (torch.dtype): Data type for the factor matrices.
326
+
preconditioner_config (RootInvShampooPreconditionerConfig): Configuration for the preconditioner.
324
327
preconditioned_dims (tuple[int, ...]): Dimensions for which the factor matrices are preconditioned.
325
-
block_dtype (torch.dtype): Data type for the block.
326
328
327
329
Returns:
328
330
kronecker_factors_state (RootInvShampooKroneckerFactorsState): An instance of RootInvShampooKroneckerFactorsState with initialized inverse factor matrices.
), f"The number of dimensions to precondition ({sum(preconditioned_dims_selector)}) must match the number of preconditioners ({len(preconditioner_list)})."
@@ -410,6 +416,8 @@ class EigenvalueCorrectedShampooPreconditionerConfig(AmortizedPreconditionerConf
410
416
inverse_exponent_override (dict[int, float]): The inverse_exponent_override attribute is a dictionary that allows for customizing the inverse exponent used in eigenvalue correction.
411
417
The keys of the dictionary represent the order of the tensor, and the values are the exponent override values. For example, if we want to use a custom inverse exponent for 3-D tensors, we can set inverse_exponent_override as inverse_exponent_override={3: 0.25}.
412
418
Note that the inverse_exponent_override dictionary can contain multiple entries for different tensor orders. If the order of the tensor is not specified in the dictionary, the default exponent, 1/2, will be used. (Default: {})
419
+
factor_matrix_eigenvectors_dtype (torch.dtype): Data type for factor matrix eigenvectors. (Default: torch.float32)
420
+
corrected_eigenvalues_dtype (torch.dtype): Data type for corrected eigenvalues. (Default: torch.float32)
413
421
414
422
"""
415
423
@@ -418,6 +426,8 @@ class EigenvalueCorrectedShampooPreconditionerConfig(AmortizedPreconditionerConf
0 commit comments