Skip to content

Commit f4f8415

Browse files
committed
Refactor for clarity, deal with overlapping env vars
1 parent 6d21537 commit f4f8415

File tree

1 file changed

+70
-37
lines changed

1 file changed

+70
-37
lines changed

alphapulldown/scripts/run_multimer_jobs.py

Lines changed: 70 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -80,46 +80,69 @@ def main(argv):
8080
elif FLAGS.use_unifold:
8181
fold_backend, model_dir = "unifold", FLAGS.unifold_param
8282

83-
constant_args = {
84-
"--input": None,
85-
"--output_directory": FLAGS.output_path,
86-
"--num_cycle": FLAGS.num_cycle,
87-
"--num_predictions_per_model": FLAGS.num_predictions_per_model,
88-
"--data_directory": model_dir,
89-
"--features_directory": FLAGS.monomer_objects_dir,
90-
"--pair_msa": FLAGS.pair_msa,
91-
"--msa_depth_scan": FLAGS.msa_depth_scan,
92-
"--multimeric_template": FLAGS.multimeric_template,
93-
"--model_names": FLAGS.model_names,
94-
"--msa_depth": FLAGS.msa_depth,
95-
"--crosslinks": FLAGS.crosslinks,
96-
"--fold_backend": fold_backend,
97-
"--description_file": FLAGS.description_file,
98-
"--path_to_mmt": FLAGS.path_to_mmt,
99-
"--compress_result_pickles": FLAGS.compress_result_pickles,
100-
"--remove_result_pickles": FLAGS.remove_result_pickles,
101-
"--remove_keys_from_pickles": FLAGS.remove_keys_from_pickles,
102-
"--use_ap_style": True,
103-
"--use_gpu_relax": FLAGS.use_gpu_relax,
104-
"--protein_delimiter": FLAGS.protein_delimiter,
105-
"--desired_num_res": FLAGS.desired_num_res,
106-
"--desired_num_msa": FLAGS.desired_num_msa,
107-
"--models_to_relax": FLAGS.models_to_relax
108-
}
109-
if FLAGS.fold_backend == "alphafold3" : #remove unnecessary args
110-
unnecessary_keys = ["--num_cycle","--num_predictions_per_model","--pair_msa","--msa_depth_scan","--multimeric_template","--msa_depth","--compress_result_pickles","--remove_result_pickles","--remove_keys_from_pickles","--use_gpu_relax","--models_to_relax"]
111-
for key in unnecessary_keys :
112-
constant_args.pop(key, None)
83+
# Build arguments to forward to run_structure_prediction.py
84+
# For AF3, only forward AF3-supported flags and map num_cycle -> num_recycles.
85+
if fold_backend == "alphafold3":
86+
constant_args = {
87+
"--input": None,
88+
"--output_directory": FLAGS.output_path,
89+
"--data_directory": model_dir,
90+
"--features_directory": FLAGS.monomer_objects_dir,
91+
"--fold_backend": fold_backend,
92+
"--protein_delimiter": FLAGS.protein_delimiter,
93+
"--use_ap_style": FLAGS.use_ap_style,
94+
# AF3-specific knobs:
95+
"--num_recycles": FLAGS.num_cycle, # map from APD wrapper's num_cycle
96+
"--num_diffusion_samples": getattr(FLAGS, "num_diffusion_samples", None),
97+
"--num_seeds": getattr(FLAGS, "num_seeds", None),
98+
"--flash_attention_implementation": getattr(FLAGS, "flash_attention_implementation", None),
99+
"--buckets": getattr(FLAGS, "buckets", None),
100+
"--jax_compilation_cache_dir": getattr(FLAGS, "jax_compilation_cache_dir", None),
101+
"--save_embeddings": getattr(FLAGS, "save_embeddings", None),
102+
"--save_distogram": getattr(FLAGS, "save_distogram", None),
103+
"--debug_templates": getattr(FLAGS, "debug_templates", None),
104+
"--debug_msas": getattr(FLAGS, "debug_msas", None),
105+
}
106+
else:
107+
constant_args = {
108+
"--input": None,
109+
"--output_directory": FLAGS.output_path,
110+
"--num_cycle": FLAGS.num_cycle,
111+
"--num_predictions_per_model": FLAGS.num_predictions_per_model,
112+
"--data_directory": model_dir,
113+
"--features_directory": FLAGS.monomer_objects_dir,
114+
"--pair_msa": FLAGS.pair_msa,
115+
"--msa_depth_scan": FLAGS.msa_depth_scan,
116+
"--multimeric_template": FLAGS.multimeric_template,
117+
"--model_names": FLAGS.model_names,
118+
"--msa_depth": FLAGS.msa_depth,
119+
"--crosslinks": FLAGS.crosslinks,
120+
"--fold_backend": fold_backend,
121+
"--description_file": FLAGS.description_file,
122+
"--path_to_mmt": FLAGS.path_to_mmt,
123+
"--compress_result_pickles": FLAGS.compress_result_pickles,
124+
"--remove_result_pickles": FLAGS.remove_result_pickles,
125+
"--remove_keys_from_pickles": FLAGS.remove_keys_from_pickles,
126+
"--use_ap_style": True,
127+
"--use_gpu_relax": FLAGS.use_gpu_relax,
128+
"--protein_delimiter": FLAGS.protein_delimiter,
129+
"--desired_num_res": FLAGS.desired_num_res,
130+
"--desired_num_msa": FLAGS.desired_num_msa,
131+
"--models_to_relax": FLAGS.models_to_relax
132+
}
113133

114134
command_args = {}
115135
for k, v in constant_args.items():
116136
if v is None:
117137
continue
118-
elif v is False:
119-
updated_key = f"--no{k.split('--')[-1]}"
120-
command_args[updated_key] = ""
121-
elif v is True:
122-
command_args[k] = ""
138+
if isinstance(v, bool):
139+
if v:
140+
command_args[k] = ""
141+
else:
142+
# For AF3, don't emit negative boolean flags; for others, use --no-*
143+
if fold_backend != "alphafold3":
144+
updated_key = f"--no{k.split('--')[-1]}"
145+
command_args[updated_key] = ""
123146
elif isinstance(v, list):
124147
command_args[k] = ",".join([str(x) for x in v])
125148
else:
@@ -141,15 +164,25 @@ def main(argv):
141164
command = base_command.copy()
142165
for arg, value in command_args.items():
143166
command.extend([str(arg), str(value)])
144-
subprocess.run(" ".join(command), check=True, shell=True)
167+
# Sanitize environment to avoid JAX conflicts in nested subprocess.
168+
child_env = os.environ.copy()
169+
if ("XLA_CLIENT_MEM_FRACTION" in child_env) and ("XLA_PYTHON_CLIENT_MEM_FRACTION" in child_env):
170+
# Prefer the newer var and drop the deprecated one.
171+
del child_env["XLA_PYTHON_CLIENT_MEM_FRACTION"]
172+
logging.info(f"command: {command}")
173+
subprocess.run(command, check=True, env=child_env)
145174
else:
146175
for job_index in job_indices:
147176
command_args["--input"] = all_folds[job_index]
148177
command = base_command.copy()
149178
for arg, value in command_args.items():
150179
command.extend([str(arg), str(value)])
180+
# Sanitize environment to avoid JAX conflicts in nested subprocess.
181+
child_env = os.environ.copy()
182+
if ("XLA_CLIENT_MEM_FRACTION" in child_env) and ("XLA_PYTHON_CLIENT_MEM_FRACTION" in child_env):
183+
del child_env["XLA_PYTHON_CLIENT_MEM_FRACTION"]
151184
logging.info(f"command: {command}")
152-
subprocess.run(" ".join(command), check=True, shell=True)
185+
subprocess.run(command, check=True, env=child_env)
153186

154187

155188
if __name__ == "__main__":

0 commit comments

Comments
 (0)