Skip to content

Commit d12f307

Browse files
authored
Merge pull request #5 from NSAPH-Projects:mauriciogtec/issue4
Fix installer #4
2 parents 274e2cd + 965686c commit d12f307

File tree

10 files changed

+118
-90
lines changed

10 files changed

+118
-90
lines changed
File renamed without changes.
File renamed without changes.

pyproject.toml

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,44 @@
11
[build-system]
2-
requires = ["setuptools>=42", "wheel"]
2+
requires = ["setuptools", "wheel"]
33
build-backend = "setuptools.build_meta"
4+
5+
[project]
6+
name = "weather2alert"
7+
version = "0.1.0"
8+
description = "A gym environment for optimizing heat alert issuance during heatwaves"
9+
readme = "README.md"
10+
authors = [{ name="Anonymous", email="[email protected]" }]
11+
license = { text = "MIT" }
12+
classifiers = [
13+
"Development Status :: 3 - Alpha",
14+
"Intended Audience :: Developers",
15+
"License :: OSI Approved :: MIT License",
16+
"Programming Language :: Python :: 3.10",
17+
"Programming Language :: Python :: 3.11",
18+
]
19+
requires-python = ">=3.10"
20+
dependencies = [
21+
"scipy",
22+
"tqdm",
23+
"pyarrow",
24+
"pandas",
25+
"torch",
26+
"gymnasium",
27+
]
28+
29+
# Optional dependencies go here under 'optional-dependencies'
30+
[project.optional-dependencies]
31+
dev = ["pytest"]
32+
33+
[tool.setuptools]
34+
packages = ["src.weather2alert"]
35+
include-package-data = true
36+
37+
[tool.setuptools.package-data]
38+
weather2alert = [
39+
"weights/nn_full_medicare/*",
40+
"weights/nn_debug_medicare/*",
41+
"weights/master.yaml",
42+
"data/processed/*.parquet",
43+
"data/raw/*",
44+
]

requirements.txt

Lines changed: 0 additions & 10 deletions
This file was deleted.

src/weather2alert/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
File renamed without changes.

weather2alert/datautils.py renamed to src/weather2alert/datautils.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,4 @@
11
import pandas as pd
2-
from sklearn.preprocessing import StandardScaler
3-
import torch
4-
5-
6-
# import matplotlib.pyplot as plt
72

83
WESTERN_STATES = [
94
"AZ",

weather2alert/env.py renamed to src/weather2alert/env.py

Lines changed: 75 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -1,75 +1,84 @@
1-
from importlib import resources
1+
import os
2+
from importlib.util import find_spec
23
from typing import Literal
34

4-
import gymnasium as gym
55
import numpy as np
6-
from gymnasium import spaces
7-
from scipy.special import expit as sigmoid
86
import pandas as pd
97
import torch
108
import yaml
9+
from gymnasium import Env, spaces
10+
from scipy.special import expit as sigmoid
1111

12-
from weather2alert.datautils import get_similar_counties
13-
14-
15-
with open(resources.path("weather2alert.weights", "master.yaml"), "r") as f:
16-
VALID_WEIGHTS = yaml.safe_load(f)
12+
from .datautils import get_similar_counties
1713

1814

19-
class HeatAlertEnv(gym.Env):
15+
class HeatAlertEnv(Env):
2016
"""Class to simulate the environment for the online RL agent."""
2117

2218
def __init__(
23-
self, weights: str = "nn_full_medicare", valid_years: list | None = None
19+
self,
20+
weights: str = "nn_full_medicare",
21+
years: list | None = None,
22+
fips_list: list | None = None,
23+
similar_climate_counties: bool = False,
24+
budget: int | None = None,
2425
):
2526
"""Initialize the environment."""
2627
super().__init__()
27-
self.valid_years = valid_years
28-
# assert (
29-
# weights in VALID_WEIGHTS
30-
# ), f"Invalid weights: {weights}, valid weights are {VALID_WEIGHTS}"
28+
self.valid_years = years
29+
self.similar_climate_counties = similar_climate_counties
30+
self.budget = budget
31+
if years is None:
32+
years = list(range(2006, 2017))
3133

3234
# load state and confounders data
33-
path = resources.path(
34-
"weather2alert.data.processed", "exogenous_states.parquet"
35-
)
36-
exogenous_states = pd.read_parquet(path)
37-
path = resources.path(
38-
"weather2alert.data.processed", "endogenous_states_actions.parquet"
35+
36+
# check if path data/processed exists, then we are working with local data
37+
if False: # os.path.exists("data/processed"):
38+
root = "./"
39+
else:
40+
root = find_spec("weather2alert").submodule_search_locations[0]
41+
42+
processed_path = os.path.join(root, "data/processed")
43+
weights_path = os.path.join(root, "weights")
44+
45+
exogenous_states = pd.read_parquet(processed_path + "/exogenous_states.parquet")
46+
endogenous_states_actions = pd.read_parquet(
47+
processed_path + "/endogenous_states_actions.parquet"
3948
)
40-
endogenous_states_actions = pd.read_parquet(path)
4149
merged = pd.merge(
4250
exogenous_states, endogenous_states_actions, on=["fips", "date"]
4351
)
4452
merged["year"] = merged.date.str[:4].astype(int)
4553

4654
# make sure merged is order by fips date and remove dates outside of the range
4755
# 152 days of the summer starting on May 1st to Sep 30th
48-
month = merged.date.str[5:7]
56+
month = merged.date.str[5:7]
4957
merged = merged[(month >= "05") & (month <= "09")].copy()
5058
merged = merged.drop_duplicates(["fips", "date"])
5159

5260
# merged.set_index(["fips", "date"], inplace=True)
53-
confounders = pd.read_parquet(
54-
resources.path("weather2alert.data.processed", "confounders.parquet")
55-
)
61+
confounders = pd.read_parquet(processed_path + "/confounders.parquet")
5662

5763
self.merged = merged.set_index(["fips", "year"])
5864
self.confounders = confounders
5965

6066
# load posterior parameters and config
61-
weights_dir = "weather2alert.weights." + weights
62-
path = resources.path(weights_dir, "posterior_samples.pt")
63-
posterior_samples = torch.load(path, weights_only=True)
64-
self.fips_list = posterior_samples["fips_list"]
67+
posterior_samples = torch.load(
68+
f"{weights_path}/{weights}/posterior_samples.pt", weights_only=True
69+
)
70+
71+
self.fips_list = fips_list
72+
if fips_list is None:
73+
self.fips_list = posterior_samples["fips_list"]
6574

6675
self.baseline_coefs = {
6776
k: v for k, v in posterior_samples.items() if k.startswith("baseline")
6877
}
6978
self.effectiveness_coefs = {
7079
k: v for k, v in posterior_samples.items() if k.startswith("effectiveness")
7180
}
72-
with open(resources.path(weights_dir, "config.yaml"), "r") as f:
81+
with open(rf"{weights_path}/{weights}/config.yaml", "r") as f:
7382
self.config = yaml.safe_load(f)
7483

7584
# get num posterior samples
@@ -91,12 +100,15 @@ def __init__(
91100
for k in self.merged.columns:
92101
if k.startswith("bspline_"):
93102
self.merged[k.replace("bspline_", "bsplines_")] = self.merged[k]
94-
# ----
103+
104+
if self.valid_years is None:
105+
self.valid_years = list(self.merged.index.get_level_values("year").unique())
95106

96107
def _get_episode(
97108
self,
98109
location: str,
99110
augment: bool = False,
111+
year: int | None = None,
100112
):
101113
if augment:
102114
# get similar counties
@@ -109,11 +121,9 @@ def _get_episode(
109121
self.location_index = self.fips_list.index(location)
110122

111123
# split by year and index by dos, drop data
112-
valid_years = self.valid_years
113-
if self.valid_years is None:
114-
valid_years = self.merged.loc[self.location].index.unique()
124+
if year is None:
125+
year = self.rng.choice(self.valid_years)
115126

116-
year = self.rng.choice(valid_years)
117127
year_data = self.merged.loc[(location, year)]
118128
year_data = (
119129
year_data.reset_index().drop(columns=["fips", "year"]).set_index("date")
@@ -123,8 +133,9 @@ def _get_episode(
123133
def reset(
124134
self,
125135
location: str | None = None,
126-
similar_climate_counties: bool = False,
136+
similar_climate_counties: bool | None = None,
127137
seed: int | None = None,
138+
budget: int | None = None,
128139
sample_budget: bool = False,
129140
sample_budget_type: Literal["less_than", "centered"] = "less_than",
130141
):
@@ -133,6 +144,9 @@ def reset(
133144
seed = np.random.randint(0, 10000)
134145
self.rng = np.random.default_rng(seed)
135146

147+
if similar_climate_counties is None:
148+
similar_climate_counties = self.similar_climate_counties
149+
136150
# if location is None, pick a random location
137151
if location is None:
138152
location = self.rng.choice(self.fips_list)
@@ -150,18 +164,24 @@ def reset(
150164
self.alert_streak = 0
151165
self.t = 0 # day of summer indicator
152166

153-
b = self.ep["remaining_budget"].iloc[0]
167+
if self.budget is None:
168+
self.budget = (
169+
self.ep["remaining_budget"].iloc[0] if budget is None else budget
170+
)
171+
154172
if sample_budget:
173+
b = self.budget
155174
if sample_budget_type == "less_than":
156175
self.budget = self.rng.integers(0, b + 1)
157176
elif sample_budget_type == "centered":
158177
self.budget = self.rng.integers(0.5 * b, 1.5 * b + 1)
159-
else:
160-
self.budget = b
178+
self.remaining_budget = self.budget
161179

162180
self.at_budget = False
163181
self.observation = self._get_obs()
164-
return self.observation, self._get_info()
182+
if not hasattr(self, "feat_names"):
183+
self.feat_names = self.observation.index.tolist()
184+
return self.observation.values, self._get_info()
165185

166186
def _get_obs(self):
167187
row = self.ep.iloc[self.t].copy()
@@ -195,10 +215,10 @@ def _get_reward(self, action):
195215
x = row[k.replace("effectiveness_", "")]
196216
v = v[self.coef_index, 0, li].item()
197217
effectiveness_contribs.append(x * v)
198-
effectiveness = sigmoid(sum(effectiveness_contribs))
218+
effectiveness = sigmoid(sum(effectiveness_contribs)) * (row["heat_qi"] > 0.5)
199219

200-
# reward is 1 - normalized hospitalization rate
201-
reward = float(1 - baseline * (1 - effectiveness * action))
220+
# reward is - normalized hospitalization rate / 10_000
221+
reward = float(-10_000 * baseline * (1 - effectiveness * action))
202222

203223
if action == 1 and self.at_budget:
204224
reward = -1
@@ -208,8 +228,9 @@ def _get_reward(self, action):
208228
def _get_info(self) -> dict:
209229
return {
210230
"episode_index": self.ep_index,
211-
"budget": self.budget,
212-
"feature_names": self.ep.columns.tolist(),
231+
"remaining_budget": self.remaining_budget,
232+
"at_budget": self.at_budget,
233+
"feature_names": self.feat_names,
213234
"location": self.location,
214235
"location_index": self.location_index,
215236
}
@@ -225,17 +246,20 @@ def step(self, action: int):
225246
actual_action = action
226247

227248
self.actual_alert_buffer.append(actual_action)
249+
if actual_action == 1:
250+
self.remaining_budget -= 1
228251

229252
# compute reward for the new state
230253
reward = self._get_reward(actual_action)
231254

232255
# advance state
233-
self.t += 1
234-
observation = self._get_obs().values
235-
done = self.t == self.n_days - 1
236-
self.alert_streak = self.alert_streak + 1 if actual_action else 0
256+
done = self.t >= self.n_days - 1
257+
if not done:
258+
self.observation = self._get_obs()
259+
self.t += 1
260+
self.alert_streak = self.alert_streak + 1 if actual_action else 0
237261

238-
return observation, reward, done, False, self._get_info()
262+
return self.observation.values, reward, done, False, self._get_info()
239263

240264

241265
if __name__ == "__main__":

tests/test_setup.py

Lines changed: 0 additions & 22 deletions
This file was deleted.

weather2alert/__init__.py

Lines changed: 0 additions & 1 deletion
This file was deleted.

0 commit comments

Comments
 (0)