File tree Expand file tree Collapse file tree 1 file changed +70
-0
lines changed
simpegtorch/electromagnetics/FDEM Expand file tree Collapse file tree 1 file changed +70
-0
lines changed Original file line number Diff line number Diff line change 1+ import torch
2+ from simpegtorch .torchmatsolver import batched_mumps_solve , batched_sparse_solve
3+
4+ from simpegtorch .discretize import TensorMesh
5+
6+ # from simpegtorch.discretize.utils import sdiag
7+
8+
9+ class BaseFDEMSimulation :
10+ """
11+ Base class for FDEM simulations.
12+ """
13+
14+ def __init__ (self , mesh : TensorMesh , survey = None , ** kwargs ):
15+ self .mesh = mesh
16+ self .survey = survey
17+
18+ def getRHS ():
19+ """
20+ Get the right-hand side (RHS) vector for the simulation.
21+ This method should be implemented in derived classes.
22+ """
23+ raise NotImplementedError ("getRHS must be implemented in derived classes." )
24+
25+ def getA ():
26+ """
27+ Get the system matrix (A) for the simulation.
28+ This method should be implemented in derived classes.
29+ """
30+ raise NotImplementedError ("getA must be implemented in derived classes." )
31+
32+ def fields (self , m ):
33+ """
34+ Compute the fields for the given model parameters.
35+
36+ Parameters
37+ ----------
38+ m : torch.Tensor
39+ Model parameters.
40+
41+ Returns
42+ -------
43+ torch.Tensor
44+ Computed fields.
45+ """
46+ A = self .getA (m )
47+ rhs = self .getRHS (m )
48+
49+ # Solve the system of equations
50+ if isinstance (A , torch .sparse .FloatTensor ):
51+ return batched_sparse_solve (A , rhs )
52+ else :
53+ return batched_mumps_solve (A , rhs )
54+
55+ def dpred (self , m ):
56+ """
57+ Compute the predicted data for the given model parameters.
58+
59+ Parameters
60+ ----------
61+ m : torch.Tensor
62+ Model parameters.
63+
64+ Returns
65+ -------
66+ torch.Tensor
67+ Predicted data.
68+ """
69+ fields = self .fields (m )
70+ return self .survey .get_data (fields ) if self .survey else fields
You can’t perform that action at this time.
0 commit comments