1- from importlib import resources
1+ import os
2+ from importlib .util import find_spec
23from typing import Literal
34
4- import gymnasium as gym
55import numpy as np
6- from gymnasium import spaces
7- from scipy .special import expit as sigmoid
86import pandas as pd
97import torch
108import yaml
9+ from gymnasium import Env , spaces
10+ from scipy .special import expit as sigmoid
1111
12- from weather2alert .datautils import get_similar_counties
13-
14-
15- with open (resources .path ("weather2alert.weights" , "master.yaml" ), "r" ) as f :
16- VALID_WEIGHTS = yaml .safe_load (f )
12+ from .datautils import get_similar_counties
1713
1814
19- class HeatAlertEnv (gym . Env ):
15+ class HeatAlertEnv (Env ):
2016 """Class to simulate the environment for the online RL agent."""
2117
2218 def __init__ (
23- self , weights : str = "nn_full_medicare" , valid_years : list | None = None
19+ self ,
20+ weights : str = "nn_full_medicare" ,
21+ years : list | None = None ,
22+ fips_list : list | None = None ,
23+ similar_climate_counties : bool = False ,
24+ budget : int | None = None ,
2425 ):
2526 """Initialize the environment."""
2627 super ().__init__ ()
27- self .valid_years = valid_years
28- # assert (
29- # weights in VALID_WEIGHTS
30- # ), f"Invalid weights: {weights}, valid weights are {VALID_WEIGHTS}"
28+ self .valid_years = years
29+ self .similar_climate_counties = similar_climate_counties
30+ self .budget = budget
31+ if years is None :
32+ years = list (range (2006 , 2017 ))
3133
3234 # load state and confounders data
33- path = resources .path (
34- "weather2alert.data.processed" , "exogenous_states.parquet"
35- )
36- exogenous_states = pd .read_parquet (path )
37- path = resources .path (
38- "weather2alert.data.processed" , "endogenous_states_actions.parquet"
35+
36+ # check if path data/processed exists, then we are working with local data
37+ if False : # os.path.exists("data/processed"):
38+ root = "./"
39+ else :
40+ root = find_spec ("weather2alert" ).submodule_search_locations [0 ]
41+
42+ processed_path = os .path .join (root , "data/processed" )
43+ weights_path = os .path .join (root , "weights" )
44+
45+ exogenous_states = pd .read_parquet (processed_path + "/exogenous_states.parquet" )
46+ endogenous_states_actions = pd .read_parquet (
47+ processed_path + "/endogenous_states_actions.parquet"
3948 )
40- endogenous_states_actions = pd .read_parquet (path )
4149 merged = pd .merge (
4250 exogenous_states , endogenous_states_actions , on = ["fips" , "date" ]
4351 )
4452 merged ["year" ] = merged .date .str [:4 ].astype (int )
4553
4654 # make sure merged is order by fips date and remove dates outside of the range
4755 # 152 days of the summer starting on May 1st to Sep 30th
48- month = merged .date .str [5 :7 ]
56+ month = merged .date .str [5 :7 ]
4957 merged = merged [(month >= "05" ) & (month <= "09" )].copy ()
5058 merged = merged .drop_duplicates (["fips" , "date" ])
5159
5260 # merged.set_index(["fips", "date"], inplace=True)
53- confounders = pd .read_parquet (
54- resources .path ("weather2alert.data.processed" , "confounders.parquet" )
55- )
61+ confounders = pd .read_parquet (processed_path + "/confounders.parquet" )
5662
5763 self .merged = merged .set_index (["fips" , "year" ])
5864 self .confounders = confounders
5965
6066 # load posterior parameters and config
61- weights_dir = "weather2alert.weights." + weights
62- path = resources .path (weights_dir , "posterior_samples.pt" )
63- posterior_samples = torch .load (path , weights_only = True )
64- self .fips_list = posterior_samples ["fips_list" ]
67+ posterior_samples = torch .load (
68+ f"{ weights_path } /{ weights } /posterior_samples.pt" , weights_only = True
69+ )
70+
71+ self .fips_list = fips_list
72+ if fips_list is None :
73+ self .fips_list = posterior_samples ["fips_list" ]
6574
6675 self .baseline_coefs = {
6776 k : v for k , v in posterior_samples .items () if k .startswith ("baseline" )
6877 }
6978 self .effectiveness_coefs = {
7079 k : v for k , v in posterior_samples .items () if k .startswith ("effectiveness" )
7180 }
72- with open (resources . path ( weights_dir , " config.yaml") , "r" ) as f :
81+ with open (rf" { weights_path } / { weights } / config.yaml" , "r" ) as f :
7382 self .config = yaml .safe_load (f )
7483
7584 # get num posterior samples
@@ -91,12 +100,15 @@ def __init__(
91100 for k in self .merged .columns :
92101 if k .startswith ("bspline_" ):
93102 self .merged [k .replace ("bspline_" , "bsplines_" )] = self .merged [k ]
94- # ----
103+
104+ if self .valid_years is None :
105+ self .valid_years = list (self .merged .index .get_level_values ("year" ).unique ())
95106
96107 def _get_episode (
97108 self ,
98109 location : str ,
99110 augment : bool = False ,
111+ year : int | None = None ,
100112 ):
101113 if augment :
102114 # get similar counties
@@ -109,11 +121,9 @@ def _get_episode(
109121 self .location_index = self .fips_list .index (location )
110122
111123 # split by year and index by dos, drop data
112- valid_years = self .valid_years
113- if self .valid_years is None :
114- valid_years = self .merged .loc [self .location ].index .unique ()
124+ if year is None :
125+ year = self .rng .choice (self .valid_years )
115126
116- year = self .rng .choice (valid_years )
117127 year_data = self .merged .loc [(location , year )]
118128 year_data = (
119129 year_data .reset_index ().drop (columns = ["fips" , "year" ]).set_index ("date" )
@@ -123,8 +133,9 @@ def _get_episode(
123133 def reset (
124134 self ,
125135 location : str | None = None ,
126- similar_climate_counties : bool = False ,
136+ similar_climate_counties : bool | None = None ,
127137 seed : int | None = None ,
138+ budget : int | None = None ,
128139 sample_budget : bool = False ,
129140 sample_budget_type : Literal ["less_than" , "centered" ] = "less_than" ,
130141 ):
@@ -133,6 +144,9 @@ def reset(
133144 seed = np .random .randint (0 , 10000 )
134145 self .rng = np .random .default_rng (seed )
135146
147+ if similar_climate_counties is None :
148+ similar_climate_counties = self .similar_climate_counties
149+
136150 # if location is None, pick a random location
137151 if location is None :
138152 location = self .rng .choice (self .fips_list )
@@ -150,18 +164,24 @@ def reset(
150164 self .alert_streak = 0
151165 self .t = 0 # day of summer indicator
152166
153- b = self .ep ["remaining_budget" ].iloc [0 ]
167+ if self .budget is None :
168+ self .budget = (
169+ self .ep ["remaining_budget" ].iloc [0 ] if budget is None else budget
170+ )
171+
154172 if sample_budget :
173+ b = self .budget
155174 if sample_budget_type == "less_than" :
156175 self .budget = self .rng .integers (0 , b + 1 )
157176 elif sample_budget_type == "centered" :
158177 self .budget = self .rng .integers (0.5 * b , 1.5 * b + 1 )
159- else :
160- self .budget = b
178+ self .remaining_budget = self .budget
161179
162180 self .at_budget = False
163181 self .observation = self ._get_obs ()
164- return self .observation , self ._get_info ()
182+ if not hasattr (self , "feat_names" ):
183+ self .feat_names = self .observation .index .tolist ()
184+ return self .observation .values , self ._get_info ()
165185
166186 def _get_obs (self ):
167187 row = self .ep .iloc [self .t ].copy ()
@@ -195,10 +215,10 @@ def _get_reward(self, action):
195215 x = row [k .replace ("effectiveness_" , "" )]
196216 v = v [self .coef_index , 0 , li ].item ()
197217 effectiveness_contribs .append (x * v )
198- effectiveness = sigmoid (sum (effectiveness_contribs ))
218+ effectiveness = sigmoid (sum (effectiveness_contribs )) * ( row [ "heat_qi" ] > 0.5 )
199219
200- # reward is 1 - normalized hospitalization rate
201- reward = float (1 - baseline * (1 - effectiveness * action ))
220+ # reward is - normalized hospitalization rate / 10_000
221+ reward = float (- 10_000 * baseline * (1 - effectiveness * action ))
202222
203223 if action == 1 and self .at_budget :
204224 reward = - 1
@@ -208,8 +228,9 @@ def _get_reward(self, action):
208228 def _get_info (self ) -> dict :
209229 return {
210230 "episode_index" : self .ep_index ,
211- "budget" : self .budget ,
212- "feature_names" : self .ep .columns .tolist (),
231+ "remaining_budget" : self .remaining_budget ,
232+ "at_budget" : self .at_budget ,
233+ "feature_names" : self .feat_names ,
213234 "location" : self .location ,
214235 "location_index" : self .location_index ,
215236 }
@@ -225,17 +246,20 @@ def step(self, action: int):
225246 actual_action = action
226247
227248 self .actual_alert_buffer .append (actual_action )
249+ if actual_action == 1 :
250+ self .remaining_budget -= 1
228251
229252 # compute reward for the new state
230253 reward = self ._get_reward (actual_action )
231254
232255 # advance state
233- self .t += 1
234- observation = self ._get_obs ().values
235- done = self .t == self .n_days - 1
236- self .alert_streak = self .alert_streak + 1 if actual_action else 0
256+ done = self .t >= self .n_days - 1
257+ if not done :
258+ self .observation = self ._get_obs ()
259+ self .t += 1
260+ self .alert_streak = self .alert_streak + 1 if actual_action else 0
237261
238- return observation , reward , done , False , self ._get_info ()
262+ return self . observation . values , reward , done , False , self ._get_info ()
239263
240264
241265if __name__ == "__main__" :
0 commit comments