Skip to content

Commit ba06186

Browse files
authored
Merge branch 'unique_temp_file_ids' into unique_files_id
2 parents 74a76e7 + fd2c349 commit ba06186

File tree

2 files changed

+44
-23
lines changed

2 files changed

+44
-23
lines changed

tiralib/tiramisu/compiling_service.py

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def compile_legality(cls, schedule: Schedule, with_ast: bool = False):
4040

4141
output_path = os.path.join(
4242
BaseConfig.base_config.workspace,
43-
f"{schedule.tiramisu_program.name}_legality",
43+
f"{schedule.tiramisu_program.temp_files_identifier}_legality",
4444
)
4545

4646
cpp_code = cls.get_legality_code(schedule=schedule, with_ast=with_ast)
@@ -129,7 +129,7 @@ def compile_annotations(cls, tiramisu_program: TiramisuProgram):
129129

130130
output_path = os.path.join(
131131
BaseConfig.base_config.workspace,
132-
f"{tiramisu_program.name}_annotations",
132+
f"{tiramisu_program.temp_files_identifier}_annotations",
133133
)
134134
# Add code to the original file to get json annotations
135135

@@ -167,7 +167,7 @@ def compile_isl_ast_tree(
167167

168168
output_path = os.path.join(
169169
BaseConfig.base_config.workspace,
170-
f"{tiramisu_program.name}_isl_ast",
170+
f"{tiramisu_program.temp_files_identifier}_isl_ast",
171171
)
172172
get_isl_ast_lines = ""
173173
if schedule:
@@ -323,7 +323,7 @@ def call_skewing_solver(
323323
logger.debug("Skewing Solver Code:\n" + solver_code)
324324
output_path = os.path.join(
325325
BaseConfig.base_config.workspace,
326-
f"{schedule.tiramisu_program.name}_skewing_solver",
326+
f"{schedule.tiramisu_program.temp_files_identifier}_skewing_solver",
327327
)
328328

329329
result_str = cls.run_cpp_code(cpp_code=solver_code, output_path=output_path)
@@ -443,7 +443,7 @@ def get_cpu_exec_times( # noqa: C901
443443
cpp_code = cls.get_schedule_code(tiramisu_program, optims_list)
444444
# Write the code to a file
445445
output_path = os.path.join(
446-
BaseConfig.base_config.workspace, tiramisu_program.name
446+
BaseConfig.base_config.workspace, tiramisu_program.temp_files_identifier
447447
)
448448

449449
cls.write_to_disk(cpp_code, output_path + "_schedule")
@@ -473,17 +473,17 @@ def get_cpu_exec_times( # noqa: C901
473473
shell_script = [
474474
# Compile intermidiate tiramisu file
475475
f"cd {BaseConfig.base_config.workspace}",
476-
f"$CXX -Wl,--no-as-needed -ldl -g -fno-rtti -lpthread -fopenmp -std=c++17 -O0 -o {tiramisu_program.name}.o -c {tiramisu_program.name}_schedule.cpp",
476+
f"$CXX -Wl,--no-as-needed -ldl -g -fno-rtti -lpthread -fopenmp -std=c++17 -O0 -o {tiramisu_program.temp_files_identifier}.o -c {tiramisu_program.temp_files_identifier}_schedule.cpp",
477477
# Link generated file with executer
478-
f"$CXX -Wl,--no-as-needed -ldl -g -fno-rtti -lpthread -fopenmp -std=c++17 -O0 {tiramisu_program.name}.o -o {tiramisu_program.name}.out -ltiramisu -ltiramisu_auto_scheduler -lHalide -lisl",
478+
f"$CXX -Wl,--no-as-needed -ldl -g -fno-rtti -lpthread -fopenmp -std=c++17 -O0 {tiramisu_program.temp_files_identifier}.o -o {tiramisu_program.temp_files_identifier}.out -ltiramisu -ltiramisu_auto_scheduler -lHalide -lisl",
479479
# Run the program
480-
f"./{tiramisu_program.name}.out",
481-
f"$CXX -shared -o {tiramisu_program.name}.so {tiramisu_program.name}.o", # noqa: E501
480+
f"./{tiramisu_program.temp_files_identifier}.out",
481+
f"$CXX -shared -o {tiramisu_program.temp_files_identifier}.so {tiramisu_program.temp_files_identifier}.o", # noqa: E501
482482
]
483483
if not tiramisu_program.wrapper_obj:
484484
shell_script += [
485485
# compile the wrapper
486-
f"$CXX -std=c++17 -fno-rtti -o {tiramisu_program.name}_wrapper -ltiramisu -lHalide -ldl -lpthread -fopenmp -lm {tiramisu_program.name}_wrapper.cpp ./{tiramisu_program.name}.so -ltiramisu -lHalide -ldl -lpthread -fopenmp -lm -lisl"
486+
f"$CXX -std=c++17 -fno-rtti -o {tiramisu_program.temp_files_identifier}_wrapper -ltiramisu -lHalide -ldl -lpthread -fopenmp -lm {tiramisu_program.temp_files_identifier}_wrapper.cpp ./{tiramisu_program.temp_files_identifier}.so -ltiramisu -lHalide -ldl -lpthread -fopenmp -lm -lisl"
487487
]
488488
try:
489489
# run the compilation of the generator and wrapper
@@ -558,11 +558,6 @@ def get_cpu_exec_times( # noqa: C901
558558
check=False,
559559
)
560560

561-
if delete_files:
562-
CompilingService.delete_temporary_files(
563-
tiramisu_program=tiramisu_program
564-
)
565-
566561
# if the command has to quit properly, that is either on timeout or (noraml completion and non-empty stdout)
567562
if not (
568563
compiler.returncode == 124
@@ -589,7 +584,10 @@ def get_cpu_exec_times( # noqa: C901
589584
f"Execution of wrapper timed-out. Completed {len(results)} out of {min_runs} min_runs and {len(compiler.stdout.split())} out of {nb_exec_left} extra runs. Collected measurements are [{' '.join(list(map(str, results)))}]+[{compiler.stdout}]."
590585
)
591586
results += [float(x) for x in compiler.stdout.split()]
592-
587+
if delete_files:
588+
CompilingService.delete_temporary_files(
589+
tiramisu_program=tiramisu_program
590+
)
593591
return results
594592

595593
except subprocess.CalledProcessError as e:
@@ -619,9 +617,9 @@ def get_n_runs_script(
619617
# set the env variables
620618
f"export NB_EXEC={nb_exec}",
621619
# run the wrapper
622-
f"./{tiramisu_program.name}_wrapper"
620+
f"./{tiramisu_program.temp_files_identifier}_wrapper"
623621
if timeout is None
624-
else f"timeout {timeout / 1000} ./{tiramisu_program.name}_wrapper",
622+
else f"timeout {timeout / 1000} ./{tiramisu_program.temp_files_identifier}_wrapper",
625623
]
626624

627625
@classmethod
@@ -630,7 +628,7 @@ def delete_temporary_files(cls, tiramisu_program: TiramisuProgram):
630628
subprocess.run(
631629
[
632630
# cd to the workspace and clean generated files
633-
f"cd {BaseConfig.base_config.workspace} && rm {tiramisu_program.name}*"
631+
f"cd {BaseConfig.base_config.workspace} && rm {tiramisu_program.temp_files_identifier}*"
634632
],
635633
capture_output=True,
636634
text=True,

tiralib/tiramisu/tiramisu_program.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import json
22
import random
33
import re
4+
import string
45
from pathlib import Path
56
from typing import Dict
67

@@ -23,6 +24,9 @@ class TiramisuProgram:
2324
The list of computations in the function
2425
`name`: str
2526
The name of the function
27+
`temp_files_identifier`: str
28+
A unique name prefix used to differentiate between temporary files
29+
generated by concurrently running instances of TiraLib
2630
`schedules_legality`: dict
2731
The legality of the schedules of the function
2832
`schedules_solver`: dict
@@ -43,6 +47,7 @@ def __init__(self: "TiramisuProgram"):
4347
self.isl_ast_string: str | None = None
4448
self.comps: list[str] | None = None
4549
self.name: str | None = None
50+
self.temp_files_identifier: str | None = None
4651
# self.schedules_legality = {}
4752
# self.schedules_solver = {}
4853
self.schedules_dict: Dict = {}
@@ -67,6 +72,11 @@ def from_dict(
6772
# Initiate an instante of the TiramisuProgram class
6873
tiramisu_prog = cls()
6974
tiramisu_prog.name = name
75+
tiramisu_prog.temp_files_identifier = (
76+
tiramisu_prog.name
77+
+ "_"
78+
+ "".join(random.choices(string.ascii_letters + string.digits, k=5))
79+
)
7080
tiramisu_prog.annotations = data["program_annotation"]
7181
if tiramisu_prog.annotations:
7282
tiramisu_prog.comps = list(tiramisu_prog.annotations["computations"].keys())
@@ -225,10 +235,21 @@ def load_code_lines(self, original_str: str | None = None):
225235
self.original_str,
226236
)[0]
227237
self.name = re.findall(r"tiramisu::init\(\"(\w+)\"\);", self.original_str)[0]
238+
self.temp_files_identifier = (
239+
self.name
240+
+ "_"
241+
+ "".join(random.choices(string.ascii_letters + string.digits, k=5))
242+
)
228243
# Remove the wrapper include from the original string
229-
self.wrapper_str = f'#include "{self.name}_wrapper.h"'
244+
self.wrapper_str = f'#include "{self.temp_files_identifier}_wrapper.h"'
230245
self.original_str = self.original_str.replace(
231-
self.wrapper_str, f"// {self.wrapper_str}"
246+
f'#include "{self.name}_wrapper.h"', f"// {self.wrapper_str}"
247+
)
248+
# Change the file name of the generated object file in the codegen line to use the temp files identifier
249+
self.original_str = re.sub(
250+
r'(tiramisu::codegen\(\{[^}]+\},\s*")([^"]*)("\);)',
251+
r"\1" + f"./{self.temp_files_identifier}.o" + r"\3",
252+
self.original_str,
232253
)
233254
self.comps = re.findall(r"computation (\w+)\(", self.original_str)
234255
self.code_gen_line = re.findall(r"tiramisu::codegen\({.+;", self.original_str)[
@@ -257,7 +278,9 @@ def construct_wrapper_code(
257278
if self.name is None:
258279
raise Exception("TiramisuProgram.name is None")
259280

260-
wrapper_cpp_code = wrapper_cpp_template.replace("$func_name$", self.name)
281+
wrapper_cpp_code = wrapper_cpp_template.replace(
282+
"$func_id$", self.temp_files_identifier
283+
).replace("$func_name$", self.name)
261284
wrapper_cpp_code = wrapper_cpp_code.replace(
262285
"$buffers_init$", buffers_init_lines
263286
)
@@ -285,7 +308,7 @@ def __repr__(self) -> str:
285308

286309

287310
wrapper_cpp_template = """#include "Halide.h"
288-
#include "$func_name$_wrapper.h"
311+
#include "$func_id$_wrapper.h"
289312
#include "tiramisu/utils.h"
290313
#include <iostream>
291314
#include <time.h>

0 commit comments

Comments
 (0)