|
| 1 | +""" |
| 2 | +DC Resistivity Inversion Example using New PDE Framework |
| 3 | +
|
| 4 | +This example demonstrates a complete inversion workflow: |
| 5 | +1. Create synthetic "true" model |
| 6 | +2. Generate synthetic data with noise |
| 7 | +3. Set up inversion with regularization |
| 8 | +4. Run inversion with beta cooling |
| 9 | +5. Visualize results |
| 10 | +""" |
| 11 | + |
| 12 | +import torch |
| 13 | +import numpy as np |
| 14 | +import matplotlib.pyplot as plt |
| 15 | +from simpegtorch.discretize import TensorMesh |
| 16 | +from simpegtorch.simulation.resistivity import ( |
| 17 | + DC3DCellCentered, |
| 18 | + SrcDipole, |
| 19 | + RxDipole, |
| 20 | + Survey, |
| 21 | +) |
| 22 | +from simpegtorch.simulation.base import DirectSolver, SimulationWrapper, mappings |
| 23 | +from simpegtorch.inversion import ( |
| 24 | + BaseInversion, |
| 25 | + BaseInvProblem, |
| 26 | + BetaSchedule, |
| 27 | + TargetMisfit, |
| 28 | + BetaEstimate_ByEig, |
| 29 | +) |
| 30 | +from simpegtorch.data_misfit import L2DataMisfit |
| 31 | +from simpegtorch.regularization import TikhonovRegularization |
| 32 | + |
| 33 | +# Set default dtype |
| 34 | +torch.set_default_dtype(torch.float64) |
| 35 | + |
| 36 | +print("=" * 70) |
| 37 | +print("DC Resistivity Inversion with New PDE Framework") |
| 38 | +print("=" * 70) |
| 39 | + |
| 40 | +# ============================================================================ |
| 41 | +# 1. Create mesh |
| 42 | +# ============================================================================ |
| 43 | +print("\n[1/6] Creating mesh...") |
| 44 | +hx = torch.ones(20) * 10.0 # 20 cells of 10m |
| 45 | +hy = torch.ones(20) * 10.0 |
| 46 | +hz = torch.ones(10) * 10.0 |
| 47 | + |
| 48 | +# 200m x 200m x 100m mesh |
| 49 | +origin = torch.tensor([0.0, 0.0, -100.0]) # Start 100m below surface |
| 50 | +mesh = TensorMesh([hx, hy, hz], origin=origin) |
| 51 | + |
| 52 | +print(f"Mesh: {mesh.n_cells} cells ({mesh.shape_cells[0]}x{mesh.shape_cells[1]}x{mesh.shape_cells[2]})") |
| 53 | + |
| 54 | +# ============================================================================ |
| 55 | +# 2. Create survey |
| 56 | +# ============================================================================ |
| 57 | +print("\n[2/6] Creating survey...") |
| 58 | + |
| 59 | +# Create a dipole-dipole survey with multiple measurements |
| 60 | +n_data_points = 15 |
| 61 | +x_locs = torch.linspace(20.0, 180.0, n_data_points) |
| 62 | +y_center = 100.0 |
| 63 | + |
| 64 | +# M electrodes (negative voltage terminal) |
| 65 | +M_locs = torch.stack( |
| 66 | + [x_locs, torch.full_like(x_locs, y_center), torch.zeros_like(x_locs)], dim=1 |
| 67 | +) |
| 68 | + |
| 69 | +# N electrodes (positive voltage terminal) - 20m spacing |
| 70 | +N_locs = torch.stack( |
| 71 | + [ |
| 72 | + x_locs + 20.0, |
| 73 | + torch.full_like(x_locs, y_center), |
| 74 | + torch.zeros_like(x_locs), |
| 75 | + ], |
| 76 | + dim=1, |
| 77 | +) |
| 78 | + |
| 79 | +rx = RxDipole(locations_m=M_locs, locations_n=N_locs) |
| 80 | + |
| 81 | +# Dipole source (A-B electrodes) - 40m dipole at center |
| 82 | +src = SrcDipole( |
| 83 | + [rx], |
| 84 | + torch.tensor([60.0, 100.0, 0.0]), # A electrode |
| 85 | + torch.tensor([140.0, 100.0, 0.0]), # B electrode |
| 86 | + current=1.0, |
| 87 | +) |
| 88 | + |
| 89 | +survey = Survey([src]) |
| 90 | +print(f"Survey: {survey.nD} data points") |
| 91 | + |
| 92 | +# ============================================================================ |
| 93 | +# 3. Create true model and generate synthetic data |
| 94 | +# ============================================================================ |
| 95 | +print("\n[3/6] Generating synthetic data...") |
| 96 | + |
| 97 | +# True model: background of 100 Ohm-m with a conductive block (10 Ohm-m) |
| 98 | +sigma_background = 0.01 # 1/100 S/m = 100 Ohm-m |
| 99 | +sigma_true = torch.ones(mesh.n_cells) * sigma_background |
| 100 | + |
| 101 | +# Add a conductive anomaly: centered block |
| 102 | +# Find cells in the anomaly region (x: 80-120m, y: 80-120m, z: -60 to -40m) |
| 103 | +cell_centers = mesh.cell_centers |
| 104 | +anomaly_mask = ( |
| 105 | + (cell_centers[:, 0] > 80.0) |
| 106 | + & (cell_centers[:, 0] < 120.0) |
| 107 | + & (cell_centers[:, 1] > 80.0) |
| 108 | + & (cell_centers[:, 1] < 120.0) |
| 109 | + & (cell_centers[:, 2] > -60.0) |
| 110 | + & (cell_centers[:, 2] < -40.0) |
| 111 | +) |
| 112 | +sigma_true[anomaly_mask] = 0.1 # 10 Ohm-m anomaly |
| 113 | + |
| 114 | +# Create mapping and PDE for forward modeling |
| 115 | +sigma_map_true = mappings.BaseMapping(sigma_true) |
| 116 | +pde_true = DC3DCellCentered(mesh, survey, sigma_map_true, bc_type="Dirichlet") |
| 117 | +solver_true = DirectSolver(pde_true) |
| 118 | + |
| 119 | +# Generate clean data |
| 120 | +with torch.no_grad(): |
| 121 | + data_clean = solver_true.forward() |
| 122 | + |
| 123 | +# Add noise (5% relative + 1e-6 V floor) |
| 124 | +torch.manual_seed(42) # For reproducibility |
| 125 | +noise_level = 0.05 |
| 126 | +noise_floor = 1e-6 |
| 127 | +noise = noise_level * torch.abs(data_clean) + noise_floor |
| 128 | +noise = noise * torch.randn_like(data_clean) |
| 129 | +data_obs = data_clean + noise |
| 130 | + |
| 131 | +print(f"Data range: [{data_obs.min():.3e}, {data_obs.max():.3e}] V") |
| 132 | +print(f"SNR: ~{1/noise_level:.0f}:1") |
| 133 | + |
| 134 | +# ============================================================================ |
| 135 | +# 4. Set up inversion |
| 136 | +# ============================================================================ |
| 137 | +print("\n[4/6] Setting up inversion...") |
| 138 | + |
| 139 | +# Starting model: homogeneous half-space (best guess) |
| 140 | +sigma_start = torch.ones(mesh.n_cells, requires_grad=True) * sigma_background |
| 141 | +sigma_map_inv = mappings.BaseMapping(sigma_start) |
| 142 | + |
| 143 | +# Create PDE and simulation wrapper for inversion |
| 144 | +pde_inv = DC3DCellCentered(mesh, survey, sigma_map_inv, bc_type="Dirichlet") |
| 145 | +solver_inv = DirectSolver(pde_inv) |
| 146 | +simulation = SimulationWrapper(pde_inv, solver_inv) |
| 147 | + |
| 148 | +# Data misfit with uncertainty weighting |
| 149 | +uncertainties = noise_level * torch.abs(data_obs) + noise_floor |
| 150 | +dmisfit = L2DataMisfit(simulation, data_obs, weights=1.0 / uncertainties) |
| 151 | + |
| 152 | +# Regularization (Tikhonov smooth inversion) |
| 153 | +alpha_s = 1e-4 # Smoothness weight |
| 154 | +alpha_x = 1.0 # x-derivative weight |
| 155 | +alpha_y = 1.0 # y-derivative weight |
| 156 | +alpha_z = 1.0 # z-derivative weight |
| 157 | + |
| 158 | +reg = TikhonovRegularization( |
| 159 | + mesh, |
| 160 | + alpha_s=alpha_s, |
| 161 | + alpha_x=alpha_x, |
| 162 | + alpha_y=alpha_y, |
| 163 | + alpha_z=alpha_z, |
| 164 | + reference_model=sigma_background * torch.ones(mesh.n_cells), |
| 165 | +) |
| 166 | + |
| 167 | +# Optimizer (Adam works well for geophysical inversions) |
| 168 | +optimizer = torch.optim.Adam([sigma_map_inv.trainable_parameters], lr=0.01) |
| 169 | + |
| 170 | +# Inverse problem |
| 171 | +inv_prob = BaseInvProblem( |
| 172 | + dmisfit, reg, optimizer, beta=1.0, max_iter=50 # Will be estimated |
| 173 | +) |
| 174 | + |
| 175 | +# Directives for inversion |
| 176 | +directives = [ |
| 177 | + BetaEstimate_ByEig(beta0_ratio=1.0), # Estimate initial beta |
| 178 | + BetaSchedule(cooling_factor=2.0, cooling_rate=3), # Cool beta every 3 iterations |
| 179 | + TargetMisfit(chi_factor=1.0), # Stop when misfit ~ # of data points |
| 180 | +] |
| 181 | + |
| 182 | +inversion = BaseInversion(inv_prob, directives=directives, device="cpu") |
| 183 | + |
| 184 | +print(f"Starting model: {sigma_start.mean():.4f} ± {sigma_start.std():.4f} S/m") |
| 185 | +print(f"True anomaly contrast: {sigma_true[anomaly_mask].mean() / sigma_background:.1f}x") |
| 186 | + |
| 187 | +# ============================================================================ |
| 188 | +# 5. Run inversion |
| 189 | +# ============================================================================ |
| 190 | +print("\n[5/6] Running inversion...") |
| 191 | +print("-" * 70) |
| 192 | + |
| 193 | +sigma_recovered = inversion.run(sigma_start) |
| 194 | + |
| 195 | +print("-" * 70) |
| 196 | +print(f"\nInversion completed in {inversion.iteration} iterations") |
| 197 | +print( |
| 198 | + f"Final misfit: φ_d = {inversion.phi_d_history[-1]:.2e} (target: {dmisfit.n_data:.2e})" |
| 199 | +) |
| 200 | +print(f"Recovered model: {sigma_recovered.mean():.4f} ± {sigma_recovered.std():.4f} S/m") |
| 201 | + |
| 202 | +# ============================================================================ |
| 203 | +# 6. Plot results |
| 204 | +# ============================================================================ |
| 205 | +print("\n[6/6] Plotting results...") |
| 206 | + |
| 207 | +fig, axes = plt.subplots(2, 3, figsize=(15, 10)) |
| 208 | + |
| 209 | +# Plot 1: Convergence curve |
| 210 | +ax = axes[0, 0] |
| 211 | +ax.semilogy(inversion.phi_d_history, "b-o", label="Data misfit") |
| 212 | +ax.semilogy(inversion.phi_m_history, "r-s", label="Model norm") |
| 213 | +ax.axhline(dmisfit.n_data, color="k", linestyle="--", label="Target") |
| 214 | +ax.set_xlabel("Iteration") |
| 215 | +ax.set_ylabel("Misfit") |
| 216 | +ax.set_title("Convergence") |
| 217 | +ax.legend() |
| 218 | +ax.grid(True, alpha=0.3) |
| 219 | + |
| 220 | +# Plot 2: Beta schedule |
| 221 | +ax = axes[0, 1] |
| 222 | +ax.semilogy(inversion.beta_history, "g-d") |
| 223 | +ax.set_xlabel("Iteration") |
| 224 | +ax.set_ylabel("Beta (trade-off parameter)") |
| 225 | +ax.set_title("Beta Cooling Schedule") |
| 226 | +ax.grid(True, alpha=0.3) |
| 227 | + |
| 228 | +# Plot 3: Data fit |
| 229 | +ax = axes[0, 2] |
| 230 | +ax.plot(data_obs.numpy(), data_clean.numpy(), "b.", alpha=0.5, label="Observed") |
| 231 | +pred_final = simulation.dpred() |
| 232 | +ax.plot(data_obs.numpy(), pred_final.detach().numpy(), "r.", alpha=0.5, label="Predicted") |
| 233 | +lims = [data_obs.min().item(), data_obs.max().item()] |
| 234 | +ax.plot(lims, lims, "k--", alpha=0.3) |
| 235 | +ax.set_xlabel("Observed Data (V)") |
| 236 | +ax.set_ylabel("Predicted Data (V)") |
| 237 | +ax.set_title("Data Fit") |
| 238 | +ax.legend() |
| 239 | +ax.grid(True, alpha=0.3) |
| 240 | + |
| 241 | +# Extract middle slice (y = 100m) for visualization |
| 242 | +y_slice_idx = mesh.shape_cells[1] // 2 |
| 243 | +x_coords = mesh.cell_centers_x |
| 244 | +z_coords = mesh.cell_centers_z |
| 245 | + |
| 246 | +# Reshape models to grid |
| 247 | +true_grid = sigma_true.reshape(mesh.shape_cells)[:, y_slice_idx, :] |
| 248 | +recovered_grid = ( |
| 249 | + torch.tensor(sigma_recovered).reshape(mesh.shape_cells)[:, y_slice_idx, :] |
| 250 | +) |
| 251 | + |
| 252 | +# Convert to resistivity for plotting |
| 253 | +rho_true = 1.0 / true_grid.numpy() |
| 254 | +rho_recovered = 1.0 / recovered_grid.numpy() |
| 255 | + |
| 256 | +# Plot 4: True model (resistivity) |
| 257 | +ax = axes[1, 0] |
| 258 | +im = ax.pcolormesh( |
| 259 | + x_coords.numpy(), |
| 260 | + z_coords.numpy(), |
| 261 | + rho_true.T, |
| 262 | + cmap="jet_r", |
| 263 | + vmin=10, |
| 264 | + vmax=100, |
| 265 | + shading="auto", |
| 266 | +) |
| 267 | +ax.set_xlabel("X (m)") |
| 268 | +ax.set_ylabel("Z (m)") |
| 269 | +ax.set_title("True Resistivity (Ω⋅m)") |
| 270 | +ax.set_aspect("equal") |
| 271 | +plt.colorbar(im, ax=ax) |
| 272 | + |
| 273 | +# Plot 5: Recovered model (resistivity) |
| 274 | +ax = axes[1, 1] |
| 275 | +im = ax.pcolormesh( |
| 276 | + x_coords.numpy(), |
| 277 | + z_coords.numpy(), |
| 278 | + rho_recovered.T, |
| 279 | + cmap="jet_r", |
| 280 | + vmin=10, |
| 281 | + vmax=100, |
| 282 | + shading="auto", |
| 283 | +) |
| 284 | +ax.set_xlabel("X (m)") |
| 285 | +ax.set_ylabel("Z (m)") |
| 286 | +ax.set_title("Recovered Resistivity (Ω⋅m)") |
| 287 | +ax.set_aspect("equal") |
| 288 | +plt.colorbar(im, ax=ax) |
| 289 | + |
| 290 | +# Plot 6: Difference |
| 291 | +ax = axes[1, 2] |
| 292 | +diff = rho_recovered - rho_true |
| 293 | +im = ax.pcolormesh( |
| 294 | + x_coords.numpy(), |
| 295 | + z_coords.numpy(), |
| 296 | + diff.T, |
| 297 | + cmap="RdBu_r", |
| 298 | + vmin=-20, |
| 299 | + vmax=20, |
| 300 | + shading="auto", |
| 301 | +) |
| 302 | +ax.set_xlabel("X (m)") |
| 303 | +ax.set_ylabel("Z (m)") |
| 304 | +ax.set_title("Difference (Ω⋅m)") |
| 305 | +ax.set_aspect("equal") |
| 306 | +plt.colorbar(im, ax=ax) |
| 307 | + |
| 308 | +plt.tight_layout() |
| 309 | +plt.savefig("dc_inversion_results.png", dpi=150, bbox_inches="tight") |
| 310 | +print("Results saved to: dc_inversion_results.png") |
| 311 | +plt.show() |
| 312 | + |
| 313 | +print("\n" + "=" * 70) |
| 314 | +print("Inversion complete!") |
| 315 | +print("=" * 70) |
0 commit comments