44
55import torch
66from simpegtorch .discretize import TensorMesh
7- from simpegtorch .electromagnetics .resistivity .simulation import (
8- Simulation3DCellCentered ,
7+ from simpegtorch .simulation .resistivity import (
8+ DC3DCellCentered ,
9+ SrcDipole ,
10+ SrcPole ,
11+ RxDipole ,
12+ RxPole ,
13+ Survey ,
914)
10- from simpegtorch .electromagnetics . resistivity import sources , receivers , survey
15+ from simpegtorch .simulation . base import DirectSolver , mappings
1116
1217torch .set_default_dtype (torch .float64 )
1318
1419
1520def test_dc_simulation_fields_with_gradients ():
16- """Test complete DC simulation with field computation and gradients using new source/receiver classes ."""
21+ """Test complete DC simulation with field computation and gradients using new PDE architecture ."""
1722
1823 # Create a simple 3D mesh with explicit cell sizes
1924 h = torch .ones (10 , dtype = torch .float64 ) # 10 cells of size 1.0 each for more room
2025 mesh = TensorMesh ([h , h , h ], dtype = torch .float64 )
2126
22- # Create resistivity with gradients
23- resistivity = torch .full (
24- (mesh .n_cells ,), 100.0 , dtype = torch .float64 , requires_grad = True
27+ # Create conductivity with gradients (note: using conductivity as base parameter)
28+ sigma = torch .full (
29+ (mesh .n_cells ,), 0.01 , dtype = torch .float64 , requires_grad = True
2530 )
2631
2732 # Create receivers for measuring potential differences
@@ -42,32 +47,32 @@ def test_dc_simulation_fields_with_gradients():
4247 dtype = torch .float64 ,
4348 )
4449
45- rx = receivers . Dipole (locations_m = rx_locations_m , locations_n = rx_locations_n )
50+ rx = RxDipole (locations_m = rx_locations_m , locations_n = rx_locations_n )
4651
4752 # Create dipole source within mesh bounds
4853 src_location_a = torch .tensor ([2.0 , 3.5 , 2.0 ], dtype = torch .float64 ) # A electrode
4954 src_location_b = torch .tensor ([6.0 , 3.5 , 2.0 ], dtype = torch .float64 ) # B electrode
50- src = sources . Dipole (
51- [rx ], location_a = src_location_a , location_b = src_location_b , current = 1.0
55+ src = SrcDipole (
56+ [rx ], src_location_a , src_location_b , current = 1.0
5257 )
5358
5459 # Create survey
55- surv = survey . Survey ([src ])
60+ surv = Survey ([src ])
5661
57- # Create DC simulation with survey
58- sim = Simulation3DCellCentered (mesh , survey = surv )
59- sim .setBC ()
62+ # Create resistivity mapping (resistivity = 1/sigma)
63+ resistivity_map = mappings .InverseMapping (sigma )
6064
61- # Compute fields - this tests the complete gradient flow through:
65+ # Create DC PDE and solver
66+ pde = DC3DCellCentered (mesh , surv , resistivity_map , bc_type = "Dirichlet" )
67+ solver = DirectSolver (pde )
68+
69+ # Compute predicted data - this tests the complete gradient flow through:
6270 # 1. Source discretization to mesh
6371 # 2. Face inner product with inversion
6472 # 3. System matrix assembly (D @ MfRhoI @ G)
6573 # 4. Linear system solve with TorchMatSolver
6674 # 5. Receiver evaluation from fields
67- fields = sim .fields (resistivity )
68-
69- # Compute predicted data
70- predicted_data = sim .dpred (resistivity )
75+ predicted_data = solver .forward ()
7176
7277 # Compute a simple objective function based on data
7378 loss = torch .sum (predicted_data ** 2 )
@@ -76,47 +81,35 @@ def test_dc_simulation_fields_with_gradients():
7681 loss .backward ()
7782
7883 # Verify results
79- assert isinstance (
80- fields , dict
81- ), "Fields should be a dictionary for multiple sources"
82- assert src in fields , "Fields should contain entry for source"
83-
84- src_fields = fields [src ]
85- assert src_fields is not None , "Fields should be computed"
86- assert src_fields .shape [0 ] == mesh .n_cells , "Fields should have correct shape"
87- assert torch .all (torch .isfinite (src_fields )), "Fields should be finite"
88-
8984 assert predicted_data is not None , "Predicted data should be computed"
90- assert predicted_data .shape [0 ] == rx .nD , "Data should have correct shape"
85+ assert predicted_data .shape [0 ] == rx .locations_m . shape [ 0 ] , "Data should have correct shape"
9186 assert torch .all (torch .isfinite (predicted_data )), "Data should be finite"
9287
93- assert resistivity .grad is not None , "Gradients should be computed"
94- assert torch .all (torch .isfinite (resistivity .grad )), "Gradients should be finite"
95- assert torch .any (resistivity .grad != 0 ), "Some gradients should be non-zero"
88+ assert resistivity_map . trainable_parameters .grad is not None , "Gradients should be computed"
89+ assert torch .all (torch .isfinite (resistivity_map . trainable_parameters .grad )), "Gradients should be finite"
90+ assert torch .any (resistivity_map . trainable_parameters .grad != 0 ), "Some gradients should be non-zero"
9691
97- print ("✅ Complete DC simulation test with new source/receiver classes passed" )
92+ print ("✅ Complete DC simulation test with new PDE architecture passed" )
9893 print (f"Mesh cells: { mesh .n_cells } " )
9994 print (f"Source: Dipole at A={ src .location_a } , B={ src .location_b } " )
100- print (f"Receivers: { rx .nD } dipole measurements" )
101- print (f"Fields shape: { src_fields .shape } " )
102- print (f"Fields range: [{ src_fields .min ():.6f} , { src_fields .max ():.6f} ]" )
95+ print (f"Receivers: { rx .locations_m .shape [0 ]} dipole measurements" )
10396 print (f"Data shape: { predicted_data .shape } " )
10497 print (f"Data range: [{ predicted_data .min ():.6f} , { predicted_data .max ():.6f} ] V" )
105- print (f"Gradient mean: { resistivity .grad .mean ():.2e} " )
106- print (f"Gradient std: { resistivity .grad .std ():.2e} " )
98+ print (f"Gradient mean: { resistivity_map . trainable_parameters .grad .mean ():.2e} " )
99+ print (f"Gradient std: { resistivity_map . trainable_parameters .grad .std ():.2e} " )
107100
108101
109102def test_dc_simulation_jtvec ():
110- """Test Jtvec functionality with new source/receiver classes ."""
103+ """Test Jacobian transpose vector product functionality with new PDE architecture ."""
111104
112105 # Create mesh and model
113106 h = torch .ones (6 , dtype = torch .float64 )
114107 mesh = TensorMesh ([h , h , h ], dtype = torch .float64 )
115- resistivity = torch .full (
116- (mesh .n_cells ,), 50.0 , dtype = torch .float64 , requires_grad = True
108+ sigma = torch .full (
109+ (mesh .n_cells ,), 0.02 , dtype = torch .float64 , requires_grad = True
117110 )
118111
119- # Create multiple receivers within mesh bounds [0,6] x [0,6] x [0,4 ]
112+ # Create multiple receivers within mesh bounds [0,6] x [0,6] x [0,6 ]
120113 rx_locations_m = torch .tensor (
121114 [
122115 [2.0 , 2.0 , 0.5 ],
@@ -135,44 +128,53 @@ def test_dc_simulation_jtvec():
135128 dtype = torch .float64 ,
136129 )
137130
138- rx = receivers . Dipole (locations_m = rx_locations_m , locations_n = rx_locations_n )
131+ rx = RxDipole (locations_m = rx_locations_m , locations_n = rx_locations_n )
139132
140133 # Create source within mesh bounds
141- src = sources . Dipole (
134+ src = SrcDipole (
142135 [rx ],
143- location_a = torch .tensor ([1.0 , 2.5 , 0.5 ], dtype = torch .float64 ),
144- location_b = torch .tensor ([5.0 , 2.5 , 0.5 ], dtype = torch .float64 ),
136+ torch .tensor ([1.0 , 2.5 , 0.5 ], dtype = torch .float64 ),
137+ torch .tensor ([5.0 , 2.5 , 0.5 ], dtype = torch .float64 ),
145138 current = 2.0 ,
146139 )
147140
148- # Create survey and simulation
149- surv = survey .Survey ([src ])
150- sim = Simulation3DCellCentered (mesh , survey = surv )
151- sim .setBC ()
152-
153- # Test Jtvec
154- data_residuals = torch .randn (rx .nD , dtype = torch .float64 )
155- gradient = sim .Jtvec (resistivity , data_residuals )
141+ # Create survey, mapping and solver
142+ surv = Survey ([src ])
143+ resistivity_map = mappings .InverseMapping (sigma )
144+ pde = DC3DCellCentered (mesh , surv , resistivity_map , bc_type = "Dirichlet" )
145+ solver = DirectSolver (pde )
146+
147+ # Test Jtvec by computing gradients manually
148+ data_residuals = torch .randn (rx .locations_m .shape [0 ], dtype = torch .float64 )
149+
150+ # Forward pass
151+ predicted_data = solver .forward ()
152+
153+ # Compute loss using data residuals (Jtvec equivalent)
154+ loss = torch .sum (data_residuals * predicted_data )
155+ loss .backward ()
156+
157+ gradient = resistivity_map .trainable_parameters .grad
156158
157159 # Verify results
158- assert gradient is not None , "Jtvec should return gradient "
160+ assert gradient is not None , "Gradients should be computed "
159161 assert gradient .shape [0 ] == mesh .n_cells , "Gradient should have correct shape"
160162 assert torch .all (torch .isfinite (gradient )), "Gradient should be finite"
161163
162- print ("✅ Jtvec test with new source/receiver classes passed" )
164+ print ("✅ Jacobian transpose test with new PDE architecture passed" )
163165 print (f"Data residuals shape: { data_residuals .shape } " )
164166 print (f"Gradient shape: { gradient .shape } " )
165167 print (f"Gradient range: [{ gradient .min ():.2e} , { gradient .max ():.2e} ]" )
166168
167169
168170def test_dc_simulation_multiple_sources ():
169- """Test simulation with multiple sources using new classes ."""
171+ """Test simulation with multiple sources using new PDE architecture ."""
170172
171173 # Create mesh
172174 h = torch .ones (8 , dtype = torch .float64 )
173175 mesh = TensorMesh ([h , h , h ], dtype = torch .float64 )
174- resistivity = torch .full (
175- (mesh .n_cells ,), 75.0 , dtype = torch .float64 , requires_grad = True
176+ sigma = torch .full (
177+ (mesh .n_cells ,), 1.0 / 75.0 , dtype = torch .float64 , requires_grad = True
176178 )
177179
178180 # Create shared receivers
@@ -194,66 +196,69 @@ def test_dc_simulation_multiple_sources():
194196 dtype = torch .float64 ,
195197 )
196198
197- rx = receivers . Dipole (locations_m = rx_locations_m , locations_n = rx_locations_n )
199+ rx = RxDipole (locations_m = rx_locations_m , locations_n = rx_locations_n )
198200
199201 # Create multiple sources
200- src1 = sources . Dipole (
202+ src1 = SrcDipole (
201203 [rx ],
202- location_a = torch .tensor ([1.0 , 3.5 , 1.0 ], dtype = torch .float64 ),
203- location_b = torch .tensor ([2.0 , 3.5 , 1.0 ], dtype = torch .float64 ),
204+ torch .tensor ([1.0 , 3.5 , 1.0 ], dtype = torch .float64 ),
205+ torch .tensor ([2.0 , 3.5 , 1.0 ], dtype = torch .float64 ),
204206 )
205207
206- src2 = sources . Dipole (
208+ src2 = SrcDipole (
207209 [rx ],
208- location_a = torch .tensor ([6.0 , 3.5 , 1.0 ], dtype = torch .float64 ),
209- location_b = torch .tensor ([7.0 , 3.5 , 1.0 ], dtype = torch .float64 ),
210+ torch .tensor ([6.0 , 3.5 , 1.0 ], dtype = torch .float64 ),
211+ torch .tensor ([7.0 , 3.5 , 1.0 ], dtype = torch .float64 ),
210212 )
211213
212214 # Test pole source too
213- rx_pole = receivers . Pole (
215+ rx_pole = RxPole (
214216 locations = torch .tensor ([[4.0 , 3.5 , 2.0 ]], dtype = torch .float64 )
215217 )
216- src3 = sources . Pole (
217- [rx_pole ], location = torch .tensor ([4.0 , 1.0 , 1.0 ], dtype = torch .float64 )
218+ src3 = SrcPole (
219+ [rx_pole ], torch .tensor ([4.0 , 1.0 , 1.0 ], dtype = torch .float64 )
218220 )
219221
220222 # Create survey with multiple sources
221- surv = survey .Survey ([src1 , src2 , src3 ])
222- sim = Simulation3DCellCentered (mesh , survey = surv )
223- sim .setBC ()
223+ surv = Survey ([src1 , src2 , src3 ])
224+ resistivity_map = mappings .InverseMapping (sigma )
225+ pde = DC3DCellCentered (mesh , surv , resistivity_map , bc_type = "Dirichlet" )
226+ solver = DirectSolver (pde )
224227
225228 # Test forward modeling
226- predicted_data = sim . dpred ( resistivity )
229+ predicted_data = solver . forward ( )
227230
228- # Test Jtvec with multiple sources
231+ # Test Jacobian transpose with multiple sources
229232 data_residuals = torch .randn_like (predicted_data )
230- gradient = sim .Jtvec (resistivity , data_residuals )
233+ loss = torch .sum (data_residuals * predicted_data )
234+ loss .backward ()
235+ gradient = resistivity_map .trainable_parameters .grad
231236
232237 # Verify results
233- expected_data_count = rx .nD * 2 + rx_pole .nD # 2 dipole sources + 1 pole source
238+ expected_data_count = rx .locations_m . shape [ 0 ] * 2 + rx_pole .locations . shape [ 0 ] # 2 dipole sources + 1 pole source
234239 assert (
235240 predicted_data .shape [0 ] == expected_data_count
236241 ), f"Expected { expected_data_count } data points"
237242 assert gradient .shape [0 ] == mesh .n_cells , "Gradient should have correct shape"
238243 assert torch .all (torch .isfinite (predicted_data )), "Data should be finite"
239244 assert torch .all (torch .isfinite (gradient )), "Gradient should be finite"
240245
241- print ("✅ Multiple sources test with new source/receiver classes passed" )
242- print (f"Sources: { surv .nSrc } (2 dipole + 1 pole)" )
243- print (f"Total data points: { surv . nD } " )
246+ print ("✅ Multiple sources test with new PDE architecture passed" )
247+ print (f"Sources: { len ( surv .source_list ) } (2 dipole + 1 pole)" )
248+ print (f"Total data points: { predicted_data . shape [ 0 ] } " )
244249 print (f"Data range: [{ predicted_data .min ():.6f} , { predicted_data .max ():.6f} ] V" )
245250 print (f"Gradient range: [{ gradient .min ():.2e} , { gradient .max ():.2e} ]" )
246251
247252
248253def test_dc_simulation_apparent_resistivity ():
249- """Test apparent resistivity calculations."""
254+ """Test apparent resistivity calculations with new PDE architecture ."""
250255
251256 # Create mesh
252257 h = torch .ones (10 , dtype = torch .float64 )
253258 mesh = TensorMesh ([h , h , h ], dtype = torch .float64 )
254- resistivity = torch .full ((mesh .n_cells ,), 100.0 , dtype = torch .float64 )
259+ sigma = torch .full ((mesh .n_cells ,), 0.01 , dtype = torch .float64 ) # 1/100 S/m
255260
256- # Create receivers with apparent resistivity data type within mesh bounds [0,10] x [0,10] x [0,10]
261+ # Create receivers for potential differences within mesh bounds [0,10] x [0,10] x [0,10]
257262 # Use a standard dipole-dipole configuration for proper geometric factors
258263 rx_locations_m = torch .tensor (
259264 [
@@ -271,48 +276,49 @@ def test_dc_simulation_apparent_resistivity():
271276 dtype = torch .float64 ,
272277 )
273278
274- rx = receivers .Dipole (
275- locations_m = rx_locations_m ,
276- locations_n = rx_locations_n ,
277- data_type = "apparent_resistivity" ,
278- )
279+ rx = RxDipole (locations_m = rx_locations_m , locations_n = rx_locations_n )
279280
280281 # Create source within mesh bounds - standard dipole-dipole array
281- src = sources . Dipole (
282+ src = SrcDipole (
282283 [rx ],
283- location_a = torch .tensor ([1.0 , 5.0 , 1.0 ], dtype = torch .float64 ), # A electrode
284- location_b = torch .tensor ([2.0 , 5.0 , 1.0 ], dtype = torch .float64 ), # B electrode
284+ torch .tensor ([1.0 , 5.0 , 1.0 ], dtype = torch .float64 ), # A electrode
285+ torch .tensor ([2.0 , 5.0 , 1.0 ], dtype = torch .float64 ), # B electrode
285286 )
286287
287- # Create survey and set geometric factors
288- surv = survey .Survey ([src ])
289- geometric_factors = surv .set_geometric_factor (space_type = "halfspace" )
290-
291- # Create simulation
292- sim = Simulation3DCellCentered (mesh , survey = surv )
293- sim .setBC ()
294-
295- # Test apparent resistivity calculation
296- apparent_resistivity = sim .dpred (resistivity )
288+ # Create survey and simulation
289+ surv = Survey ([src ])
290+ resistivity_map = mappings .InverseMapping (sigma )
291+ pde = DC3DCellCentered (mesh , surv , resistivity_map , bc_type = "Dirichlet" )
292+ solver = DirectSolver (pde )
293+
294+ # Test potential difference calculation
295+ predicted_data = solver .forward ()
296+
297+ # Calculate apparent resistivity manually using simple geometric factor
298+ # For dipole-dipole: rho_app = K * (V_MN / I) where K is geometric factor
299+ current = 1.0 # Default current
300+ # Simplified geometric factor for dipole-dipole (depends on electrode spacing)
301+ electrode_spacing = 1.0 # Unit spacing
302+ geometric_factor = 2 * torch .pi * electrode_spacing # Simplified K
303+ apparent_resistivity = geometric_factor * torch .abs (predicted_data ) / current
297304
298305 # Debug output
306+ print (f"Potential differences: { predicted_data } " )
299307 print (f"Apparent resistivity values: { apparent_resistivity } " )
300- print (f"Geometric factors: { geometric_factors } " )
301308
302309 # Verify results
310+ assert torch .all (torch .isfinite (predicted_data ))
303311 assert torch .all (torch .isfinite (apparent_resistivity ))
312+ assert torch .all (apparent_resistivity > 0 )
304313
305- print ("✅ Apparent resistivity test with new source/receiver classes passed" )
306- print (
307- f"Geometric factors range: [{ geometric_factors .min ():.6f} , { geometric_factors .max ():.6f} ]"
308- )
309- print (
310- f"Apparent resistivity range: [{ apparent_resistivity .min ():.1f} , { apparent_resistivity .max ():.1f} ] Ω⋅m"
311- )
314+ print ("✅ Apparent resistivity test with new PDE architecture passed" )
315+ print (f"Geometric factor: { geometric_factor :.6f} " )
316+ print (f"Potential difference range: [{ predicted_data .min ():.6e} , { predicted_data .max ():.6e} ] V" )
317+ print (f"Apparent resistivity range: [{ apparent_resistivity .min ():.1f} , { apparent_resistivity .max ():.1f} ] Ω⋅m" )
312318
313319
314320if __name__ == "__main__" :
315- print ("🧪 Testing Full DC Simulation with New Source/Receiver Classes " )
321+ print ("🧪 Testing Full DC Simulation with New PDE Architecture " )
316322 print ("=" * 65 )
317323
318324 test_dc_simulation_fields_with_gradients ()
@@ -328,4 +334,4 @@ def test_dc_simulation_apparent_resistivity():
328334 print ()
329335
330336 print ("=" * 65 )
331- print ("🎉 All DC simulation tests with new source/receiver classes passed!" )
337+ print ("🎉 All DC simulation tests with new PDE architecture passed!" )
0 commit comments