Skip to content

Commit 3b74f7e

Browse files
committed
Added FDEM implementation of new PDE framework
1 parent 0ced89c commit 3b74f7e

File tree

15 files changed

+2095
-1184
lines changed

15 files changed

+2095
-1184
lines changed

examples/dc_inversion_example.py

Lines changed: 315 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,315 @@
1+
"""
2+
DC Resistivity Inversion Example using New PDE Framework
3+
4+
This example demonstrates a complete inversion workflow:
5+
1. Create synthetic "true" model
6+
2. Generate synthetic data with noise
7+
3. Set up inversion with regularization
8+
4. Run inversion with beta cooling
9+
5. Visualize results
10+
"""
11+
12+
import torch
13+
import numpy as np
14+
import matplotlib.pyplot as plt
15+
from simpegtorch.discretize import TensorMesh
16+
from simpegtorch.simulation.resistivity import (
17+
DC3DCellCentered,
18+
SrcDipole,
19+
RxDipole,
20+
Survey,
21+
)
22+
from simpegtorch.simulation.base import DirectSolver, SimulationWrapper, mappings
23+
from simpegtorch.inversion import (
24+
BaseInversion,
25+
BaseInvProblem,
26+
BetaSchedule,
27+
TargetMisfit,
28+
BetaEstimate_ByEig,
29+
)
30+
from simpegtorch.data_misfit import L2DataMisfit
31+
from simpegtorch.regularization import TikhonovRegularization
32+
33+
# Set default dtype
34+
torch.set_default_dtype(torch.float64)
35+
36+
print("=" * 70)
37+
print("DC Resistivity Inversion with New PDE Framework")
38+
print("=" * 70)
39+
40+
# ============================================================================
41+
# 1. Create mesh
42+
# ============================================================================
43+
print("\n[1/6] Creating mesh...")
44+
hx = torch.ones(20) * 10.0 # 20 cells of 10m
45+
hy = torch.ones(20) * 10.0
46+
hz = torch.ones(10) * 10.0
47+
48+
# 200m x 200m x 100m mesh
49+
origin = torch.tensor([0.0, 0.0, -100.0]) # Start 100m below surface
50+
mesh = TensorMesh([hx, hy, hz], origin=origin)
51+
52+
print(f"Mesh: {mesh.n_cells} cells ({mesh.shape_cells[0]}x{mesh.shape_cells[1]}x{mesh.shape_cells[2]})")
53+
54+
# ============================================================================
55+
# 2. Create survey
56+
# ============================================================================
57+
print("\n[2/6] Creating survey...")
58+
59+
# Create a dipole-dipole survey with multiple measurements
60+
n_data_points = 15
61+
x_locs = torch.linspace(20.0, 180.0, n_data_points)
62+
y_center = 100.0
63+
64+
# M electrodes (negative voltage terminal)
65+
M_locs = torch.stack(
66+
[x_locs, torch.full_like(x_locs, y_center), torch.zeros_like(x_locs)], dim=1
67+
)
68+
69+
# N electrodes (positive voltage terminal) - 20m spacing
70+
N_locs = torch.stack(
71+
[
72+
x_locs + 20.0,
73+
torch.full_like(x_locs, y_center),
74+
torch.zeros_like(x_locs),
75+
],
76+
dim=1,
77+
)
78+
79+
rx = RxDipole(locations_m=M_locs, locations_n=N_locs)
80+
81+
# Dipole source (A-B electrodes) - 40m dipole at center
82+
src = SrcDipole(
83+
[rx],
84+
torch.tensor([60.0, 100.0, 0.0]), # A electrode
85+
torch.tensor([140.0, 100.0, 0.0]), # B electrode
86+
current=1.0,
87+
)
88+
89+
survey = Survey([src])
90+
print(f"Survey: {survey.nD} data points")
91+
92+
# ============================================================================
93+
# 3. Create true model and generate synthetic data
94+
# ============================================================================
95+
print("\n[3/6] Generating synthetic data...")
96+
97+
# True model: background of 100 Ohm-m with a conductive block (10 Ohm-m)
98+
sigma_background = 0.01 # 1/100 S/m = 100 Ohm-m
99+
sigma_true = torch.ones(mesh.n_cells) * sigma_background
100+
101+
# Add a conductive anomaly: centered block
102+
# Find cells in the anomaly region (x: 80-120m, y: 80-120m, z: -60 to -40m)
103+
cell_centers = mesh.cell_centers
104+
anomaly_mask = (
105+
(cell_centers[:, 0] > 80.0)
106+
& (cell_centers[:, 0] < 120.0)
107+
& (cell_centers[:, 1] > 80.0)
108+
& (cell_centers[:, 1] < 120.0)
109+
& (cell_centers[:, 2] > -60.0)
110+
& (cell_centers[:, 2] < -40.0)
111+
)
112+
sigma_true[anomaly_mask] = 0.1 # 10 Ohm-m anomaly
113+
114+
# Create mapping and PDE for forward modeling
115+
sigma_map_true = mappings.BaseMapping(sigma_true)
116+
pde_true = DC3DCellCentered(mesh, survey, sigma_map_true, bc_type="Dirichlet")
117+
solver_true = DirectSolver(pde_true)
118+
119+
# Generate clean data
120+
with torch.no_grad():
121+
data_clean = solver_true.forward()
122+
123+
# Add noise (5% relative + 1e-6 V floor)
124+
torch.manual_seed(42) # For reproducibility
125+
noise_level = 0.05
126+
noise_floor = 1e-6
127+
noise = noise_level * torch.abs(data_clean) + noise_floor
128+
noise = noise * torch.randn_like(data_clean)
129+
data_obs = data_clean + noise
130+
131+
print(f"Data range: [{data_obs.min():.3e}, {data_obs.max():.3e}] V")
132+
print(f"SNR: ~{1/noise_level:.0f}:1")
133+
134+
# ============================================================================
135+
# 4. Set up inversion
136+
# ============================================================================
137+
print("\n[4/6] Setting up inversion...")
138+
139+
# Starting model: homogeneous half-space (best guess)
140+
sigma_start = torch.ones(mesh.n_cells, requires_grad=True) * sigma_background
141+
sigma_map_inv = mappings.BaseMapping(sigma_start)
142+
143+
# Create PDE and simulation wrapper for inversion
144+
pde_inv = DC3DCellCentered(mesh, survey, sigma_map_inv, bc_type="Dirichlet")
145+
solver_inv = DirectSolver(pde_inv)
146+
simulation = SimulationWrapper(pde_inv, solver_inv)
147+
148+
# Data misfit with uncertainty weighting
149+
uncertainties = noise_level * torch.abs(data_obs) + noise_floor
150+
dmisfit = L2DataMisfit(simulation, data_obs, weights=1.0 / uncertainties)
151+
152+
# Regularization (Tikhonov smooth inversion)
153+
alpha_s = 1e-4 # Smoothness weight
154+
alpha_x = 1.0 # x-derivative weight
155+
alpha_y = 1.0 # y-derivative weight
156+
alpha_z = 1.0 # z-derivative weight
157+
158+
reg = TikhonovRegularization(
159+
mesh,
160+
alpha_s=alpha_s,
161+
alpha_x=alpha_x,
162+
alpha_y=alpha_y,
163+
alpha_z=alpha_z,
164+
reference_model=sigma_background * torch.ones(mesh.n_cells),
165+
)
166+
167+
# Optimizer (Adam works well for geophysical inversions)
168+
optimizer = torch.optim.Adam([sigma_map_inv.trainable_parameters], lr=0.01)
169+
170+
# Inverse problem
171+
inv_prob = BaseInvProblem(
172+
dmisfit, reg, optimizer, beta=1.0, max_iter=50 # Will be estimated
173+
)
174+
175+
# Directives for inversion
176+
directives = [
177+
BetaEstimate_ByEig(beta0_ratio=1.0), # Estimate initial beta
178+
BetaSchedule(cooling_factor=2.0, cooling_rate=3), # Cool beta every 3 iterations
179+
TargetMisfit(chi_factor=1.0), # Stop when misfit ~ # of data points
180+
]
181+
182+
inversion = BaseInversion(inv_prob, directives=directives, device="cpu")
183+
184+
print(f"Starting model: {sigma_start.mean():.4f} ± {sigma_start.std():.4f} S/m")
185+
print(f"True anomaly contrast: {sigma_true[anomaly_mask].mean() / sigma_background:.1f}x")
186+
187+
# ============================================================================
188+
# 5. Run inversion
189+
# ============================================================================
190+
print("\n[5/6] Running inversion...")
191+
print("-" * 70)
192+
193+
sigma_recovered = inversion.run(sigma_start)
194+
195+
print("-" * 70)
196+
print(f"\nInversion completed in {inversion.iteration} iterations")
197+
print(
198+
f"Final misfit: φ_d = {inversion.phi_d_history[-1]:.2e} (target: {dmisfit.n_data:.2e})"
199+
)
200+
print(f"Recovered model: {sigma_recovered.mean():.4f} ± {sigma_recovered.std():.4f} S/m")
201+
202+
# ============================================================================
203+
# 6. Plot results
204+
# ============================================================================
205+
print("\n[6/6] Plotting results...")
206+
207+
fig, axes = plt.subplots(2, 3, figsize=(15, 10))
208+
209+
# Plot 1: Convergence curve
210+
ax = axes[0, 0]
211+
ax.semilogy(inversion.phi_d_history, "b-o", label="Data misfit")
212+
ax.semilogy(inversion.phi_m_history, "r-s", label="Model norm")
213+
ax.axhline(dmisfit.n_data, color="k", linestyle="--", label="Target")
214+
ax.set_xlabel("Iteration")
215+
ax.set_ylabel("Misfit")
216+
ax.set_title("Convergence")
217+
ax.legend()
218+
ax.grid(True, alpha=0.3)
219+
220+
# Plot 2: Beta schedule
221+
ax = axes[0, 1]
222+
ax.semilogy(inversion.beta_history, "g-d")
223+
ax.set_xlabel("Iteration")
224+
ax.set_ylabel("Beta (trade-off parameter)")
225+
ax.set_title("Beta Cooling Schedule")
226+
ax.grid(True, alpha=0.3)
227+
228+
# Plot 3: Data fit
229+
ax = axes[0, 2]
230+
ax.plot(data_obs.numpy(), data_clean.numpy(), "b.", alpha=0.5, label="Observed")
231+
pred_final = simulation.dpred()
232+
ax.plot(data_obs.numpy(), pred_final.detach().numpy(), "r.", alpha=0.5, label="Predicted")
233+
lims = [data_obs.min().item(), data_obs.max().item()]
234+
ax.plot(lims, lims, "k--", alpha=0.3)
235+
ax.set_xlabel("Observed Data (V)")
236+
ax.set_ylabel("Predicted Data (V)")
237+
ax.set_title("Data Fit")
238+
ax.legend()
239+
ax.grid(True, alpha=0.3)
240+
241+
# Extract middle slice (y = 100m) for visualization
242+
y_slice_idx = mesh.shape_cells[1] // 2
243+
x_coords = mesh.cell_centers_x
244+
z_coords = mesh.cell_centers_z
245+
246+
# Reshape models to grid
247+
true_grid = sigma_true.reshape(mesh.shape_cells)[:, y_slice_idx, :]
248+
recovered_grid = (
249+
torch.tensor(sigma_recovered).reshape(mesh.shape_cells)[:, y_slice_idx, :]
250+
)
251+
252+
# Convert to resistivity for plotting
253+
rho_true = 1.0 / true_grid.numpy()
254+
rho_recovered = 1.0 / recovered_grid.numpy()
255+
256+
# Plot 4: True model (resistivity)
257+
ax = axes[1, 0]
258+
im = ax.pcolormesh(
259+
x_coords.numpy(),
260+
z_coords.numpy(),
261+
rho_true.T,
262+
cmap="jet_r",
263+
vmin=10,
264+
vmax=100,
265+
shading="auto",
266+
)
267+
ax.set_xlabel("X (m)")
268+
ax.set_ylabel("Z (m)")
269+
ax.set_title("True Resistivity (Ω⋅m)")
270+
ax.set_aspect("equal")
271+
plt.colorbar(im, ax=ax)
272+
273+
# Plot 5: Recovered model (resistivity)
274+
ax = axes[1, 1]
275+
im = ax.pcolormesh(
276+
x_coords.numpy(),
277+
z_coords.numpy(),
278+
rho_recovered.T,
279+
cmap="jet_r",
280+
vmin=10,
281+
vmax=100,
282+
shading="auto",
283+
)
284+
ax.set_xlabel("X (m)")
285+
ax.set_ylabel("Z (m)")
286+
ax.set_title("Recovered Resistivity (Ω⋅m)")
287+
ax.set_aspect("equal")
288+
plt.colorbar(im, ax=ax)
289+
290+
# Plot 6: Difference
291+
ax = axes[1, 2]
292+
diff = rho_recovered - rho_true
293+
im = ax.pcolormesh(
294+
x_coords.numpy(),
295+
z_coords.numpy(),
296+
diff.T,
297+
cmap="RdBu_r",
298+
vmin=-20,
299+
vmax=20,
300+
shading="auto",
301+
)
302+
ax.set_xlabel("X (m)")
303+
ax.set_ylabel("Z (m)")
304+
ax.set_title("Difference (Ω⋅m)")
305+
ax.set_aspect("equal")
306+
plt.colorbar(im, ax=ax)
307+
308+
plt.tight_layout()
309+
plt.savefig("dc_inversion_results.png", dpi=150, bbox_inches="tight")
310+
print("Results saved to: dc_inversion_results.png")
311+
plt.show()
312+
313+
print("\n" + "=" * 70)
314+
print("Inversion complete!")
315+
print("=" * 70)

0 commit comments

Comments
 (0)