99 "XLA_PYTHON_CLIENT_MEM_FRACTION"
1010] = "0.7" # see https://github.com/google/jax/discussions/6332#discussioncomment-1279991
1111
12+ import functools
13+
1214import flax
1315import flax .linen as nn
1416import gym
@@ -127,9 +129,23 @@ class TrainState(TrainState):
127129 target_params : flax .core .FrozenDict
128130
129131
130- def linear_schedule (start_e : float , end_e : float , duration : int , t : int ):
131- slope = (end_e - start_e ) / duration
132- return max (slope * t + start_e , end_e )
132+ @functools .partial (jax .jit , static_argnums = (1 , 2 , 3 , 5 , 6 , 7 ))
133+ def select_action (rng_seed , start_e , end_e , duration , t , num_actions , num_envs , network_def , online_params , obs ):
134+ def linear_schedule (start_e : float , end_e : float , duration : int , t : int ):
135+ slope = (end_e - start_e ) / duration
136+ return jnp .maximum (slope * t + start_e , end_e )
137+
138+ epsilon = linear_schedule (start_e , end_e , duration , t )
139+ rng , rng_1 , rng_2 = jax .random .split (rng_seed , 3 )
140+ return (
141+ rng ,
142+ jnp .where (
143+ jax .random .uniform (rng_1 ) < epsilon ,
144+ jax .random .randint (rng_2 , (num_envs ,), 0 , num_actions ),
145+ network_def .apply (online_params , obs ).argmax (axis = - 1 ),
146+ ),
147+ epsilon ,
148+ )
133149
134150
135151if __name__ == "__main__" :
@@ -157,7 +173,7 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
157173 random .seed (args .seed )
158174 np .random .seed (args .seed )
159175 key = jax .random .PRNGKey (args .seed )
160- key , q_key = jax .random .split (key , 2 )
176+ key , q_key , act_key = jax .random .split (key , 3 )
161177
162178 # env setup
163179 envs = gym .vector .SyncVectorEnv ([make_env (args .env_id , args .seed , 0 , args .capture_video , run_name )])
@@ -174,7 +190,6 @@ def linear_schedule(start_e: float, end_e: float, duration: int, t: int):
174190 tx = optax .adam (learning_rate = args .learning_rate ),
175191 )
176192
177- q_network .apply = jax .jit (q_network .apply )
178193 # This step is not necessary as init called on same observation and key will always lead to same initializations
179194 q_state = q_state .replace (target_params = optax .incremental_update (q_state .params , q_state .target_params , 1 ))
180195
@@ -208,14 +223,19 @@ def mse_loss(params):
208223 obs = envs .reset ()
209224 for global_step in range (args .total_timesteps ):
210225 # ALGO LOGIC: put action logic here
211- epsilon = linear_schedule (args .start_e , args .end_e , args .exploration_fraction * args .total_timesteps , global_step )
212- if random .random () < epsilon :
213- actions = np .array ([envs .single_action_space .sample () for _ in range (envs .num_envs )])
214- else :
215- # obs = jax.device_put(obs)
216- logits = q_network .apply (q_state .params , obs )
217- actions = logits .argmax (axis = - 1 )
218- actions = jax .device_get (actions )
226+ act_key , actions , epsilon = select_action (
227+ act_key ,
228+ args .start_e ,
229+ args .end_e ,
230+ args .exploration_fraction * args .total_timesteps ,
231+ global_step ,
232+ envs .single_action_space .n ,
233+ envs .num_envs ,
234+ q_network ,
235+ q_state .params ,
236+ obs ,
237+ )
238+ actions = jax .device_get (actions )
219239
220240 # TRY NOT TO MODIFY: execute the game and log data.
221241 next_obs , rewards , dones , infos = envs .step (actions )
@@ -226,7 +246,7 @@ def mse_loss(params):
226246 print (f"global_step={ global_step } , episodic_return={ info ['episode' ]['r' ]} " )
227247 writer .add_scalar ("charts/episodic_return" , info ["episode" ]["r" ], global_step )
228248 writer .add_scalar ("charts/episodic_length" , info ["episode" ]["l" ], global_step )
229- writer .add_scalar ("charts/epsilon" , epsilon , global_step )
249+ writer .add_scalar ("charts/epsilon" , jax . device_get ( epsilon ) , global_step )
230250 break
231251
232252 # TRY NOT TO MODIFY: save data to reply buffer; handle `terminal_observation`
0 commit comments