Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions recipe/transfer_queue/main_ppo.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Note that we don't combine the main with ray_trainer as ray_trainer is used by other mpain.
Note that we don't combine the main with ray_trainer as ray_trainer is used by other main.
"""

import os
Expand All @@ -24,13 +24,14 @@

from verl.experimental.dataset.sampler import AbstractSampler
from verl.trainer.constants_ppo import get_ppo_ray_runtime_env
from verl.trainer.ppo.ray_trainer import RayPPOTrainer
from verl.trainer.ppo.reward import load_reward_manager
from verl.trainer.ppo.utils import need_critic, need_reference_policy
from verl.utils.config import validate_config
from verl.utils.device import is_cuda_available
from verl.utils.import_utils import load_extern_type

from .ray_trainer import RayPPOTrainer


@hydra.main(config_path="config", config_name="ppo_trainer", version_base=None)
def main(config):
Expand Down Expand Up @@ -312,7 +313,6 @@ def run(self, config):
)
# Initialize the workers of the trainer.
trainer.init_workers()

# Start the training process.
trainer.fit()

Expand Down Expand Up @@ -350,6 +350,7 @@ def create_rl_dataset(data_paths, data_config, tokenizer, processor, is_train=Tr

dataset_cls = DynamicGenDataset
print("Using DynamicGenDataset for data generation.")

else:
# Use the default RLHFDataset class if no custom class is specified
dataset_cls = RLHFDataset
Expand Down
Loading
Loading