Skip to content

Commit a2dfa3c

Browse files
authored
Aggressively trim test bloat (#346)
## Summary <!--- This is a required section; please describe the main purpose of this proposed code change. ---> <!--- ## Details This is an optional section; is there anything specific that reviewers should be aware of? ---> 1. Disable the test for experimental kernels 2. Reduce the size of tensor if the tests takes too long 3. Remove redundant tests that are testing the same thing Make sure unit test time < 5 mins ## Testing Done <!--- This is a required section; please describe how this change was tested. ---> <!-- Replace BLANK with your device type. For example, A100-80G-PCIe Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> - Hardware Type: <BLANK> - [ ] run `make test` to ensure correctness - [ ] run `make checkstyle` to ensure code style - [ ] run `make test-convergence` to ensure convergence
1 parent e68b291 commit a2dfa3c

13 files changed

+58
-283
lines changed

.github/workflows/ci.yml

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -59,28 +59,28 @@ jobs:
5959
run: |
6060
modal run dev.modal.unit_tests
6161
62-
convergence-tests:
63-
runs-on: ubuntu-latest
64-
needs: [checkstyle]
65-
66-
env:
67-
MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
68-
MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
69-
70-
steps:
71-
- name: Checkout code
72-
uses: actions/checkout@v3
73-
74-
- name: Set up Python
75-
uses: actions/setup-python@v3
76-
with:
77-
python-version: '3.10'
78-
79-
- name: Install dependencies
80-
run: |
81-
python -m pip install --upgrade pip
82-
pip install modal
83-
84-
- name: Run convergence tests
85-
run: |
86-
modal run dev.modal.conv_tests
62+
# convergence-tests:
63+
# runs-on: ubuntu-latest
64+
# needs: [checkstyle]
65+
66+
# env:
67+
# MODAL_TOKEN_ID: ${{ secrets.MODAL_TOKEN_ID }}
68+
# MODAL_TOKEN_SECRET: ${{ secrets.MODAL_TOKEN_SECRET }}
69+
70+
# steps:
71+
# - name: Checkout code
72+
# uses: actions/checkout@v3
73+
74+
# - name: Set up Python
75+
# uses: actions/setup-python@v3
76+
# with:
77+
# python-version: '3.10'
78+
79+
# - name: Install dependencies
80+
# run: |
81+
# python -m pip install --upgrade pip
82+
# pip install modal
83+
84+
# - name: Run convergence tests
85+
# run: |
86+
# modal run dev.modal.conv_tests

dev/modal/unit_tests.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
repo = modal.Mount.from_local_dir(ROOT_PATH, remote_path="/root/liger-kernel")
1515

1616

17-
@app.function(gpu="A10G", mounts=[repo], timeout=60 * 20)
17+
@app.function(gpu="A10G", mounts=[repo], timeout=60 * 5)
1818
def liger_unit_test():
1919
import subprocess
2020

test/transformers/test_cross_entropy.py

Lines changed: 5 additions & 160 deletions
Original file line numberDiff line numberDiff line change
@@ -170,26 +170,14 @@ def _test_correctness_functional(B, T, V, scalar, dtype, atol, rtol):
170170
@pytest.mark.parametrize(
171171
"B, T, V",
172172
[
173-
(2, 4096, 32000), # llama2, mistral
174-
(2, 4096, 32000), # llama2, mistral
175-
(1, 4096, 128256), # llama3
176-
# # weird shapes
177-
(3, 423, 32000),
173+
(2, 4096, 32000), # llama
174+
(3, 423, 32000), # weird shapes
178175
],
179176
)
180177
@pytest.mark.parametrize("reduction", ["sum", "mean"])
181178
@pytest.mark.parametrize(
182179
"scalar, dtype, atol, rtol",
183180
[
184-
pytest.param(
185-
0.1,
186-
torch.bfloat16,
187-
1e-8,
188-
5e-2,
189-
marks=pytest.mark.skipif(
190-
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
191-
),
192-
),
193181
pytest.param(
194182
1.0,
195183
torch.bfloat16,
@@ -199,24 +187,9 @@ def _test_correctness_functional(B, T, V, scalar, dtype, atol, rtol):
199187
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
200188
),
201189
),
202-
pytest.param(
203-
10.0,
204-
torch.bfloat16,
205-
1e-7,
206-
5e-2,
207-
marks=pytest.mark.skipif(
208-
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
209-
),
210-
),
211-
(0.1, torch.float32, 1e-8, 1e-6),
212190
(1.0, torch.float32, 1e-8, 1e-6),
213-
(10.0, torch.float32, 1e-8, 1e-6),
214191
],
215192
)
216-
@pytest.mark.skipif(
217-
torch.cuda.get_device_properties(0).total_memory < 16 * 1000 * 1000 * 1000,
218-
reason="Needs 16GB+ GPU memory.",
219-
)
220193
def test_correctness(B, T, V, scalar, dtype, reduction, atol, rtol):
221194
liger_ce = LigerCrossEntropyLoss(reduction=reduction)
222195
_test_correctness_once(liger_ce, B, T, V, reduction, scalar, dtype, atol, rtol)
@@ -233,12 +206,8 @@ def test_correctness(B, T, V, scalar, dtype, reduction, atol, rtol):
233206
@pytest.mark.parametrize(
234207
"scalar, dtype, atol, rtol",
235208
[
236-
(0.1, torch.bfloat16, 1e-8, 5e-2),
237209
(1.0, torch.bfloat16, 1e-8, 5e-2),
238-
(10.0, torch.bfloat16, 1e-7, 5e-2),
239-
(0.1, torch.float32, 1e-8, 1e-6),
240210
(1.0, torch.float32, 1e-8, 1e-6),
241-
(10.0, torch.float32, 1e-8, 1e-6),
242211
],
243212
)
244213
def test_correctness_functional(B, T, V, scalar, dtype, atol, rtol):
@@ -248,9 +217,7 @@ def test_correctness_functional(B, T, V, scalar, dtype, atol, rtol):
248217
@pytest.mark.parametrize(
249218
"B, T, V, ignore_index",
250219
[
251-
(2, 4096, 32000, -100), # llama2, mistral
252-
(2, 4096, 32000, 2), # llama2, mistral
253-
(1, 4096, 128256, -300), # llama3
220+
(2, 4096, 32000, 2),
254221
# weird shapes
255222
(3, 423, 32000, -123),
256223
],
@@ -259,15 +226,6 @@ def test_correctness_functional(B, T, V, scalar, dtype, atol, rtol):
259226
@pytest.mark.parametrize(
260227
"scalar, dtype, atol, rtol",
261228
[
262-
pytest.param(
263-
0.1,
264-
torch.bfloat16,
265-
1e-8,
266-
5e-2,
267-
marks=pytest.mark.skipif(
268-
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
269-
),
270-
),
271229
pytest.param(
272230
1.0,
273231
torch.bfloat16,
@@ -277,24 +235,9 @@ def test_correctness_functional(B, T, V, scalar, dtype, atol, rtol):
277235
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
278236
),
279237
),
280-
pytest.param(
281-
10.0,
282-
torch.bfloat16,
283-
1e-8,
284-
5e-2,
285-
marks=pytest.mark.skipif(
286-
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
287-
),
288-
),
289-
(0.1, torch.float32, 1e-8, 1e-6),
290238
(1.0, torch.float32, 1e-8, 1e-6),
291-
(10.0, torch.float32, 1e-8, 1e-6),
292239
],
293240
)
294-
@pytest.mark.skipif(
295-
torch.cuda.get_device_properties(0).total_memory < 16 * 1000 * 1000 * 1000,
296-
reason="Needs 16GB+ GPU memory.",
297-
)
298241
def test_correctness_with_ignore_index(
299242
B, T, V, ignore_index, reduction, scalar, dtype, atol, rtol
300243
):
@@ -307,25 +250,14 @@ def test_correctness_with_ignore_index(
307250
@pytest.mark.parametrize(
308251
"B, T, V, label_smoothing",
309252
[
310-
(2, 4096, 32000, 0.1), # llama2, mistral
311-
(2, 4096, 32000, 0.1), # llama2, mistral
312-
(1, 4096, 128256, 0.1), # llama3
253+
(2, 4096, 32000, 0.1),
313254
# weird shapes
314255
(3, 423, 32000, 0.1),
315256
],
316257
)
317258
@pytest.mark.parametrize(
318259
"scalar, dtype, atol, rtol",
319260
[
320-
pytest.param(
321-
0.1,
322-
torch.bfloat16,
323-
1e-8,
324-
5e-2,
325-
marks=pytest.mark.skipif(
326-
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
327-
),
328-
),
329261
pytest.param(
330262
1.0,
331263
torch.bfloat16,
@@ -335,24 +267,9 @@ def test_correctness_with_ignore_index(
335267
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
336268
),
337269
),
338-
pytest.param(
339-
10.0,
340-
torch.bfloat16,
341-
1e-8,
342-
5e-2,
343-
marks=pytest.mark.skipif(
344-
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
345-
),
346-
),
347-
(0.1, torch.float32, 1e-8, 1e-6),
348270
(1.0, torch.float32, 1e-8, 1e-6),
349-
(10.0, torch.float32, 1e-8, 1e-6),
350271
],
351272
)
352-
@pytest.mark.skipif(
353-
torch.cuda.get_device_properties(0).total_memory < 16 * 1000 * 1000 * 1000,
354-
reason="Needs 16GB+ GPU memory.",
355-
)
356273
def test_correctness_with_label_smoothing_once(
357274
B, T, V, label_smoothing, scalar, dtype, atol, rtol
358275
):
@@ -365,25 +282,14 @@ def test_correctness_with_label_smoothing_once(
365282
@pytest.mark.parametrize(
366283
"B, T, V, ignore_index, label_smoothing",
367284
[
368-
(2, 4096, 32000, 1, 0.1), # llama2, mistral
369-
(2, 4096, 32000, -100, 0.2), # llama2, mistral
370-
(1, 4096, 128256, 2, 0.1), # llama3
285+
(2, 4096, 32000, 1, 0.1),
371286
# weird shapes
372287
(3, 423, 32000, -300, 0.2),
373288
],
374289
)
375290
@pytest.mark.parametrize(
376291
"scalar, dtype, atol, rtol",
377292
[
378-
pytest.param(
379-
0.1,
380-
torch.bfloat16,
381-
1e-8,
382-
5e-2,
383-
marks=pytest.mark.skipif(
384-
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
385-
),
386-
),
387293
pytest.param(
388294
1.0,
389295
torch.bfloat16,
@@ -393,24 +299,9 @@ def test_correctness_with_label_smoothing_once(
393299
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
394300
),
395301
),
396-
pytest.param(
397-
10.0,
398-
torch.bfloat16,
399-
1e-6,
400-
5e-2,
401-
marks=pytest.mark.skipif(
402-
not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
403-
),
404-
),
405-
(0.1, torch.float32, 1e-8, 1e-6),
406302
(1.0, torch.float32, 1e-8, 1e-6),
407-
(10.0, torch.float32, 1e-8, 1e-6),
408303
],
409304
)
410-
@pytest.mark.skipif(
411-
torch.cuda.get_device_properties(0).total_memory < 16 * 1000 * 1000 * 1000,
412-
reason="Needs 16GB+ GPU memory.",
413-
)
414305
def test_correctness_with_label_smoothing_with_ignore_index_once(
415306
B, T, V, ignore_index, label_smoothing, scalar, dtype, atol, rtol
416307
):
@@ -427,8 +318,6 @@ def test_correctness_with_label_smoothing_with_ignore_index_once(
427318
"B, T, V",
428319
[
429320
(2, 4096, 32000), # llama2, mistral
430-
(2, 4096, 32000), # llama2, mistral
431-
(1, 4096, 128256), # llama3
432321
# # weird shapes
433322
(3, 423, 32000),
434323
],
@@ -449,52 +338,8 @@ def test_correctness_with_label_smoothing_with_ignore_index_once(
449338
(1.0, torch.float32, 1e-8, 1e-6),
450339
],
451340
)
452-
@pytest.mark.skipif(
453-
torch.cuda.get_device_properties(0).total_memory < 16 * 1000 * 1000 * 1000,
454-
reason="Needs 16GB+ GPU memory.",
455-
)
456341
def test_correctness_not_last_layer(B, T, V, reduction, scalar, dtype, atol, rtol):
457342
liger_ce = LigerCrossEntropyLoss(reduction=reduction)
458343
_test_correctness_not_last_layer_once(
459344
liger_ce, B, T, V, reduction, scalar, dtype, atol, rtol
460345
)
461-
462-
463-
#############################################################################
464-
# Test full pass of the liger cross entropy loss to ensure it doesn't crash
465-
#############################################################################
466-
467-
468-
def _full_pass_once(B, T, V, reduction):
469-
470-
liger_ce = LigerCrossEntropyLoss(reduction=reduction)
471-
472-
_input = torch.randn(
473-
B * T, V, requires_grad=True, device="cuda", dtype=torch.bfloat16
474-
)
475-
target = torch.randint(V, (B * T, 1), device="cuda").squeeze(1)
476-
477-
output = liger_ce(_input, target)
478-
output.backward()
479-
480-
481-
@pytest.mark.parametrize(
482-
"B, T, V",
483-
[
484-
(
485-
8,
486-
8192,
487-
128256,
488-
), # _input = 16GB, total = ~32GB, 8405385216 > 2,147,483,647, so we need int64
489-
(8, 16384, 128256), # _input = 32GB, total = ~64GB
490-
],
491-
)
492-
@pytest.mark.parametrize("reduction", ["sum", "mean"])
493-
@pytest.mark.skipif(
494-
torch.cuda.get_device_properties(0).total_memory < 64 * 1000 * 1000 * 1000,
495-
reason="Needs 64GB+ GPU memory.",
496-
)
497-
def test_large_no_exception(B, T, V, reduction):
498-
# The large inputs were hitting cuda illegal memory access because of
499-
# https://github.com/triton-lang/triton/issues/1058
500-
_full_pass_once(B, T, V, reduction)

test/transformers/test_embedding.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
SLEEP_SECONDS = 0.1
88

99

10+
@pytest.mark.skip(reason="LigerEmbedding is under experimentation")
1011
@pytest.mark.parametrize(
1112
"num_embeddings, embedding_dim, padding_idx",
1213
[

test/transformers/test_fused_linear_cross_entropy.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -86,12 +86,8 @@ def forward(self, x, y):
8686
@pytest.mark.parametrize(
8787
"B, T, H, V",
8888
[
89-
# (2, 4, 512, 512), # The test does not work on some CI GPUs. Issue #160
90-
(8, 2048, 4096, 32000), # llama2, mistral
91-
# Comment out to speed up testing
92-
# (4, 2048, 4096, 128256), # llama3 8B
93-
# (4, 1024, 8192, 128256), # llama3 70B
94-
(4, 423, 8192, 32000), # random shape
89+
(8, 128, 1024, 4096),
90+
(4, 47, 31, 123), # random shape
9591
],
9692
)
9793
@pytest.mark.parametrize(
@@ -233,12 +229,8 @@ def test_correctness_functional(B, T, H, V, scalar, dtype, bias, atol, rtol):
233229
@pytest.mark.parametrize(
234230
"B, T, H, V",
235231
[
236-
(2, 4, 512, 512), # The test does not work on some CI GPUs. Issue #160
237-
(8, 2048, 4096, 32000), # llama2, mistral
238-
# Comment out to speed up testing
239-
(4, 2048, 4096, 128256), # llama3 8B
240-
(4, 1024, 8192, 128256), # llama3 70B
241-
(4, 423, 8192, 32000), # random shape
232+
(8, 128, 1024, 4096),
233+
(4, 47, 31, 123), # random shape
242234
],
243235
)
244236
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)