Skip to content

Commit edb8dbe

Browse files
committed
[ci skip] Use boolean mask for slicing
1 parent 88b89b5 commit edb8dbe

File tree

1 file changed

+12
-26
lines changed

1 file changed

+12
-26
lines changed

src/jaxsim/rbda/mass_inverse.py

Lines changed: 12 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -154,6 +154,7 @@ def loop_body_pass2(carry: Pass2Carry, i: jtp.Int) -> tuple[Pass2Carry, None]:
154154

155155
# Compute the articulated-body inertia and bias force of this link.
156156
Ma = MA[i] - U[i] / d[i] @ U[i].T
157+
Fa = F[i, :, ν[i]] + U[i] @ M_inv[i, ν[i]]
157158

158159
M_inv_ii = 1 / d[i]
159160
M_inv = M_inv.at[i, i].set(M_inv_ii)
@@ -167,7 +168,7 @@ def propagate(
167168
) -> tuple[jtp.Matrix, jtp.Matrix]:
168169
MA, F = MA_F
169170

170-
Fa_λi = F[:, ν[i]] + U[i] @ M_inv[i, ν[i]]
171+
Fa_λi = F[λ[i], :, ν[i]] + i_X_λi[i].T @ Fa
171172
F = F.at[:, ν[i]].set(Fa_λi)
172173

173174
MA_λi = MA[λ[i]] + i_X_λi[i].T @ Ma @ i_X_λi[i]
@@ -198,13 +199,7 @@ def propagate(
198199
# Pass 3
199200
# ======
200201

201-
P = jnp.zeros(
202-
shape=(
203-
model.number_of_links(),
204-
model.number_of_links(),
205-
model.number_of_links(),
206-
)
207-
)
202+
P = jnp.zeros_like(F)
208203

209204
Pass3Carry = tuple[jtp.Matrix, jtp.Matrix, jtp.Matrix]
210205
pass_3_carry = (U, M_inv, P)
@@ -213,17 +208,13 @@ def loop_body_pass3(carry: Pass3Carry, i: jtp.Int) -> tuple[Pass3Carry, None]:
213208

214209
U, M_inv, P = carry
215210

216-
mask = jnp.arange(P.shape[1]) >= i # equivalent to [i, i:]
211+
mask = jnp.arange(P.shape[1]) >= i
217212

218213
def propagate_M_inv(M_inv: jtp.Matrix) -> jtp.Matrix:
219-
P_ii = jax.lax.dynamic_slice(
220-
P, (i, i - P.shape[1], P.shape[2]), (P.shape[0], 1) * mask
214+
P_λi = jnp.where(mask, i_X_λi[i].T @ P[λ[i], i], P[λ[i], i])
215+
M_inv = M_inv.at[i].set(
216+
jnp.where(mask, M_inv[i] - U[i].T @ P_λi / d[i], M_inv)
221217
)
222-
M_inv_ii = jax.lax.dynamic_slice(
223-
M_inv.squeeze(), (i, i - M_inv.squeeze().shape[0]), i_X_λi[i].shape
224-
)
225-
M_inv_ii = M_inv_ii.at[:].set(M_inv_ii - U[i].T @ i_X_λi[i] @ P_ii / d[i])
226-
jax.lax.dynamic_update_slice(M, M_inv_ii, (i, i))
227218

228219
return M_inv
229220

@@ -234,19 +225,14 @@ def propagate_M_inv(M_inv: jtp.Matrix) -> jtp.Matrix:
234225
operand=M_inv,
235226
)
236227

237-
M_inv_ii = jax.lax.dynamic_slice(
238-
M_inv, (i, i - d.shape[0], 1), (1, d[i].shape - i, 1)
239-
)
228+
M_inv_ii = M_inv[i] * mask
240229

241-
P_i = S[i].T @ M_inv_ii
242-
P = P.at[i].set(P_i.squeeze())
230+
P_ii = S[i].T @ M_inv_ii
231+
P = P.at[i].set(P_ii.squeeze())
243232

244233
def propagate_P(P: jtp.Vector) -> jtp.Vector:
245-
P_λii = jax.lax.dynamic_slice(P, (λ[i], i), (1, i))
246-
P_iii = jax.lax.dynamic_slice(P, (i, i), (1, i))
247-
248-
P_iii = P_iii.at[:].set(P_iii + i_X_λi[i].T @ P_λii)
249-
jax.lax.dynamic_update_slice(P, P_iii, (i, i))
234+
P_λi = jnp.where(mask, i_X_λi[i].T @ P[λ[i], i], P[λ[i], i])
235+
P = P.at[i].set(jnp.where(mask, P[i] + P_λi, P[i]))
250236

251237
return P
252238

0 commit comments

Comments
 (0)