@@ -29,6 +29,7 @@ Adding a task for RLHF training depends on the desired training method and pre-e
29
29
git clone https://github.com/CarperAI/trlx.git
30
30
cd trlx
31
31
pip install -e " .[dev]"
32
+ pre-commit install # see .pre-commit-config.yaml
32
33
```
33
34
34
35
## Example: How to add a task
@@ -46,35 +47,52 @@ accelerate config
46
47
``` python
47
48
@register_datapipeline
48
49
class PPOPipeline (BasePipeline ):
49
- def __init__ (self , tokenizer , config , prompt_dataset_path = None ):
50
+ def __init__ (self , tokenizer , config , prompt_dataset_path = None ):
50
51
super ().__init__ ()
51
52
52
- ds = load_dataset(' imdb' , split = ' test' )
53
- ds = ds.rename_columns({' text' : ' review' , ' label' : ' sentiment' })
54
- ds = ds.filter(lambda x : len (x[" review" ])< 500 , batched = False )
55
-
56
- self .tokens = [tokenizer(text,
57
- truncation = True ,
58
- padding = ' max_length' ,
59
- max_length = config.train.input_size,
60
- return_tensors = " pt"
61
- )[' input_ids' ].long().flatten() for text in ds[' review' ]]
53
+ ds = load_dataset(" imdb" , split = " test" )
54
+ ds = ds.rename_columns({" text" : " review" , " label" : " sentiment" })
55
+ ds = ds.filter(lambda x : len (x[" review" ]) < 500 , batched = False )
56
+
57
+ self .tokens = [
58
+ tokenizer(
59
+ text,
60
+ truncation = True ,
61
+ padding = " max_length" ,
62
+ max_length = config.train.input_size,
63
+ return_tensors = " pt" ,
64
+ )[" input_ids" ]
65
+ .long()
66
+ .flatten()
67
+ for text in ds[" review" ]
68
+ ]
62
69
self .text = [tokenizer.decode(tokens.tolist()) for tokens in self .tokens]
63
70
64
- def __getitem__ (self , index : int ) -> PromptElement:
71
+ def __getitem__ (self , index : int ) -> PromptElement:
65
72
return PromptElement(self .text[index], self .tokens[index])
66
73
67
74
def __len__ (self ) -> int :
68
75
return len (self .text)
69
76
70
- def create_loader (self , batch_size : int , shuffle : bool , prep_fn : Callable = None , num_workers : int = 0 ) -> DataLoader:
71
- # TODO (dahoas): Decide how to support varying sizes of prompts without having to tokenize on fly
72
- def collate_fn (elems : Iterable[PromptElement]) -> PromptElement:
77
+ def create_loader (
78
+ self ,
79
+ batch_size : int ,
80
+ shuffle : bool ,
81
+ prep_fn : Callable = None ,
82
+ num_workers : int = 0 ,
83
+ ) -> DataLoader:
84
+ # TODO (dahoas): Decide how to support varying sizes of prompts without having to tokenize on fly
85
+ def collate_fn (elems : Iterable[PromptElement]) -> PromptElement:
73
86
return PromptBatch(
74
- [elem.text for elem in elems], torch.stack([elem.tokens for elem in elems]) # Assumes token tensors all same size
87
+ [elem.text for elem in elems],
88
+ torch.stack(
89
+ [elem.tokens for elem in elems]
90
+ ), # Assumes token tensors all same size
75
91
)
76
92
77
- return DataLoader(self , batch_size, shuffle, collate_fn = collate_fn, num_workers = num_workers)
93
+ return DataLoader(
94
+ self , batch_size, shuffle, collate_fn = collate_fn, num_workers = num_workers
95
+ )
78
96
```
79
97
80
98
### Launch training
0 commit comments