Skip to content

Commit 694f3ce

Browse files
mjanuszcopybara-github
authored andcommitted
Add support for prefer_orig_order in 3d meshes.
PiperOrigin-RevId: 674657197
1 parent 4b3a011 commit 694f3ce

File tree

2 files changed

+157
-77
lines changed

2 files changed

+157
-77
lines changed

mesh.py

Lines changed: 125 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -28,23 +28,20 @@
2828
import collections
2929
import dataclasses
3030
import functools
31-
from typing import Optional, Sequence, Union
31+
from typing import Sequence
3232

3333
from absl import logging
34-
3534
import dataclasses_json
36-
3735
import jax
3836
import jax.numpy as jnp
3937
import numpy as np
4038

4139

4240
# NOTE: This is likely a good candidate for acceleration with a custom CUDA
4341
# kernel on GPUs.
44-
def inplane_force(x: jnp.ndarray,
45-
k: float,
46-
stride: float,
47-
prefer_orig_order=False) -> jnp.ndarray:
42+
def inplane_force(
43+
x: jnp.ndarray, k: float, stride: float, prefer_orig_order: bool = False
44+
) -> jnp.ndarray:
4845
"""Computes in-plane forces on the nodes of a spring mesh.
4946
5047
Args:
@@ -57,8 +54,8 @@ def inplane_force(x: jnp.ndarray,
5754
Returns:
5855
[2, z, y, x] array of forces
5956
"""
60-
l0 = stride
61-
l0_diag = jnp.sqrt(2.0) * l0
57+
l0 = np.array(stride)
58+
l0_diag = np.sqrt(2.0) * l0
6259

6360
def _xy_vec(x, y):
6461
return jnp.array([x, y]).reshape([2, 1, 1, 1])
@@ -105,24 +102,29 @@ def _xy_vec(x, y):
105102
dx = x[..., 1:] - x[..., :-1] + _xy_vec(l0, 0)
106103
l = jnp.linalg.norm(dx, axis=0)
107104
if prefer_orig_order:
108-
f1 = -k * (
109-
1. -
110-
l0 * jnp.array([jnp.sign(dx[0]), jnp.ones_like(dx[1])]) / l) * dx
105+
f1 = (
106+
-k
107+
* (1.0 - l0 * jnp.array([jnp.sign(dx[0]), jnp.ones_like(dx[1])]) / l)
108+
* dx
109+
)
111110
else:
112-
f1 = -k * (1. - l0 / l) * dx
113-
f1 = jnp.nan_to_num(f1, copy=False, posinf=0., neginf=0.)
111+
f1 = -k * (1.0 - l0 / l) * dx
112+
f1 = jnp.nan_to_num(f1, copy=False, posinf=0.0, neginf=0.0)
114113
f1p = jnp.pad(f1, ((0, 0), (0, 0), (0, 0), (1, 0)))
115114
f1n = jnp.pad(f1, ((0, 0), (0, 0), (0, 0), (0, 1)))
116115

117116
# | springs
118117
dx = x[..., 1:, :] - x[..., :-1, :] + _xy_vec(0, l0)
119118
l = jnp.linalg.norm(dx, axis=0)
120119
if prefer_orig_order:
121-
f2 = -k * (1. - l0 * jnp.array([jnp.ones_like(dx[0]),
122-
jnp.sign(dx[1])]) / l) * dx
120+
f2 = (
121+
-k
122+
* (1.0 - l0 * jnp.array([jnp.ones_like(dx[0]), jnp.sign(dx[1])]) / l)
123+
* dx
124+
)
123125
else:
124-
f2 = -k * (1. - l0 / l) * dx
125-
f2 = jnp.nan_to_num(f2, copy=False, posinf=0., neginf=0.)
126+
f2 = -k * (1.0 - l0 / l) * dx
127+
f2 = jnp.nan_to_num(f2, copy=False, posinf=0.0, neginf=0.0)
126128
f2p = jnp.pad(f2, ((0, 0), (0, 0), (1, 0), (0, 0)))
127129
f2n = jnp.pad(f2, ((0, 0), (0, 0), (0, 1), (0, 0)))
128130

@@ -133,23 +135,29 @@ def _xy_vec(x, y):
133135
dx = x[:, :, 1:, 1:] - x[:, :, :-1, :-1] + _xy_vec(l0, l0)
134136
l = jnp.linalg.norm(dx, axis=0)
135137
if prefer_orig_order:
136-
f3 = -k2 * (1. - l0_diag *
137-
jnp.array([jnp.sign(dx[0]), jnp.sign(dx[1])]) / l) * dx
138+
f3 = (
139+
-k2
140+
* (1.0 - l0_diag * jnp.array([jnp.sign(dx[0]), jnp.sign(dx[1])]) / l)
141+
* dx
142+
)
138143
else:
139-
f3 = -k2 * (1. - l0_diag / l) * dx
140-
f3 = jnp.nan_to_num(f3, copy=False, posinf=0., neginf=0.)
144+
f3 = -k2 * (1.0 - l0_diag / l) * dx
145+
f3 = jnp.nan_to_num(f3, copy=False, posinf=0.0, neginf=0.0)
141146
f3p = jnp.pad(f3, ((0, 0), (0, 0), (1, 0), (1, 0)))
142147
f3n = jnp.pad(f3, ((0, 0), (0, 0), (0, 1), (0, 1)))
143148

144149
# / springs
145150
dx = x[:, :, 1:, :-1] - x[:, :, :-1, 1:] + _xy_vec(-l0, l0)
146151
l = jnp.linalg.norm(dx, axis=0)
147152
if prefer_orig_order:
148-
f4 = -k2 * (1. - l0_diag *
149-
jnp.array([-jnp.sign(dx[0]), jnp.sign(dx[1])]) / l) * dx
153+
f4 = (
154+
-k2
155+
* (1.0 - l0_diag * jnp.array([-jnp.sign(dx[0]), jnp.sign(dx[1])]) / l)
156+
* dx
157+
)
150158
else:
151-
f4 = -k2 * (1. - l0_diag / l) * dx
152-
f4 = jnp.nan_to_num(f4, copy=False, posinf=0., neginf=0.)
159+
f4 = -k2 * (1.0 - l0_diag / l) * dx
160+
f4 = jnp.nan_to_num(f4, copy=False, posinf=0.0, neginf=0.0)
153161
f4p = jnp.pad(f4, ((0, 0), (0, 0), (1, 0), (0, 1)))
154162
f4n = jnp.pad(f4, ((0, 0), (0, 0), (0, 1), (1, 0)))
155163

@@ -172,14 +180,17 @@ def _xy_vec(x, y):
172180
(1, 1, 1),
173181
(1, 1, -1),
174182
(1, -1, 1),
175-
(-1, 1, 1))
183+
(-1, 1, 1),
184+
)
176185

177186

178-
def elastic_mesh_3d(x: jnp.ndarray,
179-
k: float,
180-
stride: Union[float, Sequence[float]],
181-
prefer_orig_order=False,
182-
links=MESH_LINK_DIRECTIONS) -> jnp.ndarray:
187+
def elastic_mesh_3d(
188+
x: jnp.ndarray,
189+
k: float,
190+
stride: float | Sequence[float],
191+
prefer_orig_order: bool = False,
192+
links=MESH_LINK_DIRECTIONS,
193+
) -> jnp.ndarray:
183194
"""Computes internal forces on the nodes of a 3d spring mesh.
184195
185196
Args:
@@ -188,16 +199,15 @@ def elastic_mesh_3d(x: jnp.ndarray,
188199
according to `stride` for all other springs to maintain constant
189200
elasticity
190201
stride: XYZ stride of the spring mesh grid
191-
prefer_orig_order: only False is supported
202+
prefer_orig_order: whether to change the force formulation so that the
203+
original relative spatial ordering of the nodes is energetically preferred
192204
links: sequence of XYZ tuples indcating node links to consider, relative to
193205
the node at (0, 0, 0); valid component values are {-1, 0, 1}
194206
195207
Returns:
196208
[3, z, y, x] array of forces
197209
"""
198210
assert x.shape[0] == 3
199-
if prefer_orig_order:
200-
raise NotImplementedError('prefer_orig_order not supported for 3d mesh.')
201211

202212
if not isinstance(stride, collections.abc.Sequence):
203213
stride = (stride,) * 3
@@ -206,7 +216,6 @@ def elastic_mesh_3d(x: jnp.ndarray,
206216
f_tot = None
207217
num_non_spatial = x.ndim - 3
208218
for direction in links:
209-
l0 = np.array(stride * direction).reshape([3] + [1] * (x.ndim - 1))
210219
# Select everything in non-spatial dimensions.
211220
sel1 = [np.s_[:]] * num_non_spatial
212221
sel2 = list(sel1)
@@ -232,11 +241,28 @@ def elastic_mesh_3d(x: jnp.ndarray,
232241
else:
233242
raise ValueError('Only |v| <= 1 values supported within links.')
234243

244+
l0 = np.array(stride * direction, dtype=np.float32).reshape(
245+
[3] + [1] * (x.ndim - 1)
246+
)
235247
dx = x[tuple(sel1)] - x[tuple(sel2)] + l0
236248
l0 = np.linalg.norm(l0)
237249
l = jnp.linalg.norm(dx, axis=0)
238-
f = -k * l0 / stride[0] * (1. - l0 / l) * dx
239-
f = jnp.nan_to_num(f, copy=False, posinf=0., neginf=0.)
250+
251+
# We want to maintain constant elasticity E and E ~ k⋅l0.
252+
# k is specified for the horizontal direction, and so l0 for it is
253+
# stride_x.
254+
k_eff = k * stride[0] / l0
255+
if prefer_orig_order:
256+
ones = jnp.ones_like(dx[0])
257+
factor = jnp.array([
258+
direction[0] * jnp.sign(dx[0]) if direction[0] != 0 else ones,
259+
direction[1] * jnp.sign(dx[1]) if direction[1] != 0 else ones,
260+
direction[2] * jnp.sign(dx[2]) if direction[2] != 0 else ones,
261+
])
262+
f = -k_eff * (1.0 - l0 * factor / l) * dx
263+
else:
264+
f = -k_eff * (1.0 - l0 / l) * dx
265+
f = jnp.nan_to_num(f, copy=False, posinf=0.0, neginf=0.0)
240266
fp = jnp.pad(f, pad_pos)
241267
if f_tot is None:
242268
f_tot = fp
@@ -275,7 +301,7 @@ class IntegrationConfig:
275301
f_dec: float = 0.5
276302
alpha: float = 0.1
277303
n_min: int = 5 # Min. number of steps after which to increase step size.
278-
dt_max: float = 10. # Max time step size, in units of 'dt'.
304+
dt_max: float = 10.0 # Max time step size, in units of 'dt'.
279305

280306
# Initial and final values of the inter-section force component magnitude cap.
281307
# start_cap != final_cap is only supported when using FIRE.
@@ -303,15 +329,17 @@ class IntegrationConfig:
303329

304330

305331
@functools.partial(jax.jit, static_argnames=['config', 'mesh_force', 'prev_fn'])
306-
def velocity_verlet(x: jnp.ndarray,
307-
v: jnp.ndarray,
308-
prev: Optional[jnp.ndarray],
309-
config: IntegrationConfig,
310-
force_cap: float,
311-
fire_dt=None,
312-
fire_alpha=None,
313-
mesh_force=inplane_force,
314-
prev_fn=None):
332+
def velocity_verlet(
333+
x: jnp.ndarray,
334+
v: jnp.ndarray,
335+
prev: jnp.ndarray | None,
336+
config: IntegrationConfig,
337+
force_cap: float,
338+
fire_dt: float | None = None,
339+
fire_alpha: float | None = None,
340+
mesh_force=inplane_force,
341+
prev_fn=None,
342+
):
315343
"""Executes a sequence of (damped) velocity Verlet steps.
316344
317345
Optionally uses the FIRE integrator. Disabling or reducing
@@ -373,7 +401,7 @@ def vv_step(t, state, dt, force_cap):
373401
a = _force(x, prev, force_cap)
374402

375403
fact0 = 1.0 / (1.0 + 0.5 * dt * config.gamma)
376-
fact1 = (1.0 - 0.5 * dt * config.gamma)
404+
fact1 = 1.0 - 0.5 * dt * config.gamma
377405
v = fact0 * (v * fact1 + 0.5 * dt * (a_prev + a))
378406
return x, v, a
379407

@@ -396,24 +424,32 @@ def fire_step(t, state):
396424
dt = jnp.where(
397425
power >= 0,
398426
jnp.where(
399-
n_pos > config.n_min, #
427+
n_pos > config.n_min,
400428
jnp.minimum(dt * config.f_inc, config.dt_max * config.dt),
401-
dt),
402-
dt * config.f_dec)
429+
dt,
430+
),
431+
dt * config.f_dec,
432+
)
403433
alpha = jnp.where(
404434
power >= 0,
405435
jnp.where(n_pos > config.n_min, alpha * config.f_alpha, alpha),
406-
config.alpha)
436+
config.alpha,
437+
)
407438

408439
cap = jnp.minimum(
409440
jnp.where(
410441
power >= 0,
411-
jnp.where((n_pos > 0) & ((n_pos % config.cap_upscale_every) == 0),
412-
config.cap_scale * cap, cap), #
413-
cap),
414-
config.final_cap)
415-
416-
v *= (power >= 0)
442+
jnp.where(
443+
(n_pos > 0) & ((n_pos % config.cap_upscale_every) == 0),
444+
config.cap_scale * cap,
445+
cap,
446+
),
447+
cap,
448+
),
449+
config.final_cap,
450+
)
451+
452+
v *= power >= 0
417453

418454
if config.remove_drift:
419455
# Remove any global drift and recenter the nodes.
@@ -430,21 +466,28 @@ def fire_step(t, state):
430466
if fire_dt is None:
431467
fire_dt = config.dt
432468

433-
return jax.lax.fori_loop(0, config.num_iters, fire_step,
434-
(x, v, a, fire_dt, fire_alpha, 0, force_cap))
469+
return jax.lax.fori_loop(
470+
0,
471+
config.num_iters,
472+
fire_step,
473+
(x, v, a, fire_dt, fire_alpha, 0, force_cap),
474+
)
435475
else:
436476
return jax.lax.fori_loop(
437-
0, config.num_iters,
477+
0,
478+
config.num_iters,
438479
functools.partial(vv_step, dt=config.dt, force_cap=force_cap),
439-
(x, v, a))
480+
(x, v, a),
481+
)
440482

441483

442484
def relax_mesh(
443485
x: jnp.ndarray,
444-
prev: Optional[jnp.ndarray],
486+
prev: jnp.ndarray | None,
445487
config: IntegrationConfig,
446488
mesh_force=inplane_force,
447-
prev_fn=None) -> tuple[jnp.ndarray, list[float], int]:
489+
prev_fn=None,
490+
) -> tuple[jnp.ndarray, list[float], int]:
448491
"""Simulates mesh relaxation.
449492
450493
Args:
@@ -473,10 +516,13 @@ def relax_mesh(
473516
if config.start_cap != config.final_cap:
474517
if not config.fire:
475518
raise NotImplementedError(
476-
'Adaptive force capping is only supported with FIRE.')
519+
'Adaptive force capping is only supported with FIRE.'
520+
)
477521
if config.cap_scale <= 1:
478-
raise ValueError('The scaling factor for the force cap has to be larger '
479-
'than 1 when the initial and final cap are different.')
522+
raise ValueError(
523+
'The scaling factor for the force cap has to be larger '
524+
'than 1 when the initial and final cap are different.'
525+
)
480526

481527
if prev is not None and prev_fn is not None:
482528
raise ValueError('Only one of: "prev" and "prev_fn" can be specified.')
@@ -491,7 +537,8 @@ def relax_mesh(
491537
fire_alpha=alpha,
492538
force_cap=cap,
493539
mesh_force=mesh_force,
494-
prev_fn=prev_fn)
540+
prev_fn=prev_fn,
541+
)
495542
t += config.num_iters
496543
x, v = state[:2]
497544
v_mag = jnp.linalg.norm(v, axis=0)
@@ -501,8 +548,15 @@ def relax_mesh(
501548
if config.fire:
502549
dt, alpha, n_pos, cap = state[-4:]
503550
logging.info(
504-
't=%r: dt=%f, alpha=%f, n_pos=%d, cap=%f, v_max=%f, e_kin=%f', t, dt,
505-
alpha, n_pos, cap, v_max, e_kin[-1])
551+
't=%r: dt=%f, alpha=%f, n_pos=%d, cap=%f, v_max=%f, e_kin=%f',
552+
t,
553+
dt,
554+
alpha,
555+
n_pos,
556+
cap,
557+
v_max,
558+
e_kin[-1],
559+
)
506560

507561
if v_max < config.stop_v_max:
508562
if cap >= config.final_cap:

0 commit comments

Comments
 (0)