2727 maybe_convert_to_chatml ,
2828 maybe_extract_prompt ,
2929 maybe_unpair_preference_dataset ,
30+ pack_dataset ,
3031 pack_examples ,
32+ truncate_dataset ,
3133 unpair_preference_dataset ,
3234)
3335
@@ -395,7 +397,7 @@ def test_maybe_extract_prompt_standard_already_explicit(self):
395397
396398
397399class TestPackExamples (unittest .TestCase ):
398- def test_pack_examples_larger_chunks (self ):
400+ def test_larger_chunks (self ):
399401 examples = {
400402 "input_ids" : [[1 , 2 , 3 ], [4 , 5 , 6 , 7 ], [8 ]],
401403 "attention_mask" : [[0 , 1 , 1 ], [0 , 0 , 1 , 1 ], [1 ]],
@@ -408,7 +410,7 @@ def test_pack_examples_larger_chunks(self):
408410 result = pack_examples (examples , seq_length )
409411 self .assertEqual (result , expected_output )
410412
411- def test_pack_examples_smaller_chunks (self ):
413+ def test_smaller_chunks (self ):
412414 examples = {
413415 "input_ids" : [[1 , 2 , 3 ], [4 , 5 , 6 , 7 ], [8 ]],
414416 "attention_mask" : [[0 , 1 , 1 ], [0 , 0 , 1 , 1 ], [1 ]],
@@ -421,7 +423,7 @@ def test_pack_examples_smaller_chunks(self):
421423 result = pack_examples (examples , seq_length )
422424 self .assertEqual (result , expected_output )
423425
424- def test_pack_with_dataset (self ):
426+ def test_with_dataset (self ):
425427 examples = {
426428 "input_ids" : [[1 , 2 , 3 ], [4 , 5 , 6 , 7 ], [8 ]],
427429 "attention_mask" : [[0 , 1 , 1 ], [0 , 0 , 1 , 1 ], [1 ]],
@@ -436,6 +438,84 @@ def test_pack_with_dataset(self):
436438 self .assertEqual (dataset .to_dict (), expected_output )
437439
438440
441+ class TestPackDataset (unittest .TestCase ):
442+ def test_with_dataset (self ):
443+ examples = {
444+ "input_ids" : [[1 , 2 , 3 ], [4 , 5 , 6 , 7 ], [8 ]],
445+ "attention_mask" : [[0 , 1 , 1 ], [0 , 0 , 1 , 1 ], [1 ]],
446+ }
447+ dataset = Dataset .from_dict (examples )
448+ seq_length = 3
449+ expected_output = {
450+ "input_ids" : [[1 , 2 , 3 ], [4 , 5 , 6 ], [7 , 8 ]],
451+ "attention_mask" : [[0 , 1 , 1 ], [0 , 0 , 1 ], [1 , 1 ]],
452+ }
453+ dataset = pack_dataset (dataset , seq_length )
454+ self .assertEqual (dataset .to_dict (), expected_output )
455+
456+ def test_with_iterable_dataset (self ):
457+ examples = {
458+ "input_ids" : [[1 , 2 , 3 ], [4 , 5 , 6 , 7 ], [8 ]],
459+ "attention_mask" : [[0 , 1 , 1 ], [0 , 0 , 1 , 1 ], [1 ]],
460+ }
461+ dataset = Dataset .from_dict (examples ).to_iterable_dataset ()
462+ seq_length = 3
463+ expected_output = {
464+ "input_ids" : [[1 , 2 , 3 ], [4 , 5 , 6 ], [7 , 8 ]],
465+ "attention_mask" : [[0 , 1 , 1 ], [0 , 0 , 1 ], [1 , 1 ]],
466+ }
467+ dataset = pack_dataset (dataset , seq_length )
468+ num_examples = len (examples [next (iter (examples ))])
469+ self .assertEqual (next (iter (dataset .batch (batch_size = num_examples ))), expected_output )
470+
471+
472+ class TestTruncateExamples (unittest .TestCase ):
473+ def test_with_dataset (self ):
474+ examples = {
475+ "input_ids" : [[1 , 2 , 3 ], [4 , 5 , 6 , 7 ], [8 ]],
476+ "attention_mask" : [[0 , 1 , 1 ], [0 , 0 , 1 , 1 ], [1 ]],
477+ }
478+ dataset = Dataset .from_dict (examples )
479+ max_length = 2
480+ expected_output = {
481+ "input_ids" : [[1 , 2 ], [4 , 5 ], [8 ]],
482+ "attention_mask" : [[0 , 1 ], [0 , 0 ], [1 ]],
483+ }
484+ dataset = truncate_dataset (dataset , max_length )
485+ self .assertEqual (dataset .to_dict (), expected_output )
486+
487+ def test_with_iterable_dataset (self ):
488+ examples = {
489+ "input_ids" : [[1 , 2 , 3 ], [4 , 5 , 6 , 7 ], [8 ]],
490+ "attention_mask" : [[0 , 1 , 1 ], [0 , 0 , 1 , 1 ], [1 ]],
491+ }
492+ dataset = Dataset .from_dict (examples ).to_iterable_dataset ()
493+ max_length = 2
494+ expected_output = {
495+ "input_ids" : [[1 , 2 ], [4 , 5 ], [8 ]],
496+ "attention_mask" : [[0 , 1 ], [0 , 0 ], [1 ]],
497+ }
498+ dataset = truncate_dataset (dataset , max_length )
499+ num_examples = len (examples [next (iter (examples ))])
500+ self .assertEqual (next (iter (dataset .batch (batch_size = num_examples ))), expected_output )
501+
502+ def test_with_extra_column (self ):
503+ examples = {
504+ "input_ids" : [[1 , 2 , 3 ], [4 , 5 , 6 , 7 ], [8 ]],
505+ "attention_mask" : [[0 , 1 , 1 ], [0 , 0 , 1 , 1 ], [1 ]],
506+ "my_column" : ["a" , "b" , "c" ],
507+ }
508+ dataset = Dataset .from_dict (examples )
509+ max_length = 2
510+ expected_output = {
511+ "input_ids" : [[1 , 2 ], [4 , 5 ], [8 ]],
512+ "attention_mask" : [[0 , 1 ], [0 , 0 ], [1 ]],
513+ "my_column" : ["a" , "b" , "c" ],
514+ }
515+ dataset = truncate_dataset (dataset , max_length )
516+ self .assertEqual (dataset .to_dict (), expected_output )
517+
518+
439519class TestMaybeConvertToChatML (unittest .TestCase ):
440520 def test_with_conversations_key (self ):
441521 # Particular case where the key is "conversations": we rename it to "messages"
0 commit comments