Skip to content

Commit 0ced89c

Browse files
committed
updated tests to work with new DCR formulation
1 parent a90497b commit 0ced89c

File tree

3 files changed

+266
-213
lines changed

3 files changed

+266
-213
lines changed

tests/test_gradients/test_full_dc_simulation.py

Lines changed: 117 additions & 111 deletions
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,29 @@
44

55
import torch
66
from simpegtorch.discretize import TensorMesh
7-
from simpegtorch.electromagnetics.resistivity.simulation import (
8-
Simulation3DCellCentered,
7+
from simpegtorch.simulation.resistivity import (
8+
DC3DCellCentered,
9+
SrcDipole,
10+
SrcPole,
11+
RxDipole,
12+
RxPole,
13+
Survey,
914
)
10-
from simpegtorch.electromagnetics.resistivity import sources, receivers, survey
15+
from simpegtorch.simulation.base import DirectSolver, mappings
1116

1217
torch.set_default_dtype(torch.float64)
1318

1419

1520
def test_dc_simulation_fields_with_gradients():
16-
"""Test complete DC simulation with field computation and gradients using new source/receiver classes."""
21+
"""Test complete DC simulation with field computation and gradients using new PDE architecture."""
1722

1823
# Create a simple 3D mesh with explicit cell sizes
1924
h = torch.ones(10, dtype=torch.float64) # 10 cells of size 1.0 each for more room
2025
mesh = TensorMesh([h, h, h], dtype=torch.float64)
2126

22-
# Create resistivity with gradients
23-
resistivity = torch.full(
24-
(mesh.n_cells,), 100.0, dtype=torch.float64, requires_grad=True
27+
# Create conductivity with gradients (note: using conductivity as base parameter)
28+
sigma = torch.full(
29+
(mesh.n_cells,), 0.01, dtype=torch.float64, requires_grad=True
2530
)
2631

2732
# Create receivers for measuring potential differences
@@ -42,32 +47,32 @@ def test_dc_simulation_fields_with_gradients():
4247
dtype=torch.float64,
4348
)
4449

45-
rx = receivers.Dipole(locations_m=rx_locations_m, locations_n=rx_locations_n)
50+
rx = RxDipole(locations_m=rx_locations_m, locations_n=rx_locations_n)
4651

4752
# Create dipole source within mesh bounds
4853
src_location_a = torch.tensor([2.0, 3.5, 2.0], dtype=torch.float64) # A electrode
4954
src_location_b = torch.tensor([6.0, 3.5, 2.0], dtype=torch.float64) # B electrode
50-
src = sources.Dipole(
51-
[rx], location_a=src_location_a, location_b=src_location_b, current=1.0
55+
src = SrcDipole(
56+
[rx], src_location_a, src_location_b, current=1.0
5257
)
5358

5459
# Create survey
55-
surv = survey.Survey([src])
60+
surv = Survey([src])
5661

57-
# Create DC simulation with survey
58-
sim = Simulation3DCellCentered(mesh, survey=surv)
59-
sim.setBC()
62+
# Create resistivity mapping (resistivity = 1/sigma)
63+
resistivity_map = mappings.InverseMapping(sigma)
6064

61-
# Compute fields - this tests the complete gradient flow through:
65+
# Create DC PDE and solver
66+
pde = DC3DCellCentered(mesh, surv, resistivity_map, bc_type="Dirichlet")
67+
solver = DirectSolver(pde)
68+
69+
# Compute predicted data - this tests the complete gradient flow through:
6270
# 1. Source discretization to mesh
6371
# 2. Face inner product with inversion
6472
# 3. System matrix assembly (D @ MfRhoI @ G)
6573
# 4. Linear system solve with TorchMatSolver
6674
# 5. Receiver evaluation from fields
67-
fields = sim.fields(resistivity)
68-
69-
# Compute predicted data
70-
predicted_data = sim.dpred(resistivity)
75+
predicted_data = solver.forward()
7176

7277
# Compute a simple objective function based on data
7378
loss = torch.sum(predicted_data**2)
@@ -76,47 +81,35 @@ def test_dc_simulation_fields_with_gradients():
7681
loss.backward()
7782

7883
# Verify results
79-
assert isinstance(
80-
fields, dict
81-
), "Fields should be a dictionary for multiple sources"
82-
assert src in fields, "Fields should contain entry for source"
83-
84-
src_fields = fields[src]
85-
assert src_fields is not None, "Fields should be computed"
86-
assert src_fields.shape[0] == mesh.n_cells, "Fields should have correct shape"
87-
assert torch.all(torch.isfinite(src_fields)), "Fields should be finite"
88-
8984
assert predicted_data is not None, "Predicted data should be computed"
90-
assert predicted_data.shape[0] == rx.nD, "Data should have correct shape"
85+
assert predicted_data.shape[0] == rx.locations_m.shape[0], "Data should have correct shape"
9186
assert torch.all(torch.isfinite(predicted_data)), "Data should be finite"
9287

93-
assert resistivity.grad is not None, "Gradients should be computed"
94-
assert torch.all(torch.isfinite(resistivity.grad)), "Gradients should be finite"
95-
assert torch.any(resistivity.grad != 0), "Some gradients should be non-zero"
88+
assert resistivity_map.trainable_parameters.grad is not None, "Gradients should be computed"
89+
assert torch.all(torch.isfinite(resistivity_map.trainable_parameters.grad)), "Gradients should be finite"
90+
assert torch.any(resistivity_map.trainable_parameters.grad != 0), "Some gradients should be non-zero"
9691

97-
print("✅ Complete DC simulation test with new source/receiver classes passed")
92+
print("✅ Complete DC simulation test with new PDE architecture passed")
9893
print(f"Mesh cells: {mesh.n_cells}")
9994
print(f"Source: Dipole at A={src.location_a}, B={src.location_b}")
100-
print(f"Receivers: {rx.nD} dipole measurements")
101-
print(f"Fields shape: {src_fields.shape}")
102-
print(f"Fields range: [{src_fields.min():.6f}, {src_fields.max():.6f}]")
95+
print(f"Receivers: {rx.locations_m.shape[0]} dipole measurements")
10396
print(f"Data shape: {predicted_data.shape}")
10497
print(f"Data range: [{predicted_data.min():.6f}, {predicted_data.max():.6f}] V")
105-
print(f"Gradient mean: {resistivity.grad.mean():.2e}")
106-
print(f"Gradient std: {resistivity.grad.std():.2e}")
98+
print(f"Gradient mean: {resistivity_map.trainable_parameters.grad.mean():.2e}")
99+
print(f"Gradient std: {resistivity_map.trainable_parameters.grad.std():.2e}")
107100

108101

109102
def test_dc_simulation_jtvec():
110-
"""Test Jtvec functionality with new source/receiver classes."""
103+
"""Test Jacobian transpose vector product functionality with new PDE architecture."""
111104

112105
# Create mesh and model
113106
h = torch.ones(6, dtype=torch.float64)
114107
mesh = TensorMesh([h, h, h], dtype=torch.float64)
115-
resistivity = torch.full(
116-
(mesh.n_cells,), 50.0, dtype=torch.float64, requires_grad=True
108+
sigma = torch.full(
109+
(mesh.n_cells,), 0.02, dtype=torch.float64, requires_grad=True
117110
)
118111

119-
# Create multiple receivers within mesh bounds [0,6] x [0,6] x [0,4]
112+
# Create multiple receivers within mesh bounds [0,6] x [0,6] x [0,6]
120113
rx_locations_m = torch.tensor(
121114
[
122115
[2.0, 2.0, 0.5],
@@ -135,44 +128,53 @@ def test_dc_simulation_jtvec():
135128
dtype=torch.float64,
136129
)
137130

138-
rx = receivers.Dipole(locations_m=rx_locations_m, locations_n=rx_locations_n)
131+
rx = RxDipole(locations_m=rx_locations_m, locations_n=rx_locations_n)
139132

140133
# Create source within mesh bounds
141-
src = sources.Dipole(
134+
src = SrcDipole(
142135
[rx],
143-
location_a=torch.tensor([1.0, 2.5, 0.5], dtype=torch.float64),
144-
location_b=torch.tensor([5.0, 2.5, 0.5], dtype=torch.float64),
136+
torch.tensor([1.0, 2.5, 0.5], dtype=torch.float64),
137+
torch.tensor([5.0, 2.5, 0.5], dtype=torch.float64),
145138
current=2.0,
146139
)
147140

148-
# Create survey and simulation
149-
surv = survey.Survey([src])
150-
sim = Simulation3DCellCentered(mesh, survey=surv)
151-
sim.setBC()
152-
153-
# Test Jtvec
154-
data_residuals = torch.randn(rx.nD, dtype=torch.float64)
155-
gradient = sim.Jtvec(resistivity, data_residuals)
141+
# Create survey, mapping and solver
142+
surv = Survey([src])
143+
resistivity_map = mappings.InverseMapping(sigma)
144+
pde = DC3DCellCentered(mesh, surv, resistivity_map, bc_type="Dirichlet")
145+
solver = DirectSolver(pde)
146+
147+
# Test Jtvec by computing gradients manually
148+
data_residuals = torch.randn(rx.locations_m.shape[0], dtype=torch.float64)
149+
150+
# Forward pass
151+
predicted_data = solver.forward()
152+
153+
# Compute loss using data residuals (Jtvec equivalent)
154+
loss = torch.sum(data_residuals * predicted_data)
155+
loss.backward()
156+
157+
gradient = resistivity_map.trainable_parameters.grad
156158

157159
# Verify results
158-
assert gradient is not None, "Jtvec should return gradient"
160+
assert gradient is not None, "Gradients should be computed"
159161
assert gradient.shape[0] == mesh.n_cells, "Gradient should have correct shape"
160162
assert torch.all(torch.isfinite(gradient)), "Gradient should be finite"
161163

162-
print("✅ Jtvec test with new source/receiver classes passed")
164+
print("✅ Jacobian transpose test with new PDE architecture passed")
163165
print(f"Data residuals shape: {data_residuals.shape}")
164166
print(f"Gradient shape: {gradient.shape}")
165167
print(f"Gradient range: [{gradient.min():.2e}, {gradient.max():.2e}]")
166168

167169

168170
def test_dc_simulation_multiple_sources():
169-
"""Test simulation with multiple sources using new classes."""
171+
"""Test simulation with multiple sources using new PDE architecture."""
170172

171173
# Create mesh
172174
h = torch.ones(8, dtype=torch.float64)
173175
mesh = TensorMesh([h, h, h], dtype=torch.float64)
174-
resistivity = torch.full(
175-
(mesh.n_cells,), 75.0, dtype=torch.float64, requires_grad=True
176+
sigma = torch.full(
177+
(mesh.n_cells,), 1.0/75.0, dtype=torch.float64, requires_grad=True
176178
)
177179

178180
# Create shared receivers
@@ -194,66 +196,69 @@ def test_dc_simulation_multiple_sources():
194196
dtype=torch.float64,
195197
)
196198

197-
rx = receivers.Dipole(locations_m=rx_locations_m, locations_n=rx_locations_n)
199+
rx = RxDipole(locations_m=rx_locations_m, locations_n=rx_locations_n)
198200

199201
# Create multiple sources
200-
src1 = sources.Dipole(
202+
src1 = SrcDipole(
201203
[rx],
202-
location_a=torch.tensor([1.0, 3.5, 1.0], dtype=torch.float64),
203-
location_b=torch.tensor([2.0, 3.5, 1.0], dtype=torch.float64),
204+
torch.tensor([1.0, 3.5, 1.0], dtype=torch.float64),
205+
torch.tensor([2.0, 3.5, 1.0], dtype=torch.float64),
204206
)
205207

206-
src2 = sources.Dipole(
208+
src2 = SrcDipole(
207209
[rx],
208-
location_a=torch.tensor([6.0, 3.5, 1.0], dtype=torch.float64),
209-
location_b=torch.tensor([7.0, 3.5, 1.0], dtype=torch.float64),
210+
torch.tensor([6.0, 3.5, 1.0], dtype=torch.float64),
211+
torch.tensor([7.0, 3.5, 1.0], dtype=torch.float64),
210212
)
211213

212214
# Test pole source too
213-
rx_pole = receivers.Pole(
215+
rx_pole = RxPole(
214216
locations=torch.tensor([[4.0, 3.5, 2.0]], dtype=torch.float64)
215217
)
216-
src3 = sources.Pole(
217-
[rx_pole], location=torch.tensor([4.0, 1.0, 1.0], dtype=torch.float64)
218+
src3 = SrcPole(
219+
[rx_pole], torch.tensor([4.0, 1.0, 1.0], dtype=torch.float64)
218220
)
219221

220222
# Create survey with multiple sources
221-
surv = survey.Survey([src1, src2, src3])
222-
sim = Simulation3DCellCentered(mesh, survey=surv)
223-
sim.setBC()
223+
surv = Survey([src1, src2, src3])
224+
resistivity_map = mappings.InverseMapping(sigma)
225+
pde = DC3DCellCentered(mesh, surv, resistivity_map, bc_type="Dirichlet")
226+
solver = DirectSolver(pde)
224227

225228
# Test forward modeling
226-
predicted_data = sim.dpred(resistivity)
229+
predicted_data = solver.forward()
227230

228-
# Test Jtvec with multiple sources
231+
# Test Jacobian transpose with multiple sources
229232
data_residuals = torch.randn_like(predicted_data)
230-
gradient = sim.Jtvec(resistivity, data_residuals)
233+
loss = torch.sum(data_residuals * predicted_data)
234+
loss.backward()
235+
gradient = resistivity_map.trainable_parameters.grad
231236

232237
# Verify results
233-
expected_data_count = rx.nD * 2 + rx_pole.nD # 2 dipole sources + 1 pole source
238+
expected_data_count = rx.locations_m.shape[0] * 2 + rx_pole.locations.shape[0] # 2 dipole sources + 1 pole source
234239
assert (
235240
predicted_data.shape[0] == expected_data_count
236241
), f"Expected {expected_data_count} data points"
237242
assert gradient.shape[0] == mesh.n_cells, "Gradient should have correct shape"
238243
assert torch.all(torch.isfinite(predicted_data)), "Data should be finite"
239244
assert torch.all(torch.isfinite(gradient)), "Gradient should be finite"
240245

241-
print("✅ Multiple sources test with new source/receiver classes passed")
242-
print(f"Sources: {surv.nSrc} (2 dipole + 1 pole)")
243-
print(f"Total data points: {surv.nD}")
246+
print("✅ Multiple sources test with new PDE architecture passed")
247+
print(f"Sources: {len(surv.source_list)} (2 dipole + 1 pole)")
248+
print(f"Total data points: {predicted_data.shape[0]}")
244249
print(f"Data range: [{predicted_data.min():.6f}, {predicted_data.max():.6f}] V")
245250
print(f"Gradient range: [{gradient.min():.2e}, {gradient.max():.2e}]")
246251

247252

248253
def test_dc_simulation_apparent_resistivity():
249-
"""Test apparent resistivity calculations."""
254+
"""Test apparent resistivity calculations with new PDE architecture."""
250255

251256
# Create mesh
252257
h = torch.ones(10, dtype=torch.float64)
253258
mesh = TensorMesh([h, h, h], dtype=torch.float64)
254-
resistivity = torch.full((mesh.n_cells,), 100.0, dtype=torch.float64)
259+
sigma = torch.full((mesh.n_cells,), 0.01, dtype=torch.float64) # 1/100 S/m
255260

256-
# Create receivers with apparent resistivity data type within mesh bounds [0,10] x [0,10] x [0,10]
261+
# Create receivers for potential differences within mesh bounds [0,10] x [0,10] x [0,10]
257262
# Use a standard dipole-dipole configuration for proper geometric factors
258263
rx_locations_m = torch.tensor(
259264
[
@@ -271,48 +276,49 @@ def test_dc_simulation_apparent_resistivity():
271276
dtype=torch.float64,
272277
)
273278

274-
rx = receivers.Dipole(
275-
locations_m=rx_locations_m,
276-
locations_n=rx_locations_n,
277-
data_type="apparent_resistivity",
278-
)
279+
rx = RxDipole(locations_m=rx_locations_m, locations_n=rx_locations_n)
279280

280281
# Create source within mesh bounds - standard dipole-dipole array
281-
src = sources.Dipole(
282+
src = SrcDipole(
282283
[rx],
283-
location_a=torch.tensor([1.0, 5.0, 1.0], dtype=torch.float64), # A electrode
284-
location_b=torch.tensor([2.0, 5.0, 1.0], dtype=torch.float64), # B electrode
284+
torch.tensor([1.0, 5.0, 1.0], dtype=torch.float64), # A electrode
285+
torch.tensor([2.0, 5.0, 1.0], dtype=torch.float64), # B electrode
285286
)
286287

287-
# Create survey and set geometric factors
288-
surv = survey.Survey([src])
289-
geometric_factors = surv.set_geometric_factor(space_type="halfspace")
290-
291-
# Create simulation
292-
sim = Simulation3DCellCentered(mesh, survey=surv)
293-
sim.setBC()
294-
295-
# Test apparent resistivity calculation
296-
apparent_resistivity = sim.dpred(resistivity)
288+
# Create survey and simulation
289+
surv = Survey([src])
290+
resistivity_map = mappings.InverseMapping(sigma)
291+
pde = DC3DCellCentered(mesh, surv, resistivity_map, bc_type="Dirichlet")
292+
solver = DirectSolver(pde)
293+
294+
# Test potential difference calculation
295+
predicted_data = solver.forward()
296+
297+
# Calculate apparent resistivity manually using simple geometric factor
298+
# For dipole-dipole: rho_app = K * (V_MN / I) where K is geometric factor
299+
current = 1.0 # Default current
300+
# Simplified geometric factor for dipole-dipole (depends on electrode spacing)
301+
electrode_spacing = 1.0 # Unit spacing
302+
geometric_factor = 2 * torch.pi * electrode_spacing # Simplified K
303+
apparent_resistivity = geometric_factor * torch.abs(predicted_data) / current
297304

298305
# Debug output
306+
print(f"Potential differences: {predicted_data}")
299307
print(f"Apparent resistivity values: {apparent_resistivity}")
300-
print(f"Geometric factors: {geometric_factors}")
301308

302309
# Verify results
310+
assert torch.all(torch.isfinite(predicted_data))
303311
assert torch.all(torch.isfinite(apparent_resistivity))
312+
assert torch.all(apparent_resistivity > 0)
304313

305-
print("✅ Apparent resistivity test with new source/receiver classes passed")
306-
print(
307-
f"Geometric factors range: [{geometric_factors.min():.6f}, {geometric_factors.max():.6f}]"
308-
)
309-
print(
310-
f"Apparent resistivity range: [{apparent_resistivity.min():.1f}, {apparent_resistivity.max():.1f}] Ω⋅m"
311-
)
314+
print("✅ Apparent resistivity test with new PDE architecture passed")
315+
print(f"Geometric factor: {geometric_factor:.6f}")
316+
print(f"Potential difference range: [{predicted_data.min():.6e}, {predicted_data.max():.6e}] V")
317+
print(f"Apparent resistivity range: [{apparent_resistivity.min():.1f}, {apparent_resistivity.max():.1f}] Ω⋅m")
312318

313319

314320
if __name__ == "__main__":
315-
print("🧪 Testing Full DC Simulation with New Source/Receiver Classes")
321+
print("🧪 Testing Full DC Simulation with New PDE Architecture")
316322
print("=" * 65)
317323

318324
test_dc_simulation_fields_with_gradients()
@@ -328,4 +334,4 @@ def test_dc_simulation_apparent_resistivity():
328334
print()
329335

330336
print("=" * 65)
331-
print("🎉 All DC simulation tests with new source/receiver classes passed!")
337+
print("🎉 All DC simulation tests with new PDE architecture passed!")

0 commit comments

Comments
 (0)