Skip to content

Commit f0a37d2

Browse files
authored
Add instruct model option to verifier scripts (#82)
* update verifier scripts * Update ch03/02_math500-verifier-scripts/evaluate_math500_batched.py * update
1 parent c8aaf18 commit f0a37d2

File tree

3 files changed

+48
-52
lines changed

3 files changed

+48
-52
lines changed

ch03/02_math500-verifier-scripts/evaluate_math500.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def parse_args():
4444
"--which_model",
4545
type=str,
4646
default="base",
47-
choices=["base", "reasoning"],
47+
choices=["base", "reasoning", "instruct"],
4848
help="Model variant to load. Defaults to 'base'.",
4949
)
5050
parser.add_argument(

ch03/02_math500-verifier-scripts/evaluate_math500_batched.py

Lines changed: 2 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -14,14 +14,7 @@
1414
import torch
1515

1616
from reasoning_from_scratch.ch02 import get_device
17-
from reasoning_from_scratch.qwen3 import (
18-
download_qwen3_small,
19-
Qwen3Tokenizer,
20-
QWEN_CONFIG_06_B,
21-
)
22-
from reasoning_from_scratch.qwen3_batched import (
23-
Qwen3Model as Qwen3ModelBatched,
24-
)
17+
from reasoning_from_scratch.qwen3_batched import get_model
2518
from reasoning_from_scratch.ch03 import (
2619
render_prompt,
2720
extract_final_candidate,
@@ -47,46 +40,6 @@ def get_data():
4740
return math_data
4841

4942

50-
def get_model(which_model, device, use_compile):
51-
if which_model == "base":
52-
53-
download_qwen3_small(
54-
kind="base", tokenizer_only=False, out_dir="qwen3"
55-
)
56-
57-
tokenizer_path = Path("qwen3") / "tokenizer-base.json"
58-
model_path = Path("qwen3") / "qwen3-0.6B-base.pth"
59-
tokenizer = Qwen3Tokenizer(tokenizer_file_path=tokenizer_path)
60-
61-
elif which_model == "reasoning":
62-
63-
download_qwen3_small(
64-
kind="reasoning", tokenizer_only=False, out_dir="qwen3"
65-
)
66-
67-
tokenizer_path = Path("qwen3") / "tokenizer-reasoning.json"
68-
model_path = Path("qwen3") / "qwen3-0.6B-reasoning.pth"
69-
tokenizer = Qwen3Tokenizer(
70-
tokenizer_file_path=tokenizer_path,
71-
apply_chat_template=True,
72-
add_generation_prompt=True,
73-
add_thinking=True,
74-
)
75-
76-
else:
77-
raise ValueError(f"Invalid choice: WHICH_MODEL={which_model}")
78-
79-
model = Qwen3ModelBatched(QWEN_CONFIG_06_B)
80-
model.load_state_dict(torch.load(model_path, map_location="cpu"))
81-
model.to(device)
82-
83-
if use_compile:
84-
torch._dynamo.config.allow_unspec_int_on_nn_module = True
85-
model = torch.compile(model)
86-
87-
return model, tokenizer
88-
89-
9043
def evaluate_math500_batched(
9144
model,
9245
tokenizer,
@@ -201,7 +154,7 @@ def parse_args():
201154
"--which_model",
202155
type=str,
203156
default="base",
204-
choices=["base", "reasoning"],
157+
choices=["base", "reasoning", "instruct"],
205158
help="Model variant to load. Defaults to 'base'.",
206159
)
207160
parser.add_argument(

reasoning_from_scratch/qwen3_batched.py

Lines changed: 45 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
# Source for "Build a Reasoning Model (From Scratch)": https://mng.bz/lZ5B
33
# Code repository: https://github.com/rasbt/reasoning-from-scratch
44

5-
from .qwen3 import KVCache
5+
from .qwen3 import KVCache, download_qwen3_small, Qwen3Tokenizer
6+
7+
from pathlib import Path
68

79
import torch
810
import torch.nn as nn
@@ -582,4 +584,45 @@ def generate_text_basic_batched_stream_cache_stop(
582584

583585
shrink_kv_cache_inplace(cache, keep_mask_active, model.cfg["n_layers"])
584586

585-
out = model(next_token_survivors, cache=cache, attn_mask=cur_attn_active)[:, -1]
587+
out = model(next_token_survivors, cache=cache, attn_mask=cur_attn_active)[:, -1]
588+
589+
590+
def get_model(which_model, device, use_compile):
591+
if which_model == "base":
592+
593+
download_qwen3_small(
594+
kind="base", tokenizer_only=False, out_dir="qwen3"
595+
)
596+
597+
tokenizer_path = Path("qwen3") / "tokenizer-base.json"
598+
model_path = Path("qwen3") / "qwen3-0.6B-base.pth"
599+
tokenizer = Qwen3Tokenizer(tokenizer_file_path=tokenizer_path)
600+
601+
elif which_model in ("reasoning", "instruct"):
602+
603+
download_qwen3_small(
604+
kind="reasoning", tokenizer_only=False, out_dir="qwen3"
605+
)
606+
607+
tokenizer_path = Path("qwen3") / "tokenizer-reasoning.json"
608+
model_path = Path("qwen3") / "qwen3-0.6B-reasoning.pth"
609+
tokenizer = Qwen3Tokenizer(
610+
tokenizer_file_path=tokenizer_path,
611+
apply_chat_template=True,
612+
add_generation_prompt=True,
613+
add_thinking=which_model == "reasoning",
614+
)
615+
616+
else:
617+
raise ValueError(f"Invalid choice: WHICH_MODEL={which_model}")
618+
619+
model = Qwen3Model(QWEN_CONFIG_06_B)
620+
model.load_state_dict(torch.load(model_path))
621+
622+
model.to(device)
623+
624+
if use_compile:
625+
torch._dynamo.config.allow_unspec_int_on_nn_module = True
626+
model = torch.compile(model)
627+
628+
return model, tokenizer

0 commit comments

Comments
 (0)