@@ -80,46 +80,69 @@ def main(argv):
8080 elif FLAGS .use_unifold :
8181 fold_backend , model_dir = "unifold" , FLAGS .unifold_param
8282
83- constant_args = {
84- "--input" : None ,
85- "--output_directory" : FLAGS .output_path ,
86- "--num_cycle" : FLAGS .num_cycle ,
87- "--num_predictions_per_model" : FLAGS .num_predictions_per_model ,
88- "--data_directory" : model_dir ,
89- "--features_directory" : FLAGS .monomer_objects_dir ,
90- "--pair_msa" : FLAGS .pair_msa ,
91- "--msa_depth_scan" : FLAGS .msa_depth_scan ,
92- "--multimeric_template" : FLAGS .multimeric_template ,
93- "--model_names" : FLAGS .model_names ,
94- "--msa_depth" : FLAGS .msa_depth ,
95- "--crosslinks" : FLAGS .crosslinks ,
96- "--fold_backend" : fold_backend ,
97- "--description_file" : FLAGS .description_file ,
98- "--path_to_mmt" : FLAGS .path_to_mmt ,
99- "--compress_result_pickles" : FLAGS .compress_result_pickles ,
100- "--remove_result_pickles" : FLAGS .remove_result_pickles ,
101- "--remove_keys_from_pickles" : FLAGS .remove_keys_from_pickles ,
102- "--use_ap_style" : True ,
103- "--use_gpu_relax" : FLAGS .use_gpu_relax ,
104- "--protein_delimiter" : FLAGS .protein_delimiter ,
105- "--desired_num_res" : FLAGS .desired_num_res ,
106- "--desired_num_msa" : FLAGS .desired_num_msa ,
107- "--models_to_relax" : FLAGS .models_to_relax
108- }
109- if FLAGS .fold_backend == "alphafold3" : #remove unnecessary args
110- unnecessary_keys = ["--num_cycle" ,"--num_predictions_per_model" ,"--pair_msa" ,"--msa_depth_scan" ,"--multimeric_template" ,"--msa_depth" ,"--compress_result_pickles" ,"--remove_result_pickles" ,"--remove_keys_from_pickles" ,"--use_gpu_relax" ,"--models_to_relax" ]
111- for key in unnecessary_keys :
112- constant_args .pop (key , None )
83+ # Build arguments to forward to run_structure_prediction.py
84+ # For AF3, only forward AF3-supported flags and map num_cycle -> num_recycles.
85+ if fold_backend == "alphafold3" :
86+ constant_args = {
87+ "--input" : None ,
88+ "--output_directory" : FLAGS .output_path ,
89+ "--data_directory" : model_dir ,
90+ "--features_directory" : FLAGS .monomer_objects_dir ,
91+ "--fold_backend" : fold_backend ,
92+ "--protein_delimiter" : FLAGS .protein_delimiter ,
93+ "--use_ap_style" : FLAGS .use_ap_style ,
94+ # AF3-specific knobs:
95+ "--num_recycles" : FLAGS .num_cycle , # map from APD wrapper's num_cycle
96+ "--num_diffusion_samples" : getattr (FLAGS , "num_diffusion_samples" , None ),
97+ "--num_seeds" : getattr (FLAGS , "num_seeds" , None ),
98+ "--flash_attention_implementation" : getattr (FLAGS , "flash_attention_implementation" , None ),
99+ "--buckets" : getattr (FLAGS , "buckets" , None ),
100+ "--jax_compilation_cache_dir" : getattr (FLAGS , "jax_compilation_cache_dir" , None ),
101+ "--save_embeddings" : getattr (FLAGS , "save_embeddings" , None ),
102+ "--save_distogram" : getattr (FLAGS , "save_distogram" , None ),
103+ "--debug_templates" : getattr (FLAGS , "debug_templates" , None ),
104+ "--debug_msas" : getattr (FLAGS , "debug_msas" , None ),
105+ }
106+ else :
107+ constant_args = {
108+ "--input" : None ,
109+ "--output_directory" : FLAGS .output_path ,
110+ "--num_cycle" : FLAGS .num_cycle ,
111+ "--num_predictions_per_model" : FLAGS .num_predictions_per_model ,
112+ "--data_directory" : model_dir ,
113+ "--features_directory" : FLAGS .monomer_objects_dir ,
114+ "--pair_msa" : FLAGS .pair_msa ,
115+ "--msa_depth_scan" : FLAGS .msa_depth_scan ,
116+ "--multimeric_template" : FLAGS .multimeric_template ,
117+ "--model_names" : FLAGS .model_names ,
118+ "--msa_depth" : FLAGS .msa_depth ,
119+ "--crosslinks" : FLAGS .crosslinks ,
120+ "--fold_backend" : fold_backend ,
121+ "--description_file" : FLAGS .description_file ,
122+ "--path_to_mmt" : FLAGS .path_to_mmt ,
123+ "--compress_result_pickles" : FLAGS .compress_result_pickles ,
124+ "--remove_result_pickles" : FLAGS .remove_result_pickles ,
125+ "--remove_keys_from_pickles" : FLAGS .remove_keys_from_pickles ,
126+ "--use_ap_style" : True ,
127+ "--use_gpu_relax" : FLAGS .use_gpu_relax ,
128+ "--protein_delimiter" : FLAGS .protein_delimiter ,
129+ "--desired_num_res" : FLAGS .desired_num_res ,
130+ "--desired_num_msa" : FLAGS .desired_num_msa ,
131+ "--models_to_relax" : FLAGS .models_to_relax
132+ }
113133
114134 command_args = {}
115135 for k , v in constant_args .items ():
116136 if v is None :
117137 continue
118- elif v is False :
119- updated_key = f"--no{ k .split ('--' )[- 1 ]} "
120- command_args [updated_key ] = ""
121- elif v is True :
122- command_args [k ] = ""
138+ if isinstance (v , bool ):
139+ if v :
140+ command_args [k ] = ""
141+ else :
142+ # For AF3, don't emit negative boolean flags; for others, use --no-*
143+ if fold_backend != "alphafold3" :
144+ updated_key = f"--no{ k .split ('--' )[- 1 ]} "
145+ command_args [updated_key ] = ""
123146 elif isinstance (v , list ):
124147 command_args [k ] = "," .join ([str (x ) for x in v ])
125148 else :
@@ -141,15 +164,25 @@ def main(argv):
141164 command = base_command .copy ()
142165 for arg , value in command_args .items ():
143166 command .extend ([str (arg ), str (value )])
144- subprocess .run (" " .join (command ), check = True , shell = True )
167+ # Sanitize environment to avoid JAX conflicts in nested subprocess.
168+ child_env = os .environ .copy ()
169+ if ("XLA_CLIENT_MEM_FRACTION" in child_env ) and ("XLA_PYTHON_CLIENT_MEM_FRACTION" in child_env ):
170+ # Prefer the newer var and drop the deprecated one.
171+ del child_env ["XLA_PYTHON_CLIENT_MEM_FRACTION" ]
172+ logging .info (f"command: { command } " )
173+ subprocess .run (command , check = True , env = child_env )
145174 else :
146175 for job_index in job_indices :
147176 command_args ["--input" ] = all_folds [job_index ]
148177 command = base_command .copy ()
149178 for arg , value in command_args .items ():
150179 command .extend ([str (arg ), str (value )])
180+ # Sanitize environment to avoid JAX conflicts in nested subprocess.
181+ child_env = os .environ .copy ()
182+ if ("XLA_CLIENT_MEM_FRACTION" in child_env ) and ("XLA_PYTHON_CLIENT_MEM_FRACTION" in child_env ):
183+ del child_env ["XLA_PYTHON_CLIENT_MEM_FRACTION" ]
151184 logging .info (f"command: { command } " )
152- subprocess .run (" " . join ( command ) , check = True , shell = True )
185+ subprocess .run (command , check = True , env = child_env )
153186
154187
155188if __name__ == "__main__" :
0 commit comments