28
28
import collections
29
29
import dataclasses
30
30
import functools
31
- from typing import Optional , Sequence , Union
31
+ from typing import Sequence
32
32
33
33
from absl import logging
34
-
35
34
import dataclasses_json
36
-
37
35
import jax
38
36
import jax .numpy as jnp
39
37
import numpy as np
40
38
41
39
42
40
# NOTE: This is likely a good candidate for acceleration with a custom CUDA
43
41
# kernel on GPUs.
44
- def inplane_force (x : jnp .ndarray ,
45
- k : float ,
46
- stride : float ,
47
- prefer_orig_order = False ) -> jnp .ndarray :
42
+ def inplane_force (
43
+ x : jnp .ndarray , k : float , stride : float , prefer_orig_order : bool = False
44
+ ) -> jnp .ndarray :
48
45
"""Computes in-plane forces on the nodes of a spring mesh.
49
46
50
47
Args:
@@ -57,8 +54,8 @@ def inplane_force(x: jnp.ndarray,
57
54
Returns:
58
55
[2, z, y, x] array of forces
59
56
"""
60
- l0 = stride
61
- l0_diag = jnp .sqrt (2.0 ) * l0
57
+ l0 = np . array ( stride )
58
+ l0_diag = np .sqrt (2.0 ) * l0
62
59
63
60
def _xy_vec (x , y ):
64
61
return jnp .array ([x , y ]).reshape ([2 , 1 , 1 , 1 ])
@@ -105,24 +102,29 @@ def _xy_vec(x, y):
105
102
dx = x [..., 1 :] - x [..., :- 1 ] + _xy_vec (l0 , 0 )
106
103
l = jnp .linalg .norm (dx , axis = 0 )
107
104
if prefer_orig_order :
108
- f1 = - k * (
109
- 1. -
110
- l0 * jnp .array ([jnp .sign (dx [0 ]), jnp .ones_like (dx [1 ])]) / l ) * dx
105
+ f1 = (
106
+ - k
107
+ * (1.0 - l0 * jnp .array ([jnp .sign (dx [0 ]), jnp .ones_like (dx [1 ])]) / l )
108
+ * dx
109
+ )
111
110
else :
112
- f1 = - k * (1. - l0 / l ) * dx
113
- f1 = jnp .nan_to_num (f1 , copy = False , posinf = 0. , neginf = 0. )
111
+ f1 = - k * (1.0 - l0 / l ) * dx
112
+ f1 = jnp .nan_to_num (f1 , copy = False , posinf = 0.0 , neginf = 0.0 )
114
113
f1p = jnp .pad (f1 , ((0 , 0 ), (0 , 0 ), (0 , 0 ), (1 , 0 )))
115
114
f1n = jnp .pad (f1 , ((0 , 0 ), (0 , 0 ), (0 , 0 ), (0 , 1 )))
116
115
117
116
# | springs
118
117
dx = x [..., 1 :, :] - x [..., :- 1 , :] + _xy_vec (0 , l0 )
119
118
l = jnp .linalg .norm (dx , axis = 0 )
120
119
if prefer_orig_order :
121
- f2 = - k * (1. - l0 * jnp .array ([jnp .ones_like (dx [0 ]),
122
- jnp .sign (dx [1 ])]) / l ) * dx
120
+ f2 = (
121
+ - k
122
+ * (1.0 - l0 * jnp .array ([jnp .ones_like (dx [0 ]), jnp .sign (dx [1 ])]) / l )
123
+ * dx
124
+ )
123
125
else :
124
- f2 = - k * (1. - l0 / l ) * dx
125
- f2 = jnp .nan_to_num (f2 , copy = False , posinf = 0. , neginf = 0. )
126
+ f2 = - k * (1.0 - l0 / l ) * dx
127
+ f2 = jnp .nan_to_num (f2 , copy = False , posinf = 0.0 , neginf = 0.0 )
126
128
f2p = jnp .pad (f2 , ((0 , 0 ), (0 , 0 ), (1 , 0 ), (0 , 0 )))
127
129
f2n = jnp .pad (f2 , ((0 , 0 ), (0 , 0 ), (0 , 1 ), (0 , 0 )))
128
130
@@ -133,23 +135,29 @@ def _xy_vec(x, y):
133
135
dx = x [:, :, 1 :, 1 :] - x [:, :, :- 1 , :- 1 ] + _xy_vec (l0 , l0 )
134
136
l = jnp .linalg .norm (dx , axis = 0 )
135
137
if prefer_orig_order :
136
- f3 = - k2 * (1. - l0_diag *
137
- jnp .array ([jnp .sign (dx [0 ]), jnp .sign (dx [1 ])]) / l ) * dx
138
+ f3 = (
139
+ - k2
140
+ * (1.0 - l0_diag * jnp .array ([jnp .sign (dx [0 ]), jnp .sign (dx [1 ])]) / l )
141
+ * dx
142
+ )
138
143
else :
139
- f3 = - k2 * (1. - l0_diag / l ) * dx
140
- f3 = jnp .nan_to_num (f3 , copy = False , posinf = 0. , neginf = 0. )
144
+ f3 = - k2 * (1.0 - l0_diag / l ) * dx
145
+ f3 = jnp .nan_to_num (f3 , copy = False , posinf = 0.0 , neginf = 0.0 )
141
146
f3p = jnp .pad (f3 , ((0 , 0 ), (0 , 0 ), (1 , 0 ), (1 , 0 )))
142
147
f3n = jnp .pad (f3 , ((0 , 0 ), (0 , 0 ), (0 , 1 ), (0 , 1 )))
143
148
144
149
# / springs
145
150
dx = x [:, :, 1 :, :- 1 ] - x [:, :, :- 1 , 1 :] + _xy_vec (- l0 , l0 )
146
151
l = jnp .linalg .norm (dx , axis = 0 )
147
152
if prefer_orig_order :
148
- f4 = - k2 * (1. - l0_diag *
149
- jnp .array ([- jnp .sign (dx [0 ]), jnp .sign (dx [1 ])]) / l ) * dx
153
+ f4 = (
154
+ - k2
155
+ * (1.0 - l0_diag * jnp .array ([- jnp .sign (dx [0 ]), jnp .sign (dx [1 ])]) / l )
156
+ * dx
157
+ )
150
158
else :
151
- f4 = - k2 * (1. - l0_diag / l ) * dx
152
- f4 = jnp .nan_to_num (f4 , copy = False , posinf = 0. , neginf = 0. )
159
+ f4 = - k2 * (1.0 - l0_diag / l ) * dx
160
+ f4 = jnp .nan_to_num (f4 , copy = False , posinf = 0.0 , neginf = 0.0 )
153
161
f4p = jnp .pad (f4 , ((0 , 0 ), (0 , 0 ), (1 , 0 ), (0 , 1 )))
154
162
f4n = jnp .pad (f4 , ((0 , 0 ), (0 , 0 ), (0 , 1 ), (1 , 0 )))
155
163
@@ -172,14 +180,17 @@ def _xy_vec(x, y):
172
180
(1 , 1 , 1 ),
173
181
(1 , 1 , - 1 ),
174
182
(1 , - 1 , 1 ),
175
- (- 1 , 1 , 1 ))
183
+ (- 1 , 1 , 1 ),
184
+ )
176
185
177
186
178
- def elastic_mesh_3d (x : jnp .ndarray ,
179
- k : float ,
180
- stride : Union [float , Sequence [float ]],
181
- prefer_orig_order = False ,
182
- links = MESH_LINK_DIRECTIONS ) -> jnp .ndarray :
187
+ def elastic_mesh_3d (
188
+ x : jnp .ndarray ,
189
+ k : float ,
190
+ stride : float | Sequence [float ],
191
+ prefer_orig_order : bool = False ,
192
+ links = MESH_LINK_DIRECTIONS ,
193
+ ) -> jnp .ndarray :
183
194
"""Computes internal forces on the nodes of a 3d spring mesh.
184
195
185
196
Args:
@@ -188,16 +199,15 @@ def elastic_mesh_3d(x: jnp.ndarray,
188
199
according to `stride` for all other springs to maintain constant
189
200
elasticity
190
201
stride: XYZ stride of the spring mesh grid
191
- prefer_orig_order: only False is supported
202
+ prefer_orig_order: whether to change the force formulation so that the
203
+ original relative spatial ordering of the nodes is energetically preferred
192
204
links: sequence of XYZ tuples indcating node links to consider, relative to
193
205
the node at (0, 0, 0); valid component values are {-1, 0, 1}
194
206
195
207
Returns:
196
208
[3, z, y, x] array of forces
197
209
"""
198
210
assert x .shape [0 ] == 3
199
- if prefer_orig_order :
200
- raise NotImplementedError ('prefer_orig_order not supported for 3d mesh.' )
201
211
202
212
if not isinstance (stride , collections .abc .Sequence ):
203
213
stride = (stride ,) * 3
@@ -206,7 +216,6 @@ def elastic_mesh_3d(x: jnp.ndarray,
206
216
f_tot = None
207
217
num_non_spatial = x .ndim - 3
208
218
for direction in links :
209
- l0 = np .array (stride * direction ).reshape ([3 ] + [1 ] * (x .ndim - 1 ))
210
219
# Select everything in non-spatial dimensions.
211
220
sel1 = [np .s_ [:]] * num_non_spatial
212
221
sel2 = list (sel1 )
@@ -232,11 +241,28 @@ def elastic_mesh_3d(x: jnp.ndarray,
232
241
else :
233
242
raise ValueError ('Only |v| <= 1 values supported within links.' )
234
243
244
+ l0 = np .array (stride * direction , dtype = np .float32 ).reshape (
245
+ [3 ] + [1 ] * (x .ndim - 1 )
246
+ )
235
247
dx = x [tuple (sel1 )] - x [tuple (sel2 )] + l0
236
248
l0 = np .linalg .norm (l0 )
237
249
l = jnp .linalg .norm (dx , axis = 0 )
238
- f = - k * l0 / stride [0 ] * (1. - l0 / l ) * dx
239
- f = jnp .nan_to_num (f , copy = False , posinf = 0. , neginf = 0. )
250
+
251
+ # We want to maintain constant elasticity E and E ~ k⋅l0.
252
+ # k is specified for the horizontal direction, and so l0 for it is
253
+ # stride_x.
254
+ k_eff = k * stride [0 ] / l0
255
+ if prefer_orig_order :
256
+ ones = jnp .ones_like (dx [0 ])
257
+ factor = jnp .array ([
258
+ direction [0 ] * jnp .sign (dx [0 ]) if direction [0 ] != 0 else ones ,
259
+ direction [1 ] * jnp .sign (dx [1 ]) if direction [1 ] != 0 else ones ,
260
+ direction [2 ] * jnp .sign (dx [2 ]) if direction [2 ] != 0 else ones ,
261
+ ])
262
+ f = - k_eff * (1.0 - l0 * factor / l ) * dx
263
+ else :
264
+ f = - k_eff * (1.0 - l0 / l ) * dx
265
+ f = jnp .nan_to_num (f , copy = False , posinf = 0.0 , neginf = 0.0 )
240
266
fp = jnp .pad (f , pad_pos )
241
267
if f_tot is None :
242
268
f_tot = fp
@@ -275,7 +301,7 @@ class IntegrationConfig:
275
301
f_dec : float = 0.5
276
302
alpha : float = 0.1
277
303
n_min : int = 5 # Min. number of steps after which to increase step size.
278
- dt_max : float = 10. # Max time step size, in units of 'dt'.
304
+ dt_max : float = 10.0 # Max time step size, in units of 'dt'.
279
305
280
306
# Initial and final values of the inter-section force component magnitude cap.
281
307
# start_cap != final_cap is only supported when using FIRE.
@@ -303,15 +329,17 @@ class IntegrationConfig:
303
329
304
330
305
331
@functools .partial (jax .jit , static_argnames = ['config' , 'mesh_force' , 'prev_fn' ])
306
- def velocity_verlet (x : jnp .ndarray ,
307
- v : jnp .ndarray ,
308
- prev : Optional [jnp .ndarray ],
309
- config : IntegrationConfig ,
310
- force_cap : float ,
311
- fire_dt = None ,
312
- fire_alpha = None ,
313
- mesh_force = inplane_force ,
314
- prev_fn = None ):
332
+ def velocity_verlet (
333
+ x : jnp .ndarray ,
334
+ v : jnp .ndarray ,
335
+ prev : jnp .ndarray | None ,
336
+ config : IntegrationConfig ,
337
+ force_cap : float ,
338
+ fire_dt : float | None = None ,
339
+ fire_alpha : float | None = None ,
340
+ mesh_force = inplane_force ,
341
+ prev_fn = None ,
342
+ ):
315
343
"""Executes a sequence of (damped) velocity Verlet steps.
316
344
317
345
Optionally uses the FIRE integrator. Disabling or reducing
@@ -373,7 +401,7 @@ def vv_step(t, state, dt, force_cap):
373
401
a = _force (x , prev , force_cap )
374
402
375
403
fact0 = 1.0 / (1.0 + 0.5 * dt * config .gamma )
376
- fact1 = ( 1.0 - 0.5 * dt * config .gamma )
404
+ fact1 = 1.0 - 0.5 * dt * config .gamma
377
405
v = fact0 * (v * fact1 + 0.5 * dt * (a_prev + a ))
378
406
return x , v , a
379
407
@@ -396,24 +424,32 @@ def fire_step(t, state):
396
424
dt = jnp .where (
397
425
power >= 0 ,
398
426
jnp .where (
399
- n_pos > config .n_min , #
427
+ n_pos > config .n_min ,
400
428
jnp .minimum (dt * config .f_inc , config .dt_max * config .dt ),
401
- dt ),
402
- dt * config .f_dec )
429
+ dt ,
430
+ ),
431
+ dt * config .f_dec ,
432
+ )
403
433
alpha = jnp .where (
404
434
power >= 0 ,
405
435
jnp .where (n_pos > config .n_min , alpha * config .f_alpha , alpha ),
406
- config .alpha )
436
+ config .alpha ,
437
+ )
407
438
408
439
cap = jnp .minimum (
409
440
jnp .where (
410
441
power >= 0 ,
411
- jnp .where ((n_pos > 0 ) & ((n_pos % config .cap_upscale_every ) == 0 ),
412
- config .cap_scale * cap , cap ), #
413
- cap ),
414
- config .final_cap )
415
-
416
- v *= (power >= 0 )
442
+ jnp .where (
443
+ (n_pos > 0 ) & ((n_pos % config .cap_upscale_every ) == 0 ),
444
+ config .cap_scale * cap ,
445
+ cap ,
446
+ ),
447
+ cap ,
448
+ ),
449
+ config .final_cap ,
450
+ )
451
+
452
+ v *= power >= 0
417
453
418
454
if config .remove_drift :
419
455
# Remove any global drift and recenter the nodes.
@@ -430,21 +466,28 @@ def fire_step(t, state):
430
466
if fire_dt is None :
431
467
fire_dt = config .dt
432
468
433
- return jax .lax .fori_loop (0 , config .num_iters , fire_step ,
434
- (x , v , a , fire_dt , fire_alpha , 0 , force_cap ))
469
+ return jax .lax .fori_loop (
470
+ 0 ,
471
+ config .num_iters ,
472
+ fire_step ,
473
+ (x , v , a , fire_dt , fire_alpha , 0 , force_cap ),
474
+ )
435
475
else :
436
476
return jax .lax .fori_loop (
437
- 0 , config .num_iters ,
477
+ 0 ,
478
+ config .num_iters ,
438
479
functools .partial (vv_step , dt = config .dt , force_cap = force_cap ),
439
- (x , v , a ))
480
+ (x , v , a ),
481
+ )
440
482
441
483
442
484
def relax_mesh (
443
485
x : jnp .ndarray ,
444
- prev : Optional [ jnp .ndarray ] ,
486
+ prev : jnp .ndarray | None ,
445
487
config : IntegrationConfig ,
446
488
mesh_force = inplane_force ,
447
- prev_fn = None ) -> tuple [jnp .ndarray , list [float ], int ]:
489
+ prev_fn = None ,
490
+ ) -> tuple [jnp .ndarray , list [float ], int ]:
448
491
"""Simulates mesh relaxation.
449
492
450
493
Args:
@@ -473,10 +516,13 @@ def relax_mesh(
473
516
if config .start_cap != config .final_cap :
474
517
if not config .fire :
475
518
raise NotImplementedError (
476
- 'Adaptive force capping is only supported with FIRE.' )
519
+ 'Adaptive force capping is only supported with FIRE.'
520
+ )
477
521
if config .cap_scale <= 1 :
478
- raise ValueError ('The scaling factor for the force cap has to be larger '
479
- 'than 1 when the initial and final cap are different.' )
522
+ raise ValueError (
523
+ 'The scaling factor for the force cap has to be larger '
524
+ 'than 1 when the initial and final cap are different.'
525
+ )
480
526
481
527
if prev is not None and prev_fn is not None :
482
528
raise ValueError ('Only one of: "prev" and "prev_fn" can be specified.' )
@@ -491,7 +537,8 @@ def relax_mesh(
491
537
fire_alpha = alpha ,
492
538
force_cap = cap ,
493
539
mesh_force = mesh_force ,
494
- prev_fn = prev_fn )
540
+ prev_fn = prev_fn ,
541
+ )
495
542
t += config .num_iters
496
543
x , v = state [:2 ]
497
544
v_mag = jnp .linalg .norm (v , axis = 0 )
@@ -501,8 +548,15 @@ def relax_mesh(
501
548
if config .fire :
502
549
dt , alpha , n_pos , cap = state [- 4 :]
503
550
logging .info (
504
- 't=%r: dt=%f, alpha=%f, n_pos=%d, cap=%f, v_max=%f, e_kin=%f' , t , dt ,
505
- alpha , n_pos , cap , v_max , e_kin [- 1 ])
551
+ 't=%r: dt=%f, alpha=%f, n_pos=%d, cap=%f, v_max=%f, e_kin=%f' ,
552
+ t ,
553
+ dt ,
554
+ alpha ,
555
+ n_pos ,
556
+ cap ,
557
+ v_max ,
558
+ e_kin [- 1 ],
559
+ )
506
560
507
561
if v_max < config .stop_v_max :
508
562
if cap >= config .final_cap :
0 commit comments