@@ -155,9 +155,9 @@ class RLTrainState(TrainState):
155
155
def sample_action (
156
156
actor : Actor ,
157
157
actor_state : TrainState ,
158
- observations : jnp . ndarray ,
159
- key : jax .random . KeyArray ,
160
- ) -> jnp . array :
158
+ observations : jax . Array ,
159
+ key : jax .Array ,
160
+ ):
161
161
key , subkey = jax .random .split (key , 2 )
162
162
mean , log_std = actor .apply (actor_state .params , observations )
163
163
action_std = jnp .exp (log_std )
@@ -168,9 +168,9 @@ def sample_action(
168
168
169
169
@jax .jit
170
170
def sample_action_and_log_prob (
171
- mean : jnp . ndarray ,
172
- log_std : jnp . ndarray ,
173
- subkey : jax .random . KeyArray ,
171
+ mean : jax . Array ,
172
+ log_std : jax . Array ,
173
+ subkey : jax .Array ,
174
174
):
175
175
action_std = jnp .exp (log_std )
176
176
gaussian_action = mean + action_std * jax .random .normal (subkey , shape = mean .shape )
@@ -182,7 +182,7 @@ def sample_action_and_log_prob(
182
182
183
183
184
184
@partial (jax .jit , static_argnames = "actor" )
185
- def select_action (actor : Actor , actor_state : TrainState , observations : jnp . ndarray ) -> jnp . array :
185
+ def select_action (actor : Actor , actor_state : TrainState , observations : jax . Array ) -> jax . Array :
186
186
return actor .apply (actor_state .params , observations )[0 ]
187
187
188
188
@@ -299,12 +299,12 @@ def update_critic(
299
299
actor_state : TrainState ,
300
300
qf_state : RLTrainState ,
301
301
ent_coef_value : jnp .ndarray ,
302
- observations : np . ndarray ,
303
- actions : np . ndarray ,
304
- next_observations : np . ndarray ,
305
- rewards : np . ndarray ,
306
- dones : np . ndarray ,
307
- key : jax .random . KeyArray ,
302
+ observations : jax . Array ,
303
+ actions : jax . Array ,
304
+ next_observations : jax . Array ,
305
+ rewards : jax . Array ,
306
+ dones : jax . Array ,
307
+ key : jax .Array ,
308
308
):
309
309
key , subkey = jax .random .split (key , 2 )
310
310
mean , log_std = actor .apply (actor_state .params , next_observations )
@@ -339,7 +339,7 @@ def update_actor(
339
339
qf_state : RLTrainState ,
340
340
ent_coef_value : jnp .ndarray ,
341
341
observations : np .ndarray ,
342
- key : jax .random . KeyArray ,
342
+ key : jax .Array ,
343
343
):
344
344
key , subkey = jax .random .split (key , 2 )
345
345
0 commit comments