Skip to content

Commit 49aaad7

Browse files
committed
Added base classes for new architecture, updated mat solver for smart batching
1 parent e09afc7 commit 49aaad7

File tree

10 files changed

+1102
-145
lines changed

10 files changed

+1102
-145
lines changed
Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,12 @@
1+
from .universal_simulation import UniversalSimulation
2+
from .dc_pde import DCResistivityPDE
3+
from .fdem_pde import FDEMPDE
4+
from ..basePDE import BasePDE, BaseMapping
5+
6+
__all__ = [
7+
"UniversalSimulation",
8+
"DCResistivityPDE",
9+
"FDEMPDE",
10+
"BasePDE",
11+
"BaseMapping",
12+
]
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
from abc import ABC, abstractmethod
2+
import torch
3+
from simpegtorch.discretize.base import BaseMesh
4+
from mappings import BaseMapping
5+
6+
7+
class BasePDE(ABC):
8+
"""
9+
Minimal PyTorch-native PDE interface for forward simulations.
10+
11+
This class defines the essential interface for any PDE formulation,
12+
focusing on the three core operations needed for simulation:
13+
1. System matrix construction
14+
2. RHS vector construction
15+
3. Field-to-data projection
16+
4. A vec product (optional, for experimental indirect solver)
17+
18+
All gradient computation is handled automatically by PyTorch autograd.
19+
"""
20+
21+
def __init__(self, mesh: BaseMesh, mapping: BaseMapping):
22+
self.mesh = mesh
23+
self.mapping = mapping
24+
25+
@abstractmethod
26+
def get_system_matrices(self) -> torch.Tensor:
27+
"""
28+
Construct system matrices for the PDE.
29+
30+
Parameters
31+
----------
32+
model_params : torch.Tensor
33+
Physical model parameters (conductivity, permeability, etc.)
34+
35+
Returns
36+
-------
37+
Union[torch.Tensor, List[torch.Tensor]]
38+
Single system matrix for problems like DC resistivity,
39+
or list of matrices for frequency-domain problems.
40+
Shape depends on problem type:
41+
- DC: [1, n_cells, n_cells]
42+
- FDEM: [n_frequencies, n_cells, n_cells]
43+
"""
44+
pass
45+
46+
@abstractmethod
47+
def get_rhs_tensors(self) -> torch.Tensor:
48+
"""
49+
Construct right-hand side vectors for all sources.
50+
51+
Parameters
52+
53+
Returns
54+
-------
55+
torch.Tensor
56+
Batched RHS tensors. Shape depends on problem type:
57+
- DC: [1, n_sources, n_cells]
58+
- FDEM: [n_frequencies, n_sources, n_cells]
59+
"""
60+
pass
61+
62+
@abstractmethod
63+
def fields_to_data(self, fields: torch.Tensor) -> torch.Tensor:
64+
"""
65+
Project solution fields to predicted data.
66+
67+
Parameters
68+
----------
69+
fields : torch.Tensor
70+
Solution fields from PDE solve
71+
72+
Returns
73+
-------
74+
torch.Tensor
75+
Predicted data vector [n_data]
76+
"""
77+
pass
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
import torch
2+
from torch import nn
3+
from .basePDE import BasePDE
4+
from simpegtorch.torchmatsolver import batched_mumps_solve, batched_sparse_solve
5+
6+
7+
class DirectSolver(nn.Module):
8+
"""
9+
PyTorch-native simulation class for a direct solver on any PDE formualtion.
10+
11+
This class provides a minimal, clean interface that works with any PDE. All gradient computation is handled
12+
automatically by PyTorch autograd.
13+
14+
The entire simulation is a single forward pass which gives predicted data
15+
"""
16+
17+
def __init__(self, pde: BasePDE):
18+
"""
19+
Initialize universal simulation.
20+
21+
Parameters
22+
----------
23+
pde : BasePDE
24+
PDE formulation defining the physics
25+
survey : Survey
26+
Survey object with sources and receivers
27+
solver_method : str
28+
Solver method ("Direct" or "Iterative")
29+
"""
30+
super().__init__()
31+
self.pde = pde
32+
33+
def forward(self) -> torch.Tensor:
34+
"""
35+
Forward simulation: model parameters → predicted data.
36+
37+
Solves the PDE using a direct solver through matrixn inversion.
38+
Total process:
39+
1. Gets a RHS tensor of shape 1xkxn for n parameters problem and k different sources
40+
2. Gets a A matrix tensor of fx1xnxn for n parameterrs and f different problems (e.g. different frequencies)
41+
3. Runs a direct solver over this to solve for a fxkxn field solution tensor
42+
4. for j different receivers projects the fields into fxj different measurements
43+
"""
44+
# 1. Apply mapping if provided
45+
46+
# 2. Get system matrices from PDE
47+
system_matrices = self.pde.get_system_matrices()
48+
49+
# 3. Get RHS vectors from PDE
50+
rhs_tensors = self.pde.get_rhs_tensors()
51+
52+
try:
53+
fields = batched_mumps_solve(system_matrices, rhs_tensors)
54+
except ImportError:
55+
print("MUMPS not installed, falling back")
56+
fields = batched_sparse_solve(system_matrices, rhs_tensors)
57+
58+
# 5. Project to data and return
59+
return self.pde.fields_to_data(fields)

0 commit comments

Comments
 (0)