Skip to content

Commit a26786b

Browse files
wz337meta-codesync[bot]
authored andcommitted
Add a test to validate that _local_masked_blocked_params total elements match the original tensor when param_assignment_strategy=FSDPParamAssignmentStrategy.REPLICATE in FullyShardLosslessDistributor
Summary: Add a test to validate that `_local_masked_blocked_params` total elements match the original tensor when `param_assignment_strategy=FSDPParamAssignmentStrategy.REPLICATE` in `FullyShardLosslessDistributor` since we are replicating the sharded tensor into a replicate tensor onto each rank. Reviewed By: hjmshi Differential Revision: D86479219 fbshipit-source-id: a4530afb6c21c766485768523430a81ab9e9c383
1 parent 36df9ff commit a26786b

File tree

1 file changed

+39
-0
lines changed

1 file changed

+39
-0
lines changed

distributed_shampoo/distributor/gpu_tests/shampoo_fully_shard_lossless_distributor_test.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838

3939
from torch import distributed as dist, nn
4040
from torch.distributed.fsdp import FSDPModule, fully_shard
41+
from torch.distributed.tensor import distribute_tensor, init_device_mesh, Shard
4142
from torch.optim.optimizer import ParamsT
4243
from torch.testing._internal.common_distributed import skip_if_lt_x_gpu
4344
from torch.testing._internal.common_utils import (
@@ -124,6 +125,44 @@ def _shampoo_optim_factory(
124125
distributed_config=distributed_config,
125126
)
126127

128+
@with_comms
129+
@skip_if_lt_x_gpu(2)
130+
def test_init_assign_full_parameters(self) -> None:
131+
"""Test that FullyShardLosslessDistributor initializes correctly and blocked params match original tensor size."""
132+
# TODO: Figure out a better way to test this because the access of private field.
133+
dummy_param = torch.randn(100, 100, device="cuda")
134+
original_numel = dummy_param.numel()
135+
136+
mesh = init_device_mesh("cuda", (self.world_size,))
137+
dummy_dtensor_param = distribute_tensor(dummy_param, mesh, [Shard(0)])
138+
dummy_param_group = {
139+
"params": [dummy_dtensor_param],
140+
"distributed_config": FullyShardDistributedConfig(
141+
param_assignment_strategy=FSDPParamAssignmentStrategy.REPLICATE
142+
),
143+
"max_preconditioner_dim": PRECONDITIONER_DIM,
144+
}
145+
146+
distributor = FullyShardLosslessDistributor(dummy_param_group)
147+
148+
# Validate that _assigned_full_params is initialized
149+
self.assertIsNotNone(distributor._assigned_full_params)
150+
self.assertEqual(len(distributor._assigned_full_params), 1)
151+
self.assertEqual(distributor._assigned_full_params[0].shape, (100, 100))
152+
153+
# Validate that _local_masked_blocked_params total elements match the original tensor
154+
# Note: _local_masked_blocked_params will be empty initially since there are no gradients yet
155+
# So we check _global_blocked_params instead
156+
total_blocked_elements = sum(
157+
block.numel() for block in distributor._global_blocked_params
158+
)
159+
self.assertEqual(
160+
total_blocked_elements,
161+
original_numel,
162+
f"Total elements in blocked params ({total_blocked_elements}) "
163+
f"should match original tensor ({original_numel})",
164+
)
165+
127166
@with_comms
128167
@skip_if_lt_x_gpu(2)
129168
@parametrize("model_linear_layers_dims", TEST_MODEL_LAYER_DIMS)

0 commit comments

Comments
 (0)