Skip to content

Commit f1cece7

Browse files
committed
Add tests for checkpoint conversion to end_to_end DAG
1 parent 528cda8 commit f1cece7

File tree

2 files changed

+47
-15
lines changed

2 files changed

+47
-15
lines changed

dags/common/test_owner.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@ class Team(enum.Enum):
6161
RISHABH_B = "notabee"
6262
NUOJIN_C = "NuojCheng"
6363
BRANDEN_V = "bvandermoon"
64+
HENGTAO_G = "hengtaoguo"
6465

6566
# Multi-tier Checkpointing
6667
ABHINAV_S = "abhinavclemson"

dags/multipod/maxtext_end_to_end.py

Lines changed: 46 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -44,39 +44,70 @@
4444
catchup=False,
4545
) as dag:
4646
test_name_prefix = "maxtext"
47+
4748
test_models_tpu = {
48-
"llama2-7b": "tpu/llama2/7b/test_llama2_7b",
49-
"mistral-7b": "tpu/mistral/7b/test_mistral-7b",
50-
"gemma-2b": "tpu/gemma/2b/test_gemma",
51-
"gpt3": "tpu/test_gpt3",
49+
"llama2-7b": {
50+
"owner": test_owner.MOHIT_K,
51+
"commands": ["bash end_tp_end/tpu/llama2/7b/test_llama2_7b.sh"],
52+
},
53+
"mistral-7b": {
54+
"owner": test_owner.MOHIT_K,
55+
"commands": ["bash end_to_end/tpu/mistral/7b/test_mistral-7b.sh"],
56+
},
57+
"gemma-2b": {
58+
"owner": test_owner.MOHIT_K,
59+
"commands": ["bash end_to_end/tpu/gemma/2b/test_gemma.sh"],
60+
},
61+
"gemma2-2b": {
62+
"owner": test_owner.HENGTAO_G,
63+
"commands": [
64+
"bash end_to_end/tpu/gemma2/2b/test_gemma2_to_mt.sh",
65+
"bash end_to_end/tpu/gemma2/2b/test_gemma2_to_hf.sh",
66+
],
67+
},
68+
"gemma3-4b": {
69+
"owner": test_owner.HENGTAO_G,
70+
"commands": [
71+
"bash end_to_end/tpu/gemma3/4b/test_gemma3_to_mt.sh",
72+
"bash end_to_end/tpu/gemma3/4b/test_gemma3_to_hf.sh",
73+
],
74+
},
75+
"qwen3-4b": {
76+
"owner": test_owner.HENGTAO_G,
77+
"commands": [
78+
"bash end_to_end/tpu/qwen3/4b/test_qwen3_to_mt.sh",
79+
"bash end_to_end/tpu/qwen3/4b/test_qwen3_to_hf.sh",
80+
],
81+
},
82+
"gpt3": {
83+
"owner": test_owner.MOHIT_K,
84+
"commands": ["bash end_to_end/tpu/test_gpt3.sh"],
85+
},
5286
}
5387

5488
quarantine_task_group = TaskGroup(
5589
group_id="Quarantine", dag=dag, prefix_group_id=False
5690
)
5791

58-
for model, test_script in test_models_tpu.items():
92+
for model, test_config in test_models_tpu.items():
93+
model_cmds = (f"export HF_TOKEN={HF_TOKEN}",) + tuple(
94+
test_config["commands"]
95+
)
5996
stable_tpu = gke_config.get_gke_config(
6097
time_out_in_min=60,
6198
test_name=f"{test_name_prefix}-stable-{model}",
62-
run_model_cmds=(
63-
f"export HF_TOKEN={HF_TOKEN}",
64-
f"bash end_to_end/{test_script}.sh",
65-
),
99+
run_model_cmds=model_cmds,
66100
docker_image=DockerImage.MAXTEXT_TPU_JAX_STABLE_STACK_CANDIDATE.value,
67101
cluster=XpkClusters.TPU_V5P_8_CLUSTER,
68-
test_owner=test_owner.MOHIT_K,
102+
test_owner=test_config["owner"],
69103
).run_with_quarantine(quarantine_task_group)
70104
nightly_tpu = gke_config.get_gke_config(
71105
time_out_in_min=60,
72106
test_name=f"{test_name_prefix}-nightly-{model}",
73-
run_model_cmds=(
74-
f"export HF_TOKEN={HF_TOKEN}",
75-
f"bash end_to_end/{test_script}.sh",
76-
),
107+
run_model_cmds=model_cmds,
77108
docker_image=DockerImage.MAXTEXT_TPU_STABLE_STACK_NIGHTLY_JAX.value,
78109
cluster=XpkClusters.TPU_V5P_8_CLUSTER,
79-
test_owner=test_owner.MOHIT_K,
110+
test_owner=test_config["owner"],
80111
).run_with_quarantine(quarantine_task_group)
81112
stable_tpu >> nightly_tpu
82113

0 commit comments

Comments
 (0)