Skip to content

Commit 0f9fca0

Browse files
tsunghsienleefacebook-github-bot
authored andcommitted
Add custom dtype supports for subclasses of ShampooPreconditionerConfig (#249)
Summary: Pull Request resolved: #249 1. Add `inv_factor_matrix_dtype` option in `RootInvShampooPreconditionerConfig`. 2. Add `factor_matrix_eigenvectors_dtype` and `factor_matrix_eigenvalues_dtype` options in `EigendecomposedShampooPreconditionerConfig`. 3. Add `factor_matrix_eigenvectors_dtype` and `corrected_eigenvalues_dtype` options in `EigenvalueCorrectedShampooPreconditionerConfig`. Reviewed By: runame Differential Revision: D82239916 fbshipit-source-id: 6848168df817eff044b6ca4042477a5392e16e63
1 parent c135ab9 commit 0f9fca0

File tree

3 files changed

+62
-26
lines changed

3 files changed

+62
-26
lines changed

distributed_shampoo/preconditioner/shampoo_preconditioner_list.py

Lines changed: 39 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,10 @@
3636
)
3737
from distributed_shampoo.shampoo_types import (
3838
AmortizedPreconditionerConfig,
39+
EigendecomposedShampooPreconditionerConfig,
40+
EigenvalueCorrectedShampooPreconditionerConfig,
3941
PreconditionerValueError,
42+
RootInvShampooPreconditionerConfig,
4043
)
4144
from distributed_shampoo.utils.dict_zip_iterator import DictZipIterator
4245
from distributed_shampoo.utils.optimizer_modules import OptimizerModule
@@ -320,31 +323,31 @@ def from_block(cls, **kwargs: Any) -> "RootInvShampooKroneckerFactorsState":
320323
321324
Args:
322325
block_info (BlockInfo): Information about the block, including methods to allocate tensors.
323-
factor_matrix_dtype (torch.dtype): Data type for the factor matrices.
326+
preconditioner_config (RootInvShampooPreconditionerConfig): Configuration for the preconditioner.
324327
preconditioned_dims (tuple[int, ...]): Dimensions for which the factor matrices are preconditioned.
325-
block_dtype (torch.dtype): Data type for the block.
326328
327329
Returns:
328330
kronecker_factors_state (RootInvShampooKroneckerFactorsState): An instance of RootInvShampooKroneckerFactorsState with initialized inverse factor matrices.
329331
"""
330332
block_info: BlockInfo = kwargs["block_info"]
331-
factor_matrix_dtype: torch.dtype = kwargs["factor_matrix_dtype"]
333+
preconditioner_config: RootInvShampooPreconditionerConfig = kwargs[
334+
"preconditioner_config"
335+
]
332336
preconditioned_dims: tuple[int, ...] = kwargs["preconditioned_dims"]
333-
block_dtype: torch.dtype = kwargs["block_dtype"]
334337

335338
return cls(
336339
**asdict(
337340
BaseShampooKroneckerFactorsState.from_block(
338341
block_info=block_info,
339-
factor_matrix_dtype=factor_matrix_dtype,
342+
factor_matrix_dtype=preconditioner_config.factor_matrix_dtype,
340343
preconditioned_dims=preconditioned_dims,
341344
)
342345
),
343346
# Initialize inv_factor_matrices as identity matrices.
344347
inv_factor_matrices=tuple(
345348
block_info.allocate_eye_tensor(
346349
n=dim,
347-
dtype=block_dtype,
350+
dtype=preconditioner_config.inv_factor_matrix_dtype,
348351
device=block_info.param.device,
349352
)
350353
for dim in preconditioned_dims
@@ -517,31 +520,31 @@ def from_block(cls, **kwargs: Any) -> "EigendecomposedShampooKroneckerFactorsSta
517520
518521
Args:
519522
block_info (BlockInfo): Information about the block, including methods to allocate tensors.
520-
factor_matrix_dtype (torch.dtype): Data type for the factor matrices.
523+
preconditioner_config (EigendecomposedShampooPreconditionerConfig): Configuration for the preconditioner.
521524
preconditioned_dims (tuple[int, ...]): Dimensions for which the factor matrices are preconditioned.
522-
block_dtype (torch.dtype): Data type for the block.
523525
524526
Returns:
525527
kronecker_factors_state (EigendecomposedShampooKroneckerFactorsState): An instance of EigendecomposedShampooKroneckerFactorsState.
526528
"""
527529
block_info: BlockInfo = kwargs["block_info"]
528-
factor_matrix_dtype: torch.dtype = kwargs["factor_matrix_dtype"]
530+
preconditioner_config: EigendecomposedShampooPreconditionerConfig = kwargs[
531+
"preconditioner_config"
532+
]
529533
preconditioned_dims: tuple[int, ...] = kwargs["preconditioned_dims"]
530-
block_dtype: torch.dtype = kwargs["block_dtype"]
531534

532535
return cls(
533536
**asdict(
534537
BaseShampooKroneckerFactorsState.from_block(
535538
block_info=block_info,
536-
factor_matrix_dtype=factor_matrix_dtype,
539+
factor_matrix_dtype=preconditioner_config.factor_matrix_dtype,
537540
preconditioned_dims=preconditioned_dims,
538541
)
539542
),
540543
# Initialize factor_matrices_eigenvectors as identity matrices.
541544
factor_matrices_eigenvectors=tuple(
542545
block_info.allocate_eye_tensor(
543546
n=dim,
544-
dtype=block_dtype,
547+
dtype=preconditioner_config.factor_matrix_eigenvectors_dtype,
545548
device=block_info.param.device,
546549
)
547550
for dim in preconditioned_dims
@@ -550,7 +553,7 @@ def from_block(cls, **kwargs: Any) -> "EigendecomposedShampooKroneckerFactorsSta
550553
factor_matrices_eigenvalues=tuple(
551554
block_info.allocate_ones_tensor(
552555
size=(dim,),
553-
dtype=block_dtype,
556+
dtype=preconditioner_config.factor_matrix_eigenvalues_dtype,
554557
device=block_info.param.device,
555558
)
556559
for dim in preconditioned_dims
@@ -760,41 +763,41 @@ def from_block(
760763
761764
Args:
762765
block_info (BlockInfo): Information about the block, including methods to allocate tensors.
763-
factor_matrix_dtype (torch.dtype): Data type for the factor matrices.
766+
preconditioner_config (EigenvalueCorrectedShampooPreconditionerConfig): Configuration for the preconditioner.
764767
preconditioned_dims (tuple[int, ...]): Dimensions for which the factor matrices are preconditioned.
765-
block_dtype (torch.dtype): Data type for the block.
766768
dims (tuple[int, ...]): Dimensions of the block.
767769
768770
Returns:
769771
kronecker_factors_state (EigenvalueCorrectedShampooKroneckerFactorsState): An instance of EigenvalueCorrectedShampooKroneckerFactorsState.
770772
"""
771773
block_info: BlockInfo = kwargs["block_info"]
772-
factor_matrix_dtype: torch.dtype = kwargs["factor_matrix_dtype"]
774+
preconditioner_config: EigenvalueCorrectedShampooPreconditionerConfig = kwargs[
775+
"preconditioner_config"
776+
]
773777
preconditioned_dims: tuple[int, ...] = kwargs["preconditioned_dims"]
774-
block_dtype: torch.dtype = kwargs["block_dtype"]
775778
dims: tuple[int, ...] = kwargs["dims"]
776779

777780
return EigenvalueCorrectedShampooKroneckerFactorsState(
778781
**asdict(
779782
BaseShampooKroneckerFactorsState.from_block(
780783
block_info=block_info,
781-
factor_matrix_dtype=factor_matrix_dtype,
784+
factor_matrix_dtype=preconditioner_config.factor_matrix_dtype,
782785
preconditioned_dims=preconditioned_dims,
783786
)
784787
),
785788
# Initialize factor_matrices_eigenvectors as identity matrices.
786789
factor_matrices_eigenvectors=tuple(
787790
block_info.allocate_eye_tensor(
788791
n=dim,
789-
dtype=block_dtype,
792+
dtype=preconditioner_config.factor_matrix_eigenvectors_dtype,
790793
device=block_info.param.device,
791794
)
792795
for dim in preconditioned_dims
793796
),
794797
corrected_eigenvalues=block_info.allocate_zeros_tensor(
795798
# Note that the corrected eigenvalues are not affected by the preconditioned_dims.
796799
size=tuple(dims),
797-
dtype=block_dtype,
800+
dtype=preconditioner_config.corrected_eigenvalues_dtype,
798801
device=block_info.param.device,
799802
),
800803
)
@@ -1133,9 +1136,8 @@ def _create_kronecker_factors_state(
11331136
)
11341137
block_state[SHAMPOO] = kronecker_factors_state_type.from_block(
11351138
block_info=block_info,
1136-
factor_matrix_dtype=self._preconditioner_config.factor_matrix_dtype,
1139+
preconditioner_config=self._preconditioner_config,
11371140
preconditioned_dims=preconditioned_dims,
1138-
block_dtype=block.dtype,
11391141
dims=dims,
11401142
)
11411143
kronecker_factors_unwrapped.append(
@@ -1316,17 +1318,30 @@ def _precondition_grad(
13161318
assert (
13171319
sum(preconditioned_dims_selector) == len(preconditioner_list)
13181320
), f"The number of dimensions to precondition ({sum(preconditioned_dims_selector)}) must match the number of preconditioners ({len(preconditioner_list)})."
1321+
1322+
# Extract all dtypes and assert they are unique
1323+
assert (
1324+
len(unique_dtypes := {p.dtype for p in preconditioner_list}) <= 1
1325+
), f"All preconditioners must have the same dtype, but found: {unique_dtypes}"
1326+
1327+
# Use the single dtype if preconditioners exist, otherwise use grad dtype
1328+
target_dtype = next(iter(unique_dtypes), grad.dtype)
13191329
preconditioner_list_iter = iter(preconditioner_list)
1330+
13201331
return reduce(
13211332
lambda grad, should_precondition: torch.tensordot(
1322-
grad, next(preconditioner_list_iter), dims=dims
1333+
# Use the single target dtype for all operations
1334+
grad.to(dtype=target_dtype),
1335+
# Use the actual iterator for the operation
1336+
next(preconditioner_list_iter),
1337+
dims=dims,
13231338
)
13241339
if should_precondition
13251340
# Perform a left rotation on grad if not preconditioned.
13261341
else grad.permute(*range(1, grad.ndim), 0),
13271342
preconditioned_dims_selector,
13281343
grad,
1329-
)
1344+
).to(dtype=grad.dtype)
13301345

13311346
@overload
13321347
@staticmethod

distributed_shampoo/preconditioner/tests/shampoo_preconditioner_list_test.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1033,7 +1033,11 @@ def _amortized_computation_properties(self) -> AmortizedComputationProperties:
10331033

10341034
@property
10351035
def _default_preconditioner_config(self) -> RootInvShampooPreconditionerConfig:
1036-
return replace(DefaultShampooConfig, factor_matrix_dtype=torch.float64)
1036+
return replace(
1037+
DefaultShampooConfig,
1038+
factor_matrix_dtype=torch.float64,
1039+
inv_factor_matrix_dtype=torch.float64,
1040+
)
10371041

10381042
@property
10391043
def _preconditioner_list_factory(self) -> Callable[..., PreconditionerList]:
@@ -1059,6 +1063,8 @@ def _default_preconditioner_config( # type: ignore[override]
10591063
return EigendecomposedShampooPreconditionerConfig(
10601064
amortized_computation_config=QREigendecompositionConfig(),
10611065
factor_matrix_dtype=torch.float64,
1066+
factor_matrix_eigenvectors_dtype=torch.float64,
1067+
factor_matrix_eigenvalues_dtype=torch.float64,
10621068
)
10631069

10641070
@property
@@ -1077,7 +1083,12 @@ def _amortized_computation_properties(self) -> AmortizedComputationProperties:
10771083
def _default_preconditioner_config(
10781084
self,
10791085
) -> EigenvalueCorrectedShampooPreconditionerConfig:
1080-
return replace(DefaultSOAPConfig, factor_matrix_dtype=torch.float64)
1086+
return replace(
1087+
DefaultSOAPConfig,
1088+
factor_matrix_dtype=torch.float64,
1089+
factor_matrix_eigenvectors_dtype=torch.float64,
1090+
corrected_eigenvalues_dtype=torch.float64,
1091+
)
10811092

10821093
@property
10831094
def _preconditioner_list_factory(self) -> Callable[..., PreconditionerList]:

distributed_shampoo/shampoo_types.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -298,13 +298,15 @@ class RootInvShampooPreconditionerConfig(ShampooPreconditionerConfig):
298298
| (^0.1667), the default inverse exponent 1/(2*3) since inverse_exponent_override[3][2] is not specified
299299
|
300300
no preconditioning since inverse_exponent_override[3][0]=0.0
301+
inv_factor_matrix_dtype (torch.dtype): Data type for inverse factor matrix. (Default: torch.float32)
301302
302303
303304
"""
304305

305306
amortized_computation_config: RootInvConfig = field(
306307
default_factory=lambda: DefaultEigenConfig
307308
)
309+
inv_factor_matrix_dtype: torch.dtype = torch.float32
308310

309311

310312
DefaultShampooConfig = RootInvShampooPreconditionerConfig()
@@ -353,13 +355,17 @@ class EigendecomposedShampooPreconditionerConfig(ShampooPreconditionerConfig):
353355
| (^0.1667), the default inverse exponent 1/(2*3) since inverse_exponent_override[3][2] is not specified
354356
|
355357
no preconditioning since inverse_exponent_override[3][0]=0.0
358+
factor_matrix_eigenvectors_dtype (torch.dtype): Data type for factor matrix eigenvectors. (Default: torch.float32)
359+
factor_matrix_eigenvalues_dtype (torch.dtype): Data type for factor matrix eigenvalues. (Default: torch.float32)
356360
357361
358362
"""
359363

360364
amortized_computation_config: EigendecompositionConfig = field(
361365
default_factory=lambda: DefaultEigendecompositionConfig
362366
)
367+
factor_matrix_eigenvectors_dtype: torch.dtype = torch.float32
368+
factor_matrix_eigenvalues_dtype: torch.dtype = torch.float32
363369

364370

365371
@dataclass(kw_only=True)
@@ -410,6 +416,8 @@ class EigenvalueCorrectedShampooPreconditionerConfig(AmortizedPreconditionerConf
410416
inverse_exponent_override (dict[int, float]): The inverse_exponent_override attribute is a dictionary that allows for customizing the inverse exponent used in eigenvalue correction.
411417
The keys of the dictionary represent the order of the tensor, and the values are the exponent override values. For example, if we want to use a custom inverse exponent for 3-D tensors, we can set inverse_exponent_override as inverse_exponent_override={3: 0.25}.
412418
Note that the inverse_exponent_override dictionary can contain multiple entries for different tensor orders. If the order of the tensor is not specified in the dictionary, the default exponent, 1/2, will be used. (Default: {})
419+
factor_matrix_eigenvectors_dtype (torch.dtype): Data type for factor matrix eigenvectors. (Default: torch.float32)
420+
corrected_eigenvalues_dtype (torch.dtype): Data type for corrected eigenvalues. (Default: torch.float32)
413421
414422
"""
415423

@@ -418,6 +426,8 @@ class EigenvalueCorrectedShampooPreconditionerConfig(AmortizedPreconditionerConf
418426
)
419427
ignored_basis_change_dims: dict[int, list[int]] = field(default_factory=dict)
420428
inverse_exponent_override: dict[int, float] = field(default_factory=dict)
429+
factor_matrix_eigenvectors_dtype: torch.dtype = torch.float32
430+
corrected_eigenvalues_dtype: torch.dtype = torch.float32
421431

422432
def __post_init__(self) -> None:
423433
super().__post_init__()

0 commit comments

Comments
 (0)