@@ -154,6 +154,7 @@ def loop_body_pass2(carry: Pass2Carry, i: jtp.Int) -> tuple[Pass2Carry, None]:
154154
155155 # Compute the articulated-body inertia and bias force of this link.
156156 Ma = MA [i ] - U [i ] / d [i ] @ U [i ].T
157+ Fa = F [i , :, ν [i ]] + U [i ] @ M_inv [i , ν [i ]]
157158
158159 M_inv_ii = 1 / d [i ]
159160 M_inv = M_inv .at [i , i ].set (M_inv_ii )
@@ -167,7 +168,7 @@ def propagate(
167168 ) -> tuple [jtp .Matrix , jtp .Matrix ]:
168169 MA , F = MA_F
169170
170- Fa_λi = F [:, ν [i ]] + U [i ] @ M_inv [ i , ν [ i ]]
171+ Fa_λi = F [λ [ i ], :, ν [i ]] + i_X_λi [i ]. T @ Fa
171172 F = F .at [:, ν [i ]].set (Fa_λi )
172173
173174 MA_λi = MA [λ [i ]] + i_X_λi [i ].T @ Ma @ i_X_λi [i ]
@@ -198,13 +199,7 @@ def propagate(
198199 # Pass 3
199200 # ======
200201
201- P = jnp .zeros (
202- shape = (
203- model .number_of_links (),
204- model .number_of_links (),
205- model .number_of_links (),
206- )
207- )
202+ P = jnp .zeros_like (F )
208203
209204 Pass3Carry = tuple [jtp .Matrix , jtp .Matrix , jtp .Matrix ]
210205 pass_3_carry = (U , M_inv , P )
@@ -213,17 +208,13 @@ def loop_body_pass3(carry: Pass3Carry, i: jtp.Int) -> tuple[Pass3Carry, None]:
213208
214209 U , M_inv , P = carry
215210
216- mask = jnp .arange (P .shape [1 ]) >= i # equivalent to [i, i:]
211+ mask = jnp .arange (P .shape [1 ]) >= i
217212
218213 def propagate_M_inv (M_inv : jtp .Matrix ) -> jtp .Matrix :
219- P_ii = jax .lax .dynamic_slice (
220- P , (i , i - P .shape [1 ], P .shape [2 ]), (P .shape [0 ], 1 ) * mask
214+ P_λi = jnp .where (mask , i_X_λi [i ].T @ P [λ [i ], i ], P [λ [i ], i ])
215+ M_inv = M_inv .at [i ].set (
216+ jnp .where (mask , M_inv [i ] - U [i ].T @ P_λi / d [i ], M_inv )
221217 )
222- M_inv_ii = jax .lax .dynamic_slice (
223- M_inv .squeeze (), (i , i - M_inv .squeeze ().shape [0 ]), i_X_λi [i ].shape
224- )
225- M_inv_ii = M_inv_ii .at [:].set (M_inv_ii - U [i ].T @ i_X_λi [i ] @ P_ii / d [i ])
226- jax .lax .dynamic_update_slice (M , M_inv_ii , (i , i ))
227218
228219 return M_inv
229220
@@ -234,19 +225,14 @@ def propagate_M_inv(M_inv: jtp.Matrix) -> jtp.Matrix:
234225 operand = M_inv ,
235226 )
236227
237- M_inv_ii = jax .lax .dynamic_slice (
238- M_inv , (i , i - d .shape [0 ], 1 ), (1 , d [i ].shape - i , 1 )
239- )
228+ M_inv_ii = M_inv [i ] * mask
240229
241- P_i = S [i ].T @ M_inv_ii
242- P = P .at [i ].set (P_i .squeeze ())
230+ P_ii = S [i ].T @ M_inv_ii
231+ P = P .at [i ].set (P_ii .squeeze ())
243232
244233 def propagate_P (P : jtp .Vector ) -> jtp .Vector :
245- P_λii = jax .lax .dynamic_slice (P , (λ [i ], i ), (1 , i ))
246- P_iii = jax .lax .dynamic_slice (P , (i , i ), (1 , i ))
247-
248- P_iii = P_iii .at [:].set (P_iii + i_X_λi [i ].T @ P_λii )
249- jax .lax .dynamic_update_slice (P , P_iii , (i , i ))
234+ P_λi = jnp .where (mask , i_X_λi [i ].T @ P [λ [i ], i ], P [λ [i ], i ])
235+ P = P .at [i ].set (jnp .where (mask , P [i ] + P_λi , P [i ]))
250236
251237 return P
252238
0 commit comments