|
| 1 | +import re |
| 2 | + |
| 3 | +import copy |
| 4 | + |
| 5 | +r""" |
| 6 | +This is the python script to generate c wrappers and corresponding |
| 7 | +headers for fortran (public) subroutines. |
| 8 | +
|
| 9 | +It scans the files in src folder and generate the c wrappers in prg_c_interface.F90 and |
| 10 | +headers in prg_progress_mod.h. |
| 11 | +
|
| 12 | +Note: the generated wrappers are not bug-free, it is only used to facilitate building |
| 13 | +c interrface. |
| 14 | +""" |
| 15 | + |
| 16 | +def get_public(fortran_code): |
| 17 | + |
| 18 | + # Define the regex pattern to match lines starting with 'public ::' |
| 19 | + pattern = re.compile(r'public *::(.*)') |
| 20 | + |
| 21 | + # Find matches in the fortran code |
| 22 | + matches = pattern.findall("".join(fortran_code)) |
| 23 | + |
| 24 | + public_subroutines = [] |
| 25 | + |
| 26 | + for match in matches: |
| 27 | + # Split by comma to get individual subroutines, strip whitespace |
| 28 | + subroutines = [subroutine.strip() for subroutine in match.split(',')] |
| 29 | + public_subroutines.extend(subroutines) |
| 30 | + |
| 31 | + # print the list of public subroutines |
| 32 | + print("Public subroutines are \n") |
| 33 | + for subroutine in public_subroutines: |
| 34 | + print(f" public :: {subroutine}_c") |
| 35 | + print("") |
| 36 | + return public_subroutines |
| 37 | + |
| 38 | + |
| 39 | +def transform_fortran_file(filename): |
| 40 | + # use to find the argument names of each subroutine |
| 41 | + argument_re = re.compile(r'(real\(8\)|real\(dp\)|real\(PREC\)|integer|logical|type\(bml_matrix_t\)|character\(\d+\)|character\(len=\*\))(,.+?)?\s+::\s+([a-zA-Z0-9_,\s\(\):]+)') |
| 42 | + |
| 43 | + # Replacement rules for argument types |
| 44 | + # i.e., fortran to c datatype mapping |
| 45 | + type_replacement = { |
| 46 | + "real(8)": "real(c_double)", |
| 47 | + "real(dp)": "real(c_double)", |
| 48 | + "real(PREC)": "real(c_double)", |
| 49 | + "type(bml_matrix_t)": "type(c_ptr)", |
| 50 | + "integer": "integer(c_int)", |
| 51 | + "character(len=*)":"character(c_char)", |
| 52 | + "character(10)":"character(c_char)", |
| 53 | + "character(20)":"character(c_char)", |
| 54 | + "character(50)":"character(c_char)", |
| 55 | + "character(3)":"character(c_char)", |
| 56 | + "character(2)":"character(c_char)", |
| 57 | + "logical":"logical(c_bool)" |
| 58 | + } |
| 59 | + |
| 60 | + # Replacement rules for argument types (for c header) |
| 61 | + type_replacement_header = { |
| 62 | + "real(8)": "double", |
| 63 | + "real(dp)": "double", |
| 64 | + "real(PREC)": "double", |
| 65 | + "type(bml_matrix_t)": "bml_matrix_t*", |
| 66 | + "integer": "int", |
| 67 | + "character(len=*)":"char*", |
| 68 | + "character(10)":"char*", |
| 69 | + "character(20)":"char*", |
| 70 | + "character(50)":"char*", |
| 71 | + "character(3)":"char*", |
| 72 | + "character(2)":"char*", |
| 73 | + "logical":"int*" |
| 74 | + } |
| 75 | + |
| 76 | + with open(filename, "r") as f: |
| 77 | + content = f.readlines() |
| 78 | + |
| 79 | + public_subroutines = get_public(content) |
| 80 | + |
| 81 | + new_content = [] |
| 82 | + header_content = [] |
| 83 | + is_inside_subroutine = False |
| 84 | + argument_names = [] |
| 85 | + argument_declarations = [] |
| 86 | + argument_assignments = [] |
| 87 | + subroutine_name = "" |
| 88 | + bml_matrix_t_args = [] |
| 89 | + |
| 90 | + for k, line in enumerate(content): |
| 91 | + if "&" in line: |
| 92 | + line = line.replace("\n","").strip() |
| 93 | + line = line + content[k+1].strip() |
| 94 | + line = line.replace("&", "") |
| 95 | + match = subroutine_re.match(line.strip()) |
| 96 | + if match: |
| 97 | + subroutine_name = match.group(1) |
| 98 | + argument_names = [arg.strip() for arg in match.group(2).split(',')] |
| 99 | + for iarg, argname in enumerate(argument_names): |
| 100 | + if "&" in argname: |
| 101 | + print(f"debug-zy: {argname} has character &") |
| 102 | + argument_names[iarg] = argname.split("&")[0] |
| 103 | + |
| 104 | + |
| 105 | + ispublic = subroutine_name in public_subroutines |
| 106 | + print(f"\ndebug-zy: a subroutine {subroutine_name} is found!", ispublic) |
| 107 | + print(f"debug-zy: its arguments are: {argument_names}") |
| 108 | + # skip the subroutine if it's not public |
| 109 | + if not ispublic: continue |
| 110 | + is_inside_subroutine = True |
| 111 | + arg2typ = {} |
| 112 | + continue |
| 113 | + |
| 114 | + if is_inside_subroutine: |
| 115 | + if 'end subroutine' in line: |
| 116 | + argument_names_c = [arg+"_c" if arg in bml_matrix_t_args else arg for arg in argument_names] |
| 117 | + new_subroutine_header = f' subroutine {subroutine_name}_c({", ".join(argument_names_c)}) bind(C, name="{subroutine_name}")' |
| 118 | + new_subroutine_body = '\n'.join(argument_declarations + argument_assignments) |
| 119 | + new_subroutine_call = f' call {subroutine_name}({", ".join(argument_names)})' |
| 120 | + new_content.append(new_subroutine_header) |
| 121 | + new_content.append(new_subroutine_body) |
| 122 | + new_content.append(new_subroutine_call) |
| 123 | + |
| 124 | + new_content.append(' end subroutine ' + subroutine_name + '_c\n') |
| 125 | + tmpheader = [] |
| 126 | + for arg in argument_names: |
| 127 | + if arg in arg2typ: |
| 128 | + tmpheader.append(f'{type_replacement_header[arg2typ[arg]]} {arg}') |
| 129 | + header_content.append(f'\nvoid {subroutine_name}({", ".join(tmpheader)});') |
| 130 | + |
| 131 | + is_inside_subroutine = False |
| 132 | + argument_names = [] |
| 133 | + argument_declarations = [] |
| 134 | + argument_assignments = [] |
| 135 | + bml_matrix_t_args = [] |
| 136 | + continue |
| 137 | + |
| 138 | + match = argument_re.match(line.strip()) |
| 139 | + if match: |
| 140 | + arg_type = match.group(1) |
| 141 | + arg_modifiers = match.group(2) |
| 142 | + print("matched group(3)=", match.group(3)) |
| 143 | + |
| 144 | + # Remove any spaces within array declarations |
| 145 | + input_string = match.group(3) |
| 146 | + # Split the input string by comma using lookahead assertion |
| 147 | + arg_names = re.split(r',(?![^(]*\))', input_string) |
| 148 | + arg_names = [arg.strip() for arg in arg_names] |
| 149 | + print("debug: arg_names=", arg_names) |
| 150 | + |
| 151 | + # create arg to type mapping for c header |
| 152 | + for arg_name in arg_names: |
| 153 | + if "(" in arg_name: |
| 154 | + arg_name = arg_name.split("(")[0] |
| 155 | + print(f"debug-zy: arg_name={arg_name}, type is {arg_type}") |
| 156 | + if arg_name in argument_names: |
| 157 | + arg2typ[arg_name] = arg_type |
| 158 | + |
| 159 | + if arg_type in type_replacement: |
| 160 | + new_arg_type = type_replacement[arg_type] |
| 161 | + print("new arg_type is:", new_arg_type) |
| 162 | + if arg_modifiers is None: |
| 163 | + arg_modifiers = ", value" |
| 164 | + else: |
| 165 | + arg_modifiers = arg_modifiers.replace("intent(inout)", "value") |
| 166 | + arg_modifiers = arg_modifiers.replace("intent(in)", "value") |
| 167 | + if "allocatable" in arg_modifiers or "optional" in arg_modifiers: |
| 168 | + arg_modifiers = arg_modifiers.replace("value", "") |
| 169 | + |
| 170 | + for arg_name in arg_names: |
| 171 | + tmparg_modifiers = copy.copy(arg_modifiers) |
| 172 | + |
| 173 | + if "(:,:)" in arg_name: |
| 174 | + tmparg_name = arg_name[:-5] |
| 175 | + tmparg_modifiers = tmparg_modifiers.replace("value", "target") |
| 176 | + elif "(:)" in arg_name: |
| 177 | + tmparg_name = arg_name[:-3] |
| 178 | + tmparg_modifiers = tmparg_modifiers.replace("value", "target") |
| 179 | + elif "(" in arg_name: |
| 180 | + tmparg_name = arg_name.split("(")[0] |
| 181 | + tmparg_modifiers = tmparg_modifiers.replace("value", "target") |
| 182 | + else: |
| 183 | + tmparg_name = arg_name |
| 184 | + print("debug: arg_name/tmparg_name=", arg_name, tmparg_name) |
| 185 | + print("debug: argument_names=", argument_names) |
| 186 | + |
| 187 | + if tmparg_name in argument_names: |
| 188 | + #if arg_name in argument_names: |
| 189 | + print(f"{tmparg_name} is added to the declaration\n") |
| 190 | + if arg_type == "type(bml_matrix_t)": |
| 191 | + bml_matrix_t_args.append(arg_name) |
| 192 | + argument_declarations.append(f' {new_arg_type}{tmparg_modifiers} :: {arg_name}_c') |
| 193 | + argument_declarations.append(f' {arg_type} :: {arg_name}') |
| 194 | + argument_assignments.append(f' {arg_name}%ptr = {arg_name}_c') |
| 195 | + else: |
| 196 | + argument_declarations.append(f' {new_arg_type}{tmparg_modifiers} :: {arg_name}') |
| 197 | + else: |
| 198 | + print(f"{tmparg_name} is not added to the declaration\n") |
| 199 | + |
| 200 | + new_filename = filename[:-4] + "_c.F90" |
| 201 | + print("c_wrappers are:") |
| 202 | + with open(new_filename, "w") as f: |
| 203 | + print('\n'.join(new_content)) |
| 204 | + f.write('\n'.join(new_content)) |
| 205 | + |
| 206 | + print("c headers are:") |
| 207 | + #header_filename = filename[:-4] + ".h" |
| 208 | + header_filename = "prg_progress_mod.h" |
| 209 | + with open(header_filename, "a") as f: |
| 210 | + print('\n'.join(header_content)) |
| 211 | + f.write('\n'.join(header_content)) |
| 212 | + f.close() |
| 213 | + |
| 214 | +# Example usage |
| 215 | +import glob, sys |
| 216 | + |
| 217 | +files=glob.glob(f"src/*mod.F90") |
| 218 | +for fname in files: |
| 219 | + transform_fortran_file(fname) |
0 commit comments