Skip to content

Commit 8eea5b4

Browse files
committed
Update types
1 parent debf117 commit 8eea5b4

File tree

1 file changed

+14
-14
lines changed

1 file changed

+14
-14
lines changed

cleanrl/sac_continuous_action_jax.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -155,9 +155,9 @@ class RLTrainState(TrainState):
155155
def sample_action(
156156
actor: Actor,
157157
actor_state: TrainState,
158-
observations: jnp.ndarray,
159-
key: jax.random.KeyArray,
160-
) -> jnp.array:
158+
observations: jax.Array,
159+
key: jax.Array,
160+
):
161161
key, subkey = jax.random.split(key, 2)
162162
mean, log_std = actor.apply(actor_state.params, observations)
163163
action_std = jnp.exp(log_std)
@@ -168,9 +168,9 @@ def sample_action(
168168

169169
@jax.jit
170170
def sample_action_and_log_prob(
171-
mean: jnp.ndarray,
172-
log_std: jnp.ndarray,
173-
subkey: jax.random.KeyArray,
171+
mean: jax.Array,
172+
log_std: jax.Array,
173+
subkey: jax.Array,
174174
):
175175
action_std = jnp.exp(log_std)
176176
gaussian_action = mean + action_std * jax.random.normal(subkey, shape=mean.shape)
@@ -182,7 +182,7 @@ def sample_action_and_log_prob(
182182

183183

184184
@partial(jax.jit, static_argnames="actor")
185-
def select_action(actor: Actor, actor_state: TrainState, observations: jnp.ndarray) -> jnp.array:
185+
def select_action(actor: Actor, actor_state: TrainState, observations: jax.Array) -> jax.Array:
186186
return actor.apply(actor_state.params, observations)[0]
187187

188188

@@ -299,12 +299,12 @@ def update_critic(
299299
actor_state: TrainState,
300300
qf_state: RLTrainState,
301301
ent_coef_value: jnp.ndarray,
302-
observations: np.ndarray,
303-
actions: np.ndarray,
304-
next_observations: np.ndarray,
305-
rewards: np.ndarray,
306-
dones: np.ndarray,
307-
key: jax.random.KeyArray,
302+
observations: jax.Array,
303+
actions: jax.Array,
304+
next_observations: jax.Array,
305+
rewards: jax.Array,
306+
dones: jax.Array,
307+
key: jax.Array,
308308
):
309309
key, subkey = jax.random.split(key, 2)
310310
mean, log_std = actor.apply(actor_state.params, next_observations)
@@ -339,7 +339,7 @@ def update_actor(
339339
qf_state: RLTrainState,
340340
ent_coef_value: jnp.ndarray,
341341
observations: np.ndarray,
342-
key: jax.random.KeyArray,
342+
key: jax.Array,
343343
):
344344
key, subkey = jax.random.split(key, 2)
345345

0 commit comments

Comments
 (0)