Skip to content

Commit 2460746

Browse files
GdoongMathewBordaSkafteNicki
authored
fix mis-alignment column while using rich model summary in DeepSpeedstrategy (#21100)
* fix mis-alignment column while using rich model summary in DeepSpeed strategy. * test: add minimum gpu requirement in `test_deepspeed_summary_with_rich_model_summary` * chlog --------- Co-authored-by: Jirka B <[email protected]> Co-authored-by: Nicki Skafte Detlefsen <[email protected]>
1 parent da88f5a commit 2460746

File tree

4 files changed

+47
-2
lines changed

4 files changed

+47
-2
lines changed

src/lightning/pytorch/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3737
- Fixed learning rate not being correctly set after using `LearningRateFinder` callback ([#21068](https://github.com/Lightning-AI/pytorch-lightning/pull/21068))
3838

3939

40+
- Fixed misalignment column while using rich model summary in `DeepSpeedstrategy` ([#21100](https://github.com/Lightning-AI/pytorch-lightning/pull/21100))
41+
4042
---
4143

4244
## [2.5.3] - 2025-08-13

src/lightning/pytorch/callbacks/rich_model_summary.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,18 +79,21 @@ def summarize(
7979
from rich.table import Table
8080

8181
console = get_console()
82+
column_names = list(zip(*summary_data))[0]
8283

8384
header_style: str = summarize_kwargs.get("header_style", "bold magenta")
8485
table = Table(header_style=header_style)
8586
table.add_column(" ", style="dim")
8687
table.add_column("Name", justify="left", no_wrap=True)
8788
table.add_column("Type")
8889
table.add_column("Params", justify="right")
90+
91+
if "Params per Device" in column_names:
92+
table.add_column("Params per Device", justify="right")
93+
8994
table.add_column("Mode")
9095
table.add_column("FLOPs", justify="right")
9196

92-
column_names = list(zip(*summary_data))[0]
93-
9497
for column_name in ["In sizes", "Out sizes"]:
9598
if column_name in column_names:
9699
table.add_column(column_name, justify="right", style="white")

src/lightning/pytorch/utilities/model_summary/model_summary_deepspeed.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,7 @@ def _get_summary_data(self) -> list[tuple[str, list[str]]]:
9999
("Params", list(map(get_human_readable_count, self.param_nums))),
100100
("Params per Device", list(map(get_human_readable_count, self.parameters_per_layer))),
101101
("Mode", ["train" if mode else "eval" for mode in self.training_modes]),
102+
("FLOPs", list(map(get_human_readable_count, (sum(x.values()) for x in self.flop_counts.values())))),
102103
]
103104
if self._model.example_input_array is not None:
104105
arrays.append(("In sizes", [str(x) for x in self.in_sizes]))

tests/tests_pytorch/utilities/test_deepspeed_model_summary.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from unittest import mock
16+
17+
import torch
18+
1519
import lightning.pytorch as pl
1620
from lightning.pytorch import Callback, Trainer
1721
from lightning.pytorch.demos.boring_classes import BoringModel
@@ -51,3 +55,38 @@ def on_fit_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
5155
)
5256

5357
trainer.fit(model)
58+
59+
60+
@RunIf(min_cuda_gpus=1, deepspeed=True, rich=True)
61+
@mock.patch("rich.table.Table.add_row", autospec=True)
62+
def test_deepspeed_summary_with_rich_model_summary(mock_table_add_row, tmp_path):
63+
from lightning.pytorch.callbacks import RichModelSummary
64+
65+
model = BoringModel()
66+
model.example_input_array = torch.randn(4, 32)
67+
68+
trainer = Trainer(
69+
strategy=DeepSpeedStrategy(stage=3),
70+
default_root_dir=tmp_path,
71+
accelerator="gpu",
72+
fast_dev_run=True,
73+
devices=1,
74+
enable_model_summary=True,
75+
callbacks=[RichModelSummary()],
76+
)
77+
78+
trainer.fit(model)
79+
80+
# assert that the input summary data was converted correctly
81+
args, _ = mock_table_add_row.call_args_list[0]
82+
assert args[1:] == (
83+
"0",
84+
"layer",
85+
"Linear",
86+
"66 ",
87+
"66 ",
88+
"train",
89+
"512 ",
90+
"[4, 32]",
91+
"[4, 2]",
92+
)

0 commit comments

Comments
 (0)