Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/simsopt/geo/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .curvehelical import *
from .curverzfourier import *
from .curvexyzfourier import *
from .curvealongz import *
from .curvexyzfouriersymmetries import *
from .curveperturbed import *
from .curveobjectives import *
Expand Down
81 changes: 81 additions & 0 deletions src/simsopt/geo/curvealongz.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import jax.numpy as jnp
from math import pi
import numpy as np
from .curve import JaxCurve

__all__ = ['CurveAlongZ']


def jaxcurvealongz_pure(dofs, quadpoints, zscale):
"""
Pure function for the CurveAlongZ, returns the points along the curve given the degrees of freedom and the quadpoints.

Args:
dofs: [xpos, ypos]; degrees of freedom of the curve, must be len(3)
quadpoints: points in [0,1]; points on which to evaluate the curve
"""
lenquadpoints = len(quadpoints)
x = dofs[0]*jnp.ones(lenquadpoints)
y = dofs[1]*jnp.ones(lenquadpoints)
if lenquadpoints < 2:
halfstep = 0 # don't offset for one point or two points
else:
halfstep = .5/(lenquadpoints-1) # avoid evaluating at 0 or 1.
z = jnp.tan(((quadpoints + halfstep) - .5)*pi)*zscale
gamma = jnp.stack((x, y, z), axis=1)
return gamma


class CurveAlongZ(JaxCurve):
r'''
Straight vertical curve, parallel to the z-axis.

Useful for quickly generating a toroidal field, comparing to tokamak equilibria
where an axisymmetric 1/R field is present, and testing.
Copy link
Contributor

@andrewgiuliani andrewgiuliani May 20, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do you really want this approximation of an axisymmetric 1/R field? why not just implement the analytical magnetic field instead of this approximation?

see how we do this here:

class ScalarPotentialRZMagneticField(MagneticField):


Degrees of freedom are the x, y coordinates of the vertical coil.
Displacing this from [0,0] can give a 1/1 perturbation if you feel like it.


Args:
quadpoints: number of grid points/resolution along the curve;
xpos: the x-coordinate of the vertical coil
ypos: the y-coordinate of the vertical coil
zscale: points are closer together at z=0, and spread apart using a zscale*tan(pi*(gamma-.5)) scaling.
'''

def __init__(self, quadpoints, xpos=0., ypos=0., zscale=10, fix_dofs=True, **kwargs):
if isinstance(quadpoints, int):
quadpoints = np.linspace(0, 1, quadpoints, endpoint=False)
self.xpos = xpos
self.ypos = ypos
self.set_dofs_impl([self.xpos, self.ypos])
self._fix_dofs = fix_dofs
self.zscale = zscale


pure = lambda dofs, points: jaxcurvealongz_pure(
dofs, points, self.zscale)



super().__init__(quadpoints, pure, x0=np.array([self.xpos, self.ypos]), names=self.make_dof_names(), **kwargs)
# unless you are doing strange things, you don't want to move the coil so we
# set the dofs fixed.
if fix_dofs:
self.fix_all()

def num_dofs(self):
return 2

def get_dofs(self):
return np.array([self.xpos, self.ypos])

def set_dofs_impl(self, dofs):
self.xpos = dofs[0]
self.ypos = dofs[1]


def make_dof_names(self):
return ['xpos', 'ypos']

2 changes: 2 additions & 0 deletions src/simsopt/util/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os

from .helper_coils import *
from .mpi import *
from .logger import *
from .famus_helpers import *
Expand All @@ -15,5 +16,6 @@
+ famus_helpers.__all__
+ polarization_project.__all__
+ permanent_magnet_helper_functions.__all__
+ helper_coils.__all__
+ ['in_github_actions']
)
32 changes: 32 additions & 0 deletions src/simsopt/util/helper_coils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
"""
This file contains a helper function for coils.
Currently only a simple function that returns a simsopt.field.Coil object that represents a current along the z-axis.

Useful for adding a toroidal field, for example to perturb a stellarator equilibrium.
"""

__all__ = ['current_along_z',]

from simsopt.geo import CurveAlongZ
from simsopt.field import Coil, Current

__all__ = ['current_along_z',]


def current_along_z(current, quadpoints=100, x0=0., y0=0., zscale=10, coil_dofs_fixed=True):
"""
Returns a Coil object that represents a current along the z-axis.
Add this to your a coilset before calling simsopt.BiotSavart to
add a toroidal field.

The dofs of the coil are only the current value, unles you set coil_dofs_fixed=False.

Args:
current: the current in Amperes
quadpoints: number of grid points/resolution along the curve;
x0: (default 0) the x-coordinate .
y0: (default 0) the y-coordinate.
zscale: (default 10) points are closer together at z=0, and spread apart using a zscale*tan(pi*(gamma-.5)) scaling.
coil_dofs_fixed: (default True) unless you are doing strange things, you don't want to move the coil from the axis.
"""
return Coil(CurveAlongZ(quadpoints, x0, y0, zscale, coil_dofs_fixed=coil_dofs_fixed), Current(current))
8 changes: 7 additions & 1 deletion tests/geo/test_curve.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from simsopt.geo.curverzfourier import CurveRZFourier
from simsopt.geo.curveplanarfourier import CurvePlanarFourier
from simsopt.geo.curvehelical import CurveHelical
from simsopt.geo.curvealongz import CurveAlongZ
from simsopt.geo.curvexyzfouriersymmetries import CurveXYZFourierSymmetries
from simsopt.geo.curve import RotatedCurve, curves_to_vtk
from simsopt.geo import parameters
Expand Down Expand Up @@ -83,6 +84,8 @@ def get_curve(curvetype, rotated, x=np.asarray([0.5])):
curve = CurveXYZFourierSymmetries(x, order, 2, False)
elif curvetype == "CurveXYZFourierSymmetries3":
curve = CurveXYZFourierSymmetries(x, order, 2, False, ntor=3)
elif curvetype == "CurveAlongZ":
curve = CurveAlongZ(x, 0., 0., 10, False)
else:
assert False

Expand Down Expand Up @@ -124,6 +127,9 @@ def get_curve(curvetype, rotated, x=np.asarray([0.5])):
curve.set('zc(0)', 1)
curve.set('zs(1)', r)
dofs = curve.get_dofs()
elif curvetype == "CurveAlongZ":
curve.set_dofs(np.array([0.5, 0]))
curve.set('xpos', 0.0)
else:
assert False

Expand All @@ -136,7 +142,7 @@ def get_curve(curvetype, rotated, x=np.asarray([0.5])):

class Testing(unittest.TestCase):

curvetypes = ["CurveXYZFourier", "JaxCurveXYZFourier", "CurveRZFourier", "CurvePlanarFourier", "CurveHelical", "CurveXYZFourierSymmetries1","CurveXYZFourierSymmetries2", "CurveXYZFourierSymmetries3", "CurveHelicalInitx0"]
curvetypes = ["CurveXYZFourier", "JaxCurveXYZFourier", "CurveRZFourier", "CurvePlanarFourier", "CurveHelical", "CurveXYZFourierSymmetries1","CurveXYZFourierSymmetries2", "CurveXYZFourierSymmetries3", "CurveHelicalInitx0", "CurveAlongZ"]

def get_curvexyzfouriersymmetries(self, stellsym=True, x=None, nfp=None, ntor=1):
# returns a CurveXYZFourierSymmetries that is randomly perturbed
Expand Down
Empty file.
Loading