Skip to content

Commit f71401e

Browse files
ayulockinjon-tow
andauthored
Fix wandb.errors.RequireError as reported in #162 (#167)
* check wandb version * Update wandb.py * add wandb version + mean/reward * fix api changes * Update ILQL sweep hparams * Run `isort` * Remove `return`s from examples to avoid RayTune erros * Skip `isort` on wandb reports import * Remove erroring `tqdm` call in `BaseRLTrainer.evaluate` Co-authored-by: jon-tow <[email protected]>
1 parent 0c5246f commit f71401e

File tree

7 files changed

+13
-13
lines changed

7 files changed

+13
-13
lines changed

configs/sweeps/ilql_sweep.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ tune_config:
55
scheduler: "fifo"
66
num_samples: 32
77

8-
lr_init:
8+
lr:
99
strategy: "loguniform"
1010
values: [0.00001, 0.01]
1111
tau:

configs/sweeps/ppo_sweep.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
tune_config:
22
mode: "max"
3-
metric: "mean_reward"
3+
metric: "reward/mean"
44
search_alg: "random"
55
scheduler: "fifo"
66
num_samples: 32
77

88
# https://docs.ray.io/en/latest/tune/api_docs/search_space.html#tune-sample-docs
9-
lr_init:
9+
lr:
1010
strategy: "loguniform"
1111
values: [0.00001, 0.01]
1212
init_kl_coef:

examples/ppo_sentiments.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ def reward_fn(samples: List[str]) -> List[float]:
4848
imdb = load_dataset("imdb", split="test")
4949
val_prompts = [" ".join(review.split()[:4]) for review in imdb["text"]]
5050

51-
return trlx.train(
51+
trlx.train(
5252
reward_fn=reward_fn,
5353
prompts=prompts,
5454
eval_prompts=val_prompts[0:1000],

examples/summarize_daily_cnn/t5_summarize_daily_cnn.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ def reward_fn(samples: List[str]):
7474
) # get prompt like trlx's prompt
7575
prompt_label[key.strip()] = val_summaries[i]
7676

77-
model = trlx.train(
77+
trlx.train(
7878
config.model.model_path,
7979
reward_fn=reward_fn,
8080
prompts=prompts,

setup.cfg

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ install_requires =
1919
torchtyping
2020
transformers>=4.21.2
2121
tqdm
22-
wandb
22+
wandb>=0.13.5
2323
ray>=2.0.1
2424
tabulate>=0.9.0
2525
networkx

trlx/ray_tune/wandb.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,8 @@
77

88
import wandb
99

10-
wandb.require("report-editing")
11-
import wandb.apis.reports as wb # noqa: E402
10+
import wandb.apis.reports as wb # isort: skip
11+
1212

1313
ray_info = [
1414
"done",
@@ -84,10 +84,10 @@ def log_trials(trial_path: str, project_name: str):
8484
def create_report(project_name, param_space, tune_config, trial_path, best_config=None):
8585
def get_parallel_coordinate(param_space, metric):
8686
column_names = list(param_space.keys())
87-
columns = [wb.reports.PCColumn(column) for column in column_names]
87+
columns = [wb.PCColumn(column) for column in column_names]
8888

8989
return wb.ParallelCoordinatesPlot(
90-
columns=columns + [wb.reports.PCColumn(metric)],
90+
columns=columns + [wb.PCColumn(metric)],
9191
layout={"x": 0, "y": 0, "w": 12 * 2, "h": 5 * 2},
9292
)
9393

@@ -155,7 +155,7 @@ def get_metrics_with_history(project_name, group_name, entity=None):
155155
get_scatter_plot(tune_config["metric"]),
156156
],
157157
runsets=[
158-
wb.RunSet(project=project_name).set_filters_with_python_expr(
158+
wb.Runset(project=project_name).set_filters_with_python_expr(
159159
f'group == "{trial_path}"'
160160
)
161161
],
@@ -192,7 +192,7 @@ def get_metrics_with_history(project_name, group_name, entity=None):
192192
wb.PanelGrid(
193193
panels=line_plot_panels,
194194
runsets=[
195-
wb.RunSet(project=project_name).set_filters_with_python_expr(
195+
wb.Runset(project=project_name).set_filters_with_python_expr(
196196
f'group == "{trial_path}"'
197197
)
198198
],

trlx/trainer/accelerate_base_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,7 +221,7 @@ def evaluate(self): # noqa: C901
221221
prompts_sizes = []
222222
lst_prompts = []
223223
generate_time = time()
224-
for prompts in tqdm(self.eval_dataloader, desc="Generating samples"):
224+
for prompts in self.eval_dataloader:
225225
if isinstance(prompts, torch.Tensor):
226226
samples = self.generate_eval(prompts)
227227
else:

0 commit comments

Comments
 (0)