@@ -170,26 +170,14 @@ def _test_correctness_functional(B, T, V, scalar, dtype, atol, rtol):
170
170
@pytest .mark .parametrize (
171
171
"B, T, V" ,
172
172
[
173
- (2 , 4096 , 32000 ), # llama2, mistral
174
- (2 , 4096 , 32000 ), # llama2, mistral
175
- (1 , 4096 , 128256 ), # llama3
176
- # # weird shapes
177
- (3 , 423 , 32000 ),
173
+ (2 , 4096 , 32000 ), # llama
174
+ (3 , 423 , 32000 ), # weird shapes
178
175
],
179
176
)
180
177
@pytest .mark .parametrize ("reduction" , ["sum" , "mean" ])
181
178
@pytest .mark .parametrize (
182
179
"scalar, dtype, atol, rtol" ,
183
180
[
184
- pytest .param (
185
- 0.1 ,
186
- torch .bfloat16 ,
187
- 1e-8 ,
188
- 5e-2 ,
189
- marks = pytest .mark .skipif (
190
- not supports_bfloat16 (), reason = "bfloat16 not supported on this GPU"
191
- ),
192
- ),
193
181
pytest .param (
194
182
1.0 ,
195
183
torch .bfloat16 ,
@@ -199,24 +187,9 @@ def _test_correctness_functional(B, T, V, scalar, dtype, atol, rtol):
199
187
not supports_bfloat16 (), reason = "bfloat16 not supported on this GPU"
200
188
),
201
189
),
202
- pytest .param (
203
- 10.0 ,
204
- torch .bfloat16 ,
205
- 1e-7 ,
206
- 5e-2 ,
207
- marks = pytest .mark .skipif (
208
- not supports_bfloat16 (), reason = "bfloat16 not supported on this GPU"
209
- ),
210
- ),
211
- (0.1 , torch .float32 , 1e-8 , 1e-6 ),
212
190
(1.0 , torch .float32 , 1e-8 , 1e-6 ),
213
- (10.0 , torch .float32 , 1e-8 , 1e-6 ),
214
191
],
215
192
)
216
- @pytest .mark .skipif (
217
- torch .cuda .get_device_properties (0 ).total_memory < 16 * 1000 * 1000 * 1000 ,
218
- reason = "Needs 16GB+ GPU memory." ,
219
- )
220
193
def test_correctness (B , T , V , scalar , dtype , reduction , atol , rtol ):
221
194
liger_ce = LigerCrossEntropyLoss (reduction = reduction )
222
195
_test_correctness_once (liger_ce , B , T , V , reduction , scalar , dtype , atol , rtol )
@@ -233,12 +206,8 @@ def test_correctness(B, T, V, scalar, dtype, reduction, atol, rtol):
233
206
@pytest .mark .parametrize (
234
207
"scalar, dtype, atol, rtol" ,
235
208
[
236
- (0.1 , torch .bfloat16 , 1e-8 , 5e-2 ),
237
209
(1.0 , torch .bfloat16 , 1e-8 , 5e-2 ),
238
- (10.0 , torch .bfloat16 , 1e-7 , 5e-2 ),
239
- (0.1 , torch .float32 , 1e-8 , 1e-6 ),
240
210
(1.0 , torch .float32 , 1e-8 , 1e-6 ),
241
- (10.0 , torch .float32 , 1e-8 , 1e-6 ),
242
211
],
243
212
)
244
213
def test_correctness_functional (B , T , V , scalar , dtype , atol , rtol ):
@@ -248,9 +217,7 @@ def test_correctness_functional(B, T, V, scalar, dtype, atol, rtol):
248
217
@pytest .mark .parametrize (
249
218
"B, T, V, ignore_index" ,
250
219
[
251
- (2 , 4096 , 32000 , - 100 ), # llama2, mistral
252
- (2 , 4096 , 32000 , 2 ), # llama2, mistral
253
- (1 , 4096 , 128256 , - 300 ), # llama3
220
+ (2 , 4096 , 32000 , 2 ),
254
221
# weird shapes
255
222
(3 , 423 , 32000 , - 123 ),
256
223
],
@@ -259,15 +226,6 @@ def test_correctness_functional(B, T, V, scalar, dtype, atol, rtol):
259
226
@pytest .mark .parametrize (
260
227
"scalar, dtype, atol, rtol" ,
261
228
[
262
- pytest .param (
263
- 0.1 ,
264
- torch .bfloat16 ,
265
- 1e-8 ,
266
- 5e-2 ,
267
- marks = pytest .mark .skipif (
268
- not supports_bfloat16 (), reason = "bfloat16 not supported on this GPU"
269
- ),
270
- ),
271
229
pytest .param (
272
230
1.0 ,
273
231
torch .bfloat16 ,
@@ -277,24 +235,9 @@ def test_correctness_functional(B, T, V, scalar, dtype, atol, rtol):
277
235
not supports_bfloat16 (), reason = "bfloat16 not supported on this GPU"
278
236
),
279
237
),
280
- pytest .param (
281
- 10.0 ,
282
- torch .bfloat16 ,
283
- 1e-8 ,
284
- 5e-2 ,
285
- marks = pytest .mark .skipif (
286
- not supports_bfloat16 (), reason = "bfloat16 not supported on this GPU"
287
- ),
288
- ),
289
- (0.1 , torch .float32 , 1e-8 , 1e-6 ),
290
238
(1.0 , torch .float32 , 1e-8 , 1e-6 ),
291
- (10.0 , torch .float32 , 1e-8 , 1e-6 ),
292
239
],
293
240
)
294
- @pytest .mark .skipif (
295
- torch .cuda .get_device_properties (0 ).total_memory < 16 * 1000 * 1000 * 1000 ,
296
- reason = "Needs 16GB+ GPU memory." ,
297
- )
298
241
def test_correctness_with_ignore_index (
299
242
B , T , V , ignore_index , reduction , scalar , dtype , atol , rtol
300
243
):
@@ -307,25 +250,14 @@ def test_correctness_with_ignore_index(
307
250
@pytest .mark .parametrize (
308
251
"B, T, V, label_smoothing" ,
309
252
[
310
- (2 , 4096 , 32000 , 0.1 ), # llama2, mistral
311
- (2 , 4096 , 32000 , 0.1 ), # llama2, mistral
312
- (1 , 4096 , 128256 , 0.1 ), # llama3
253
+ (2 , 4096 , 32000 , 0.1 ),
313
254
# weird shapes
314
255
(3 , 423 , 32000 , 0.1 ),
315
256
],
316
257
)
317
258
@pytest .mark .parametrize (
318
259
"scalar, dtype, atol, rtol" ,
319
260
[
320
- pytest .param (
321
- 0.1 ,
322
- torch .bfloat16 ,
323
- 1e-8 ,
324
- 5e-2 ,
325
- marks = pytest .mark .skipif (
326
- not supports_bfloat16 (), reason = "bfloat16 not supported on this GPU"
327
- ),
328
- ),
329
261
pytest .param (
330
262
1.0 ,
331
263
torch .bfloat16 ,
@@ -335,24 +267,9 @@ def test_correctness_with_ignore_index(
335
267
not supports_bfloat16 (), reason = "bfloat16 not supported on this GPU"
336
268
),
337
269
),
338
- pytest .param (
339
- 10.0 ,
340
- torch .bfloat16 ,
341
- 1e-8 ,
342
- 5e-2 ,
343
- marks = pytest .mark .skipif (
344
- not supports_bfloat16 (), reason = "bfloat16 not supported on this GPU"
345
- ),
346
- ),
347
- (0.1 , torch .float32 , 1e-8 , 1e-6 ),
348
270
(1.0 , torch .float32 , 1e-8 , 1e-6 ),
349
- (10.0 , torch .float32 , 1e-8 , 1e-6 ),
350
271
],
351
272
)
352
- @pytest .mark .skipif (
353
- torch .cuda .get_device_properties (0 ).total_memory < 16 * 1000 * 1000 * 1000 ,
354
- reason = "Needs 16GB+ GPU memory." ,
355
- )
356
273
def test_correctness_with_label_smoothing_once (
357
274
B , T , V , label_smoothing , scalar , dtype , atol , rtol
358
275
):
@@ -365,25 +282,14 @@ def test_correctness_with_label_smoothing_once(
365
282
@pytest .mark .parametrize (
366
283
"B, T, V, ignore_index, label_smoothing" ,
367
284
[
368
- (2 , 4096 , 32000 , 1 , 0.1 ), # llama2, mistral
369
- (2 , 4096 , 32000 , - 100 , 0.2 ), # llama2, mistral
370
- (1 , 4096 , 128256 , 2 , 0.1 ), # llama3
285
+ (2 , 4096 , 32000 , 1 , 0.1 ),
371
286
# weird shapes
372
287
(3 , 423 , 32000 , - 300 , 0.2 ),
373
288
],
374
289
)
375
290
@pytest .mark .parametrize (
376
291
"scalar, dtype, atol, rtol" ,
377
292
[
378
- pytest .param (
379
- 0.1 ,
380
- torch .bfloat16 ,
381
- 1e-8 ,
382
- 5e-2 ,
383
- marks = pytest .mark .skipif (
384
- not supports_bfloat16 (), reason = "bfloat16 not supported on this GPU"
385
- ),
386
- ),
387
293
pytest .param (
388
294
1.0 ,
389
295
torch .bfloat16 ,
@@ -393,24 +299,9 @@ def test_correctness_with_label_smoothing_once(
393
299
not supports_bfloat16 (), reason = "bfloat16 not supported on this GPU"
394
300
),
395
301
),
396
- pytest .param (
397
- 10.0 ,
398
- torch .bfloat16 ,
399
- 1e-6 ,
400
- 5e-2 ,
401
- marks = pytest .mark .skipif (
402
- not supports_bfloat16 (), reason = "bfloat16 not supported on this GPU"
403
- ),
404
- ),
405
- (0.1 , torch .float32 , 1e-8 , 1e-6 ),
406
302
(1.0 , torch .float32 , 1e-8 , 1e-6 ),
407
- (10.0 , torch .float32 , 1e-8 , 1e-6 ),
408
303
],
409
304
)
410
- @pytest .mark .skipif (
411
- torch .cuda .get_device_properties (0 ).total_memory < 16 * 1000 * 1000 * 1000 ,
412
- reason = "Needs 16GB+ GPU memory." ,
413
- )
414
305
def test_correctness_with_label_smoothing_with_ignore_index_once (
415
306
B , T , V , ignore_index , label_smoothing , scalar , dtype , atol , rtol
416
307
):
@@ -427,8 +318,6 @@ def test_correctness_with_label_smoothing_with_ignore_index_once(
427
318
"B, T, V" ,
428
319
[
429
320
(2 , 4096 , 32000 ), # llama2, mistral
430
- (2 , 4096 , 32000 ), # llama2, mistral
431
- (1 , 4096 , 128256 ), # llama3
432
321
# # weird shapes
433
322
(3 , 423 , 32000 ),
434
323
],
@@ -449,52 +338,8 @@ def test_correctness_with_label_smoothing_with_ignore_index_once(
449
338
(1.0 , torch .float32 , 1e-8 , 1e-6 ),
450
339
],
451
340
)
452
- @pytest .mark .skipif (
453
- torch .cuda .get_device_properties (0 ).total_memory < 16 * 1000 * 1000 * 1000 ,
454
- reason = "Needs 16GB+ GPU memory." ,
455
- )
456
341
def test_correctness_not_last_layer (B , T , V , reduction , scalar , dtype , atol , rtol ):
457
342
liger_ce = LigerCrossEntropyLoss (reduction = reduction )
458
343
_test_correctness_not_last_layer_once (
459
344
liger_ce , B , T , V , reduction , scalar , dtype , atol , rtol
460
345
)
461
-
462
-
463
- #############################################################################
464
- # Test full pass of the liger cross entropy loss to ensure it doesn't crash
465
- #############################################################################
466
-
467
-
468
- def _full_pass_once (B , T , V , reduction ):
469
-
470
- liger_ce = LigerCrossEntropyLoss (reduction = reduction )
471
-
472
- _input = torch .randn (
473
- B * T , V , requires_grad = True , device = "cuda" , dtype = torch .bfloat16
474
- )
475
- target = torch .randint (V , (B * T , 1 ), device = "cuda" ).squeeze (1 )
476
-
477
- output = liger_ce (_input , target )
478
- output .backward ()
479
-
480
-
481
- @pytest .mark .parametrize (
482
- "B, T, V" ,
483
- [
484
- (
485
- 8 ,
486
- 8192 ,
487
- 128256 ,
488
- ), # _input = 16GB, total = ~32GB, 8405385216 > 2,147,483,647, so we need int64
489
- (8 , 16384 , 128256 ), # _input = 32GB, total = ~64GB
490
- ],
491
- )
492
- @pytest .mark .parametrize ("reduction" , ["sum" , "mean" ])
493
- @pytest .mark .skipif (
494
- torch .cuda .get_device_properties (0 ).total_memory < 64 * 1000 * 1000 * 1000 ,
495
- reason = "Needs 64GB+ GPU memory." ,
496
- )
497
- def test_large_no_exception (B , T , V , reduction ):
498
- # The large inputs were hitting cuda illegal memory access because of
499
- # https://github.com/triton-lang/triton/issues/1058
500
- _full_pass_once (B , T , V , reduction )
0 commit comments