Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion llm/finetune_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
import json
import os
import sys
from dataclasses import dataclass, field
from functools import partial

import paddle
Expand Down Expand Up @@ -49,6 +50,23 @@
from paddlenlp.utils.log import logger


def add_start_docstrings(*docstr):
def docstring_decorator(fn):
fn.__doc__ = "".join(docstr) + (fn.__doc__ if fn.__doc__ is not None else "")
return fn

return docstring_decorator


@dataclass
@add_start_docstrings(TrainingArguments.__doc__)
class FinetuneArguments(TrainingArguments):
decay_steps: int = field(
default=0,
metadata={"help": "The steps use to control the learing rate."},
)


def read_local_dataset(path):
with open(path, "r", encoding="utf-8") as fp:
for line in fp:
Expand All @@ -57,7 +75,7 @@ def read_local_dataset(path):

def main():
# Arguments
parser = PdArgumentParser((GenerateArgument, QuantArgument, ModelArgument, DataArgument, TrainingArguments))
parser = PdArgumentParser((GenerateArgument, QuantArgument, ModelArgument, DataArgument, FinetuneArguments))
# Support format as "args.json --arg1 value1 --arg2 value2.”
# In case of conflict, command line arguments take precedence.
if len(sys.argv) >= 2 and sys.argv[1].endswith(".json"):
Expand Down
5 changes: 4 additions & 1 deletion paddlenlp/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -1650,13 +1650,16 @@ def create_scheduler(self, num_training_steps: int):
warmup = (
self.args.warmup_steps if self.args.warmup_steps > 0 else int(self.args.warmup_ratio * num_training_steps)
)
decay_steps = num_training_steps
if hasattr(self.args, "decay_steps") and self.args.decay_steps > 0:
decay_steps = self.args.decay_steps

if self.lr_scheduler is None:
self.lr_scheduler = get_scheduler(
self.args.lr_scheduler_type,
learning_rate=self.args.learning_rate,
num_warmup_steps=warmup,
num_training_steps=num_training_steps,
num_training_steps=decay_steps,
num_cycles=self.args.num_cycles,
lr_end=self.args.lr_end,
power=self.args.power,
Expand Down