Skip to content

Commit 14436ca

Browse files
committed
Added FDEM simulation code
1 parent 9dd2de4 commit 14436ca

File tree

1 file changed

+70
-0
lines changed

1 file changed

+70
-0
lines changed
Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import torch
2+
from simpegtorch.torchmatsolver import batched_mumps_solve, batched_sparse_solve
3+
4+
from simpegtorch.discretize import TensorMesh
5+
6+
# from simpegtorch.discretize.utils import sdiag
7+
8+
9+
class BaseFDEMSimulation:
10+
"""
11+
Base class for FDEM simulations.
12+
"""
13+
14+
def __init__(self, mesh: TensorMesh, survey=None, **kwargs):
15+
self.mesh = mesh
16+
self.survey = survey
17+
18+
def getRHS():
19+
"""
20+
Get the right-hand side (RHS) vector for the simulation.
21+
This method should be implemented in derived classes.
22+
"""
23+
raise NotImplementedError("getRHS must be implemented in derived classes.")
24+
25+
def getA():
26+
"""
27+
Get the system matrix (A) for the simulation.
28+
This method should be implemented in derived classes.
29+
"""
30+
raise NotImplementedError("getA must be implemented in derived classes.")
31+
32+
def fields(self, m):
33+
"""
34+
Compute the fields for the given model parameters.
35+
36+
Parameters
37+
----------
38+
m : torch.Tensor
39+
Model parameters.
40+
41+
Returns
42+
-------
43+
torch.Tensor
44+
Computed fields.
45+
"""
46+
A = self.getA(m)
47+
rhs = self.getRHS(m)
48+
49+
# Solve the system of equations
50+
if isinstance(A, torch.sparse.FloatTensor):
51+
return batched_sparse_solve(A, rhs)
52+
else:
53+
return batched_mumps_solve(A, rhs)
54+
55+
def dpred(self, m):
56+
"""
57+
Compute the predicted data for the given model parameters.
58+
59+
Parameters
60+
----------
61+
m : torch.Tensor
62+
Model parameters.
63+
64+
Returns
65+
-------
66+
torch.Tensor
67+
Predicted data.
68+
"""
69+
fields = self.fields(m)
70+
return self.survey.get_data(fields) if self.survey else fields

0 commit comments

Comments
 (0)