Skip to content

Commit a90497b

Browse files
committed
Updated DCR tests to use new formulation
1 parent af629a6 commit a90497b

File tree

9 files changed

+283
-358
lines changed

9 files changed

+283
-358
lines changed

examples/dc_torch_example.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
import torch
2+
from simpegtorch.discretize import TensorMesh
3+
from simpegtorch.simulation.resistivity import (
4+
DC3DCellCentered,
5+
SrcDipole,
6+
RxDipole,
7+
Survey,
8+
)
9+
10+
from simpegtorch.simulation.base import DirectSolver, mappings
11+
12+
from simpegtorch.discretize.utils import (
13+
ndgrid,
14+
)
15+
16+
# Create a tensor mesh
17+
hx = torch.ones(10) * 25
18+
hy = torch.ones(10) * 25
19+
hz = torch.ones(10) * 25
20+
21+
# 250m x 250m x 250m mesh
22+
23+
# Origin at (-125, -125, -250) to center the mesh
24+
origin = torch.tensor([-125.0, -125.0, -250.0])
25+
26+
mesh = TensorMesh(
27+
[hx, hy, hz],
28+
origin=origin,
29+
)
30+
31+
sigma = torch.ones(mesh.nC) * 1e-2 # Uniform conductivity
32+
sigma_map = mappings.BaseMapping(sigma)
33+
34+
# Set up survey parameters for numeric solution
35+
x = mesh.cell_centers_x[(mesh.cell_centers_x > -75.0) & (mesh.cell_centers_x < 75.0)]
36+
y = mesh.cell_centers_y[(mesh.cell_centers_y > -75.0) & (mesh.cell_centers_y < 75.0)]
37+
38+
M = ndgrid(x - 25.0, y, [0.0])
39+
N = ndgrid(x + 25.0, y, [0.0])
40+
41+
# create a dipole dipole survey
42+
rx = RxDipole(
43+
locations_m=M,
44+
locations_n=N,
45+
)
46+
47+
loc_a = torch.tensor([-25.0, 0.0, 0.0])
48+
loc_b = torch.tensor([25.0, 0.0, 0.0])
49+
50+
src = SrcDipole(
51+
[rx],
52+
loc_a, # location of A
53+
loc_b, # location of B
54+
current=1.0, # current in Amperes
55+
)
56+
57+
survey = Survey([src])
58+
59+
## Setup the Problem as a Cell Centered PDE problem
60+
problem = DC3DCellCentered(
61+
mesh,
62+
survey,
63+
sigma_map,
64+
bc_type="Dirichlet",
65+
)
66+
67+
# Create the solver we will use
68+
solver = DirectSolver(problem)
69+
70+
# Solve the forward problem
71+
Data = solver.forward()
72+
73+
# Print results
74+
print("Forward simulation completed successfully!")
75+
print(f"Data values: {Data}")
76+
print(f"Data min: {Data.min():.6e}, max: {Data.max():.6e}")

simpegtorch/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,2 @@
11
from . import discretize
22
from . import utils
3-
from . import maps
Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,2 @@
1-
from .universal_simulation import UniversalSimulation
2-
from .dc_pde import DCResistivityPDE
3-
from .fdem_pde import FDEMPDE
4-
from ..basePDE import BasePDE, BaseMapping
5-
6-
__all__ = [
7-
"UniversalSimulation",
8-
"DCResistivityPDE",
9-
"FDEMPDE",
10-
"BasePDE",
11-
"BaseMapping",
12-
]
1+
from .basePDE import BasePDE
2+
from .direct_solver import DirectSolver

simpegtorch/simulation/base/basePDE.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
from abc import ABC, abstractmethod
22
import torch
33
from simpegtorch.discretize.base import BaseMesh
4-
from mappings import BaseMapping
4+
from .mappings import BaseMapping
55

66

77
class BasePDE(ABC):

simpegtorch/simulation/resistivity/__init__.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
"""
2626

2727
from .simulation import Simulation3DCellCentered, Simulation3DNodal
28+
from .dc_pde import DC3DCellCentered, DC3DNodal
2829
from .sources import BaseSrc, Pole as SrcPole, Dipole as SrcDipole, Multipole
2930
from .receivers import BaseRx, Pole as RxPole, Dipole as RxDipole
3031
from .survey import Survey
@@ -44,6 +45,9 @@
4445
# 3D simulations
4546
"Simulation3DCellCentered",
4647
"Simulation3DNodal",
48+
# PDE problems
49+
"DC3DCellCentered",
50+
"DC3DNodal",
4751
# Sources
4852
"BaseSrc",
4953
"Src",

simpegtorch/simulation/resistivity/dc_pde.py

Lines changed: 161 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import torch
2-
from ..basePDE import BasePDE
2+
from simpegtorch.simulation.base.basePDE import BasePDE
33
from simpegtorch.discretize import TensorMesh
44
from simpegtorch.discretize.utils import sdiag
55

@@ -34,7 +34,7 @@ def __init__(
3434
survey : Survey
3535
Survey object with sources and receivers
3636
mapping : BaseMapping
37-
Parameter generating function
37+
Parameter generating function, by default this is conductivity
3838
"""
3939
super().__init__(mesh, mapping)
4040
self.survey = survey
@@ -44,14 +44,10 @@ def __init__(
4444
def setBC(self):
4545
mesh = self.mesh
4646
# Standard cell-centered finite volume discretization
47-
# Volume scaling is needed for proper conservation
4847
V = sdiag(mesh.cell_volumes)
4948
self.Div = V @ mesh.face_divergence
5049
self.Grad = self.Div.T
5150

52-
# Initialize MfRhoI here, it will be updated in getA
53-
self.MfRhoI = None
54-
5551
if self.bc_type == "Dirichlet":
5652
print("Homogeneous Dirichlet is the natural BC for this CC discretization")
5753
return
@@ -108,12 +104,12 @@ def get_system_matrices(self) -> torch.Tensor:
108104
D = self.Div
109105
G = self.Grad
110106

111-
# Use face inner product with inverse resistivity (conductivity)
112-
self.MfRhoI = self.mesh.get_face_inner_product(
107+
# Use face inner product with resistivity from mapping (inverted to get conductivity)
108+
MfRhoI = self.mesh.get_face_inner_product(
113109
self.mapping.forward(), invert_matrix=True
114110
)
115111

116-
A = D @ self.MfRhoI @ G
112+
A = D @ MfRhoI @ G
117113

118114
if self.bc_type.lower() == "neumann":
119115
A = self.condition_matrix(A)
@@ -156,7 +152,162 @@ def fields_to_data(self, fields: torch.Tensor) -> torch.Tensor:
156152

157153
# Get receiver projection tensor and project to data
158154
rx_tensor = src.build_receiver_tensor(self.mesh, "CC")
159-
rx_data = torch.mv(rx_tensor, src_field)
155+
156+
# Handle sparse tensor matrix multiplication
157+
# rx_tensor shape: [1, n_receivers, n_mesh_cells]
158+
# src_field shape: [n_mesh_cells]
159+
if rx_tensor.is_sparse:
160+
# For sparse tensors, we need to handle the multiplication differently
161+
# Select the first batch (index 0) and multiply with each receiver projection
162+
rx_tensor_2d = rx_tensor[0] # [n_receivers, n_mesh_cells]
163+
rx_data = torch.sparse.mm(
164+
rx_tensor_2d, src_field.unsqueeze(-1)
165+
).squeeze(-1)
166+
else:
167+
# Dense tensor multiplication
168+
rx_data = rx_tensor.squeeze(0) @ src_field
169+
170+
data_list.append(rx_data)
171+
172+
return torch.cat(data_list, dim=0)
173+
174+
175+
class DC3DNodal(BasePDE):
176+
"""
177+
DC Resistivity PDE for Node centered formulation
178+
"""
179+
180+
def __init__(
181+
self,
182+
mesh: TensorMesh,
183+
survey,
184+
mapping,
185+
bc_type: str = "Dirichlet",
186+
):
187+
"""
188+
Initialize DC Resistivity PDE.
189+
190+
Parameters
191+
----------
192+
mesh : TensorMesh
193+
Discretization mesh
194+
survey : Survey
195+
Survey object with sources and receivers
196+
mapping : BaseMapping
197+
Parameter generating function
198+
"""
199+
super().__init__(mesh, mapping)
200+
self.survey = survey
201+
self.bc_type = bc_type
202+
self.setBC()
203+
204+
def setBC(self):
205+
valid_bc_types = ["neumann"]
206+
if self.bc_type.lower() not in valid_bc_types:
207+
raise ValueError(f"Unsupported boundary condition type: {self.bc_type}")
208+
209+
def get_system_matrices(self) -> torch.Tensor:
210+
"""
211+
Construct system matrix for the nodal DC resistivity PDE.
212+
Returns
213+
-------
214+
torch.Tensor
215+
Single system matrix for nodal DC resistivity.
216+
Shape [1, nN, nN]
217+
"""
218+
# Use edge inner product with conductivity (1/resistivity)
219+
MeSigma = self.mesh.get_edge_inner_product(1.0 / self.mapping.forward())
220+
Grad = self.mesh.nodal_gradient
221+
A = Grad.T @ MeSigma @ Grad
222+
223+
if self.bc_type.lower() == "neumann":
224+
A = self.condition_matrix(A)
225+
226+
return torch.stack([A], dim=0) # Shape [1, nN, nN]
227+
228+
def condition_matrix(self, A: torch.Tensor) -> torch.Tensor:
229+
"""
230+
Condition the system matrix for numerical stability.
231+
Used for Neumann boundary conditions.
232+
"""
233+
# Create a sparse tensor that represents the modification
234+
# It will have 1 at (0,0) and 0 elsewhere
235+
mod_indices = torch.tensor([[0], [0]], dtype=torch.long, device=A.device)
236+
mod_values = torch.tensor([1.0], dtype=A.dtype, device=A.device)
237+
modification_matrix = torch.sparse_coo_tensor(
238+
mod_indices, mod_values, A.shape
239+
).coalesce()
240+
241+
# Create a mask for the first row of A
242+
first_row_mask = A.indices()[0] == 0
243+
244+
# Get the values of the first row of A
245+
first_row_values = A.values()[first_row_mask]
246+
first_row_cols = A.indices()[1, first_row_mask]
247+
248+
# Create a sparse tensor representing the negative of the first row of A
249+
neg_first_row_indices = torch.stack(
250+
[torch.zeros_like(first_row_cols), first_row_cols]
251+
)
252+
neg_first_row_values = -first_row_values
253+
neg_first_row_matrix = torch.sparse_coo_tensor(
254+
neg_first_row_indices, neg_first_row_values, A.shape
255+
).coalesce()
256+
257+
# Add the modification matrix and the negative of the first row to A
258+
A = A + neg_first_row_matrix + modification_matrix
259+
return A
260+
261+
def get_rhs_tensors(self) -> torch.Tensor:
262+
"""
263+
Construct RHS vectors for all sources.
264+
Returns
265+
-------
266+
torch.Tensor
267+
Batched RHS tensors. Shape [n_sources, nC]
268+
"""
269+
270+
return self.survey.get_source_tensor(self.mesh, projected_grid="N")
271+
272+
def fields_to_data(self, fields: torch.Tensor) -> torch.Tensor:
273+
"""
274+
Project solution fields to predicted data.
275+
Parameters
276+
----------
277+
fields : torch.Tensor
278+
Solution fields from PDE solve, shape [1, n_sources, nC]
279+
Returns
280+
-------
281+
torch.Tensor
282+
Predicted data vector [n_data]
283+
"""
284+
# Remove frequency dimension if present (shape: [n_sources, nC])
285+
if fields.dim() == 3:
286+
fields = fields.squeeze(0)
287+
288+
sources = self.survey.source_list
289+
data_list = []
290+
291+
for i, src in enumerate(sources):
292+
# Get field for this source
293+
src_field = fields[i]
294+
295+
# Get receiver projection tensor and project to data
296+
rx_tensor = src.build_receiver_tensor(self.mesh, "N")
297+
298+
# Handle sparse tensor matrix multiplication
299+
# rx_tensor shape: [1, n_receivers, n_mesh_nodes]
300+
# src_field shape: [n_mesh_nodes]
301+
if rx_tensor.is_sparse:
302+
# For sparse tensors, we need to handle the multiplication differently
303+
# Select the first batch (index 0) and multiply with each receiver projection
304+
rx_tensor_2d = rx_tensor[0] # [n_receivers, n_mesh_nodes]
305+
rx_data = torch.sparse.mm(
306+
rx_tensor_2d, src_field.unsqueeze(-1)
307+
).squeeze(-1)
308+
else:
309+
# Dense tensor multiplication
310+
rx_data = rx_tensor.squeeze(0) @ src_field
160311

161312
data_list.append(rx_data)
162313

simpegtorch/simulation/resistivity/sources.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,16 @@ def evaluate(self, mesh, projected_grid: str = "CC"):
127127
"""
128128

129129
indices = mesh.closest_points_index(self.location, grid_loc=projected_grid)
130-
q = torch.zeros(mesh.nC, dtype=self.current.dtype, device=self.current.device)
130+
131+
# Choose the right vector size based on formulation
132+
if projected_grid in ["N", "nodes"]:
133+
vector_size = mesh.nN # Number of nodes for nodal formulation
134+
else:
135+
vector_size = mesh.nC # Number of cells for cell-centered formulation
136+
137+
q = torch.zeros(
138+
vector_size, dtype=self.current.dtype, device=self.current.device
139+
)
131140
q[indices] = self.current
132141
return q
133142

0 commit comments

Comments
 (0)