|
5 | 5 | from unittest.mock import patch |
6 | 6 |
|
7 | 7 | import torch |
| 8 | +import torch.nn as nn |
8 | 9 | import torch.distributed.checkpoint as dcp |
9 | | -import torchvision.models as models |
10 | 10 |
|
11 | 11 | from s3torchconnector import S3ReaderConstructor |
12 | 12 | from s3torchconnector.dcp import S3StorageWriter, S3StorageReader |
13 | 13 | from s3torchconnector.s3reader.sequential import SequentialS3Reader |
14 | 14 |
|
15 | 15 |
|
16 | | -@pytest.mark.parametrize( |
17 | | - "model", |
18 | | - [ |
19 | | - torch.nn.Sequential( |
20 | | - torch.nn.Linear(5, 5), |
21 | | - torch.nn.Linear(20, 20), |
22 | | - torch.nn.Linear(10, 10), |
23 | | - ), |
24 | | - models.resnet18(pretrained=False), |
25 | | - ], |
| 16 | +SIMPLE_MODEL = torch.nn.Sequential( |
| 17 | + nn.Linear(5, 5), |
| 18 | + nn.Linear(20, 20), |
| 19 | + nn.Linear(10, 10), |
26 | 20 | ) |
27 | | -def test_prepare_local_plan_sorts_by_storage_offset(checkpoint_directory, model): |
| 21 | + |
| 22 | + |
| 23 | +class NeuralNetwork(nn.Module): |
| 24 | + """NeuralNetwork from PyTorch quickstart tutorial.""" |
| 25 | + |
| 26 | + def __init__(self): |
| 27 | + super().__init__() |
| 28 | + self.flatten = nn.Flatten() |
| 29 | + self.linear_relu_stack = nn.Sequential( |
| 30 | + nn.Linear(28 * 28, 512), |
| 31 | + nn.ReLU(), |
| 32 | + nn.Linear(512, 512), |
| 33 | + nn.ReLU(), |
| 34 | + nn.Linear(512, 10), |
| 35 | + ) |
| 36 | + |
| 37 | + |
| 38 | +LARGER_MODEL = NeuralNetwork() |
| 39 | + |
| 40 | + |
| 41 | +@pytest.mark.parametrize("model", [SIMPLE_MODEL, LARGER_MODEL]) |
| 42 | +def test_dcp_load_reads_tensors_in_sequential_order(checkpoint_directory, model): |
28 | 43 | """ |
29 | 44 | Test that prepare_local_plan allows dcp.load() to read items in offset order. |
30 | 45 |
|
|
0 commit comments