|
44 | 44 | catchup=False,
|
45 | 45 | ) as dag:
|
46 | 46 | test_name_prefix = "maxtext"
|
| 47 | + |
47 | 48 | test_models_tpu = {
|
48 |
| - "llama2-7b": "tpu/llama2/7b/test_llama2_7b", |
49 |
| - "mistral-7b": "tpu/mistral/7b/test_mistral-7b", |
50 |
| - "gemma-2b": "tpu/gemma/2b/test_gemma", |
51 |
| - "gpt3": "tpu/test_gpt3", |
| 49 | + "llama2-7b": { |
| 50 | + "owner": test_owner.MOHIT_K, |
| 51 | + "commands": ["bash end_tp_end/tpu/llama2/7b/test_llama2_7b.sh"], |
| 52 | + }, |
| 53 | + "mistral-7b": { |
| 54 | + "owner": test_owner.MOHIT_K, |
| 55 | + "commands": ["bash end_to_end/tpu/mistral/7b/test_mistral-7b.sh"], |
| 56 | + }, |
| 57 | + "gemma-2b": { |
| 58 | + "owner": test_owner.MOHIT_K, |
| 59 | + "commands": ["bash end_to_end/tpu/gemma/2b/test_gemma.sh"], |
| 60 | + }, |
| 61 | + "gemma2-2b": { |
| 62 | + "owner": test_owner.HENGTAO_G, |
| 63 | + "commands": [ |
| 64 | + "bash end_to_end/tpu/gemma2/2b/test_gemma2_to_mt.sh", |
| 65 | + "bash end_to_end/tpu/gemma2/2b/test_gemma2_to_hf.sh", |
| 66 | + ], |
| 67 | + }, |
| 68 | + "gemma3-4b": { |
| 69 | + "owner": test_owner.HENGTAO_G, |
| 70 | + "commands": [ |
| 71 | + "bash end_to_end/tpu/gemma3/4b/test_gemma3_to_mt.sh", |
| 72 | + "bash end_to_end/tpu/gemma3/4b/test_gemma3_to_hf.sh", |
| 73 | + ], |
| 74 | + }, |
| 75 | + "qwen3-4b": { |
| 76 | + "owner": test_owner.HENGTAO_G, |
| 77 | + "commands": [ |
| 78 | + "bash end_to_end/tpu/qwen3/4b/test_qwen3_to_mt.sh", |
| 79 | + "bash end_to_end/tpu/qwen3/4b/test_qwen3_to_hf.sh", |
| 80 | + ], |
| 81 | + }, |
| 82 | + "gpt3": { |
| 83 | + "owner": test_owner.MOHIT_K, |
| 84 | + "commands": ["bash end_to_end/tpu/test_gpt3.sh"], |
| 85 | + }, |
52 | 86 | }
|
53 | 87 |
|
54 | 88 | quarantine_task_group = TaskGroup(
|
55 | 89 | group_id="Quarantine", dag=dag, prefix_group_id=False
|
56 | 90 | )
|
57 | 91 |
|
58 |
| - for model, test_script in test_models_tpu.items(): |
| 92 | + for model, test_config in test_models_tpu.items(): |
| 93 | + model_cmds = (f"export HF_TOKEN={HF_TOKEN}",) + tuple( |
| 94 | + test_config["commands"] |
| 95 | + ) |
59 | 96 | stable_tpu = gke_config.get_gke_config(
|
60 | 97 | time_out_in_min=60,
|
61 | 98 | test_name=f"{test_name_prefix}-stable-{model}",
|
62 |
| - run_model_cmds=( |
63 |
| - f"export HF_TOKEN={HF_TOKEN}", |
64 |
| - f"bash end_to_end/{test_script}.sh", |
65 |
| - ), |
| 99 | + run_model_cmds=model_cmds, |
66 | 100 | docker_image=DockerImage.MAXTEXT_TPU_JAX_STABLE_STACK_CANDIDATE.value,
|
67 | 101 | cluster=XpkClusters.TPU_V5P_8_CLUSTER,
|
68 |
| - test_owner=test_owner.MOHIT_K, |
| 102 | + test_owner=test_config["owner"], |
69 | 103 | ).run_with_quarantine(quarantine_task_group)
|
70 | 104 | nightly_tpu = gke_config.get_gke_config(
|
71 | 105 | time_out_in_min=60,
|
72 | 106 | test_name=f"{test_name_prefix}-nightly-{model}",
|
73 |
| - run_model_cmds=( |
74 |
| - f"export HF_TOKEN={HF_TOKEN}", |
75 |
| - f"bash end_to_end/{test_script}.sh", |
76 |
| - ), |
| 107 | + run_model_cmds=model_cmds, |
77 | 108 | docker_image=DockerImage.MAXTEXT_TPU_STABLE_STACK_NIGHTLY_JAX.value,
|
78 | 109 | cluster=XpkClusters.TPU_V5P_8_CLUSTER,
|
79 |
| - test_owner=test_owner.MOHIT_K, |
| 110 | + test_owner=test_config["owner"], |
80 | 111 | ).run_with_quarantine(quarantine_task_group)
|
81 | 112 | stable_tpu >> nightly_tpu
|
82 | 113 |
|
|
0 commit comments