11import torch
2- from . .basePDE import BasePDE
2+ from simpegtorch . simulation . base .basePDE import BasePDE
33from simpegtorch .discretize import TensorMesh
44from simpegtorch .discretize .utils import sdiag
55
@@ -34,7 +34,7 @@ def __init__(
3434 survey : Survey
3535 Survey object with sources and receivers
3636 mapping : BaseMapping
37- Parameter generating function
37+ Parameter generating function, by default this is conductivity
3838 """
3939 super ().__init__ (mesh , mapping )
4040 self .survey = survey
@@ -44,14 +44,10 @@ def __init__(
4444 def setBC (self ):
4545 mesh = self .mesh
4646 # Standard cell-centered finite volume discretization
47- # Volume scaling is needed for proper conservation
4847 V = sdiag (mesh .cell_volumes )
4948 self .Div = V @ mesh .face_divergence
5049 self .Grad = self .Div .T
5150
52- # Initialize MfRhoI here, it will be updated in getA
53- self .MfRhoI = None
54-
5551 if self .bc_type == "Dirichlet" :
5652 print ("Homogeneous Dirichlet is the natural BC for this CC discretization" )
5753 return
@@ -108,12 +104,12 @@ def get_system_matrices(self) -> torch.Tensor:
108104 D = self .Div
109105 G = self .Grad
110106
111- # Use face inner product with inverse resistivity ( conductivity)
112- self . MfRhoI = self .mesh .get_face_inner_product (
107+ # Use face inner product with resistivity from mapping (inverted to get conductivity)
108+ MfRhoI = self .mesh .get_face_inner_product (
113109 self .mapping .forward (), invert_matrix = True
114110 )
115111
116- A = D @ self . MfRhoI @ G
112+ A = D @ MfRhoI @ G
117113
118114 if self .bc_type .lower () == "neumann" :
119115 A = self .condition_matrix (A )
@@ -156,7 +152,162 @@ def fields_to_data(self, fields: torch.Tensor) -> torch.Tensor:
156152
157153 # Get receiver projection tensor and project to data
158154 rx_tensor = src .build_receiver_tensor (self .mesh , "CC" )
159- rx_data = torch .mv (rx_tensor , src_field )
155+
156+ # Handle sparse tensor matrix multiplication
157+ # rx_tensor shape: [1, n_receivers, n_mesh_cells]
158+ # src_field shape: [n_mesh_cells]
159+ if rx_tensor .is_sparse :
160+ # For sparse tensors, we need to handle the multiplication differently
161+ # Select the first batch (index 0) and multiply with each receiver projection
162+ rx_tensor_2d = rx_tensor [0 ] # [n_receivers, n_mesh_cells]
163+ rx_data = torch .sparse .mm (
164+ rx_tensor_2d , src_field .unsqueeze (- 1 )
165+ ).squeeze (- 1 )
166+ else :
167+ # Dense tensor multiplication
168+ rx_data = rx_tensor .squeeze (0 ) @ src_field
169+
170+ data_list .append (rx_data )
171+
172+ return torch .cat (data_list , dim = 0 )
173+
174+
175+ class DC3DNodal (BasePDE ):
176+ """
177+ DC Resistivity PDE for Node centered formulation
178+ """
179+
180+ def __init__ (
181+ self ,
182+ mesh : TensorMesh ,
183+ survey ,
184+ mapping ,
185+ bc_type : str = "Dirichlet" ,
186+ ):
187+ """
188+ Initialize DC Resistivity PDE.
189+
190+ Parameters
191+ ----------
192+ mesh : TensorMesh
193+ Discretization mesh
194+ survey : Survey
195+ Survey object with sources and receivers
196+ mapping : BaseMapping
197+ Parameter generating function
198+ """
199+ super ().__init__ (mesh , mapping )
200+ self .survey = survey
201+ self .bc_type = bc_type
202+ self .setBC ()
203+
204+ def setBC (self ):
205+ valid_bc_types = ["neumann" ]
206+ if self .bc_type .lower () not in valid_bc_types :
207+ raise ValueError (f"Unsupported boundary condition type: { self .bc_type } " )
208+
209+ def get_system_matrices (self ) -> torch .Tensor :
210+ """
211+ Construct system matrix for the nodal DC resistivity PDE.
212+ Returns
213+ -------
214+ torch.Tensor
215+ Single system matrix for nodal DC resistivity.
216+ Shape [1, nN, nN]
217+ """
218+ # Use edge inner product with conductivity (1/resistivity)
219+ MeSigma = self .mesh .get_edge_inner_product (1.0 / self .mapping .forward ())
220+ Grad = self .mesh .nodal_gradient
221+ A = Grad .T @ MeSigma @ Grad
222+
223+ if self .bc_type .lower () == "neumann" :
224+ A = self .condition_matrix (A )
225+
226+ return torch .stack ([A ], dim = 0 ) # Shape [1, nN, nN]
227+
228+ def condition_matrix (self , A : torch .Tensor ) -> torch .Tensor :
229+ """
230+ Condition the system matrix for numerical stability.
231+ Used for Neumann boundary conditions.
232+ """
233+ # Create a sparse tensor that represents the modification
234+ # It will have 1 at (0,0) and 0 elsewhere
235+ mod_indices = torch .tensor ([[0 ], [0 ]], dtype = torch .long , device = A .device )
236+ mod_values = torch .tensor ([1.0 ], dtype = A .dtype , device = A .device )
237+ modification_matrix = torch .sparse_coo_tensor (
238+ mod_indices , mod_values , A .shape
239+ ).coalesce ()
240+
241+ # Create a mask for the first row of A
242+ first_row_mask = A .indices ()[0 ] == 0
243+
244+ # Get the values of the first row of A
245+ first_row_values = A .values ()[first_row_mask ]
246+ first_row_cols = A .indices ()[1 , first_row_mask ]
247+
248+ # Create a sparse tensor representing the negative of the first row of A
249+ neg_first_row_indices = torch .stack (
250+ [torch .zeros_like (first_row_cols ), first_row_cols ]
251+ )
252+ neg_first_row_values = - first_row_values
253+ neg_first_row_matrix = torch .sparse_coo_tensor (
254+ neg_first_row_indices , neg_first_row_values , A .shape
255+ ).coalesce ()
256+
257+ # Add the modification matrix and the negative of the first row to A
258+ A = A + neg_first_row_matrix + modification_matrix
259+ return A
260+
261+ def get_rhs_tensors (self ) -> torch .Tensor :
262+ """
263+ Construct RHS vectors for all sources.
264+ Returns
265+ -------
266+ torch.Tensor
267+ Batched RHS tensors. Shape [n_sources, nC]
268+ """
269+
270+ return self .survey .get_source_tensor (self .mesh , projected_grid = "N" )
271+
272+ def fields_to_data (self , fields : torch .Tensor ) -> torch .Tensor :
273+ """
274+ Project solution fields to predicted data.
275+ Parameters
276+ ----------
277+ fields : torch.Tensor
278+ Solution fields from PDE solve, shape [1, n_sources, nC]
279+ Returns
280+ -------
281+ torch.Tensor
282+ Predicted data vector [n_data]
283+ """
284+ # Remove frequency dimension if present (shape: [n_sources, nC])
285+ if fields .dim () == 3 :
286+ fields = fields .squeeze (0 )
287+
288+ sources = self .survey .source_list
289+ data_list = []
290+
291+ for i , src in enumerate (sources ):
292+ # Get field for this source
293+ src_field = fields [i ]
294+
295+ # Get receiver projection tensor and project to data
296+ rx_tensor = src .build_receiver_tensor (self .mesh , "N" )
297+
298+ # Handle sparse tensor matrix multiplication
299+ # rx_tensor shape: [1, n_receivers, n_mesh_nodes]
300+ # src_field shape: [n_mesh_nodes]
301+ if rx_tensor .is_sparse :
302+ # For sparse tensors, we need to handle the multiplication differently
303+ # Select the first batch (index 0) and multiply with each receiver projection
304+ rx_tensor_2d = rx_tensor [0 ] # [n_receivers, n_mesh_nodes]
305+ rx_data = torch .sparse .mm (
306+ rx_tensor_2d , src_field .unsqueeze (- 1 )
307+ ).squeeze (- 1 )
308+ else :
309+ # Dense tensor multiplication
310+ rx_data = rx_tensor .squeeze (0 ) @ src_field
160311
161312 data_list .append (rx_data )
162313
0 commit comments