Skip to content

Commit ad3c0b0

Browse files
committed
some abstraction to the code + random walk implementation
1 parent ae18648 commit ad3c0b0

File tree

1 file changed

+94
-63
lines changed

1 file changed

+94
-63
lines changed

person_story.py

Lines changed: 94 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,10 @@ def random_normal(mean: float, std_dev: Optional[float] = None) -> float:
2424

2525

2626
def gen_death(
27-
generic: Generic, person: SqlRow, src_stats: SrcStats
27+
generic: Generic, person: SqlRow, src_stats: SrcStats
2828
) -> Optional[tuple[str, SqlRow]]:
2929
"""Generate a row for the death table."""
30+
3031
def with_probability(p: float) -> bool:
3132
"""Return True with probability p (0 ≤ p ≤ 1)."""
3233
return random.random() < p
@@ -36,7 +37,8 @@ def with_probability(p: float) -> bool:
3637
else:
3738
avg_age_at_death_days = src_stats["age_at_death"][0]["average_age_years"] * 365
3839
std_dev_age_at_death_days = src_stats["age_at_death"][0]["stddev_age_years"] * 365
39-
age_at_death_days = abs(random_normal(cast(float,avg_age_at_death_days), cast(float,std_dev_age_at_death_days)))
40+
age_at_death_days = abs(
41+
random_normal(cast(float, avg_age_at_death_days), cast(float, std_dev_age_at_death_days)))
4042
death_datetime = cast(dt.datetime, person["birth_datetime"]) + dt.timedelta(
4143
days=age_at_death_days)
4244
return "death", {
@@ -45,35 +47,36 @@ def with_probability(p: float) -> bool:
4547
"death_date": death_datetime.date(),
4648
}
4749

50+
4851
def gen_visit_occurrence(
49-
person: SqlRow, death: Optional[SqlRow], src_stats: SrcStats
52+
person: SqlRow, death: Optional[SqlRow], src_stats: SrcStats
5053
) -> tuple[str, SqlRow]:
5154
"""Generate a row for the visit_occurrence table."""
5255
age_days_at_visit_start = abs(
5356
random_normal(
54-
cast(float, 63*365), cast(float, 13*365)
57+
cast(float, 63 * 365), cast(float, 13 * 365)
5558
)
5659
)
5760
if person["gender_concept_id"] == 8532:
5861
age_days_at_visit_start = abs(
5962
random_normal(
60-
cast(float, src_stats["age_first_admission"][0]["average_age_years"]*365),
61-
cast(float, src_stats["age_first_admission"][0]["stddev_age_years"]*365)
63+
cast(float, src_stats["age_first_admission"][0]["average_age_years"] * 365),
64+
cast(float, src_stats["age_first_admission"][0]["stddev_age_years"] * 365)
6265
)
6366
)
6467
if person["gender_concept_id"] == 8507:
6568
age_days_at_visit_start = abs(
6669
random_normal(
67-
cast(float, src_stats["age_first_admission"][1]["average_age_years"]*365),
68-
cast(float, src_stats["age_first_admission"][1]["stddev_age_years"]*365)
70+
cast(float, src_stats["age_first_admission"][1]["average_age_years"] * 365),
71+
cast(float, src_stats["age_first_admission"][1]["stddev_age_years"] * 365)
6972
)
7073
)
7174
visit_start_datetime = cast(dt.datetime, person["birth_datetime"]) + dt.timedelta(
7275
days=age_days_at_visit_start
7376
)
7477
visit_length_hours = abs(
7578
random_normal(
76-
cast(float, src_stats["visit_duration"][0]["average_hours"]),
79+
cast(float, src_stats["visit_duration"][0]["average_hours"]),
7780
cast(float, src_stats["visit_duration"][0]["stddev_hours"])
7881
# cast(float, 6), cast(float, 29*24)
7982
)
@@ -114,15 +117,15 @@ def random_event_times(avg_rate: float, visit_occurrence: SqlRow) -> list[dt.dat
114117

115118

116119
def gen_events( # pylint: disable=too-many-arguments
117-
generic: Generic,
118-
avg_rate: float,
119-
visit_occurrence: SqlRow,
120-
person: SqlRow,
121-
generator_function: Callable[
122-
[Generic, int, int, dt.datetime, SrcStats], Optional[SqlRow]
123-
],
124-
table_name: str,
125-
src_stats: SrcStats,
120+
generic: Generic,
121+
avg_rate: float,
122+
visit_occurrence: SqlRow,
123+
person: SqlRow,
124+
generator_function: Callable[
125+
[Generic, int, int, dt.datetime, SrcStats], Optional[SqlRow]
126+
],
127+
table_name: str,
128+
src_stats: SrcStats,
126129
) -> list[tuple[str, SqlRow]]:
127130
"""Generate events for a visit occurrence, at a given rate with a given generator.
128131
@@ -143,31 +146,30 @@ def gen_events( # pylint: disable=too-many-arguments
143146
events.append((table_name, event))
144147
return events
145148

149+
146150
def gen_blood_pressure_events( # pylint: disable=too-many-arguments
147-
avg_rate: float,
148-
visit_occurrence: SqlRow,
149-
person: SqlRow,
150-
src_stats: SrcStats,
151+
avg_rate: float,
152+
visit_occurrence: SqlRow,
153+
person: SqlRow,
154+
src_stats: SrcStats,
151155
) -> list[tuple[str, SqlRow]]:
152156
"""Generate events for a visit occurrence, at a given rate with a given generator.
153157
154158
This is a utility function for generating multiple rows for one of the "event"
155159
tables (measurements, observation, etc.).
156160
"""
157161

158-
159162
def generate_paired_measurement(
160-
person_id: int,
161-
visit_occurrence_id: int,
162-
event_datetime: dt.datetime,
163-
values: tuple[float, float],
164-
measurement_concept_id: tuple[int,int],
165-
measurement_type_concept_ids: int,
166-
unit_concept_id: int,
167-
unit_source_value: str,
163+
person_id: int,
164+
visit_occurrence_id: int,
165+
event_datetime: dt.datetime,
166+
values: tuple[float, float],
167+
measurement_concept_id: tuple[int, int],
168+
measurement_type_concept_ids: int,
169+
unit_concept_id: int,
170+
unit_source_value: str,
168171
) -> tuple[SqlRow, SqlRow]:
169172

170-
171173
### This can be abastracted to generate any number of set of measurements
172174
"""Generate two rows for the measurement table."""
173175
measurement1: SqlRow = {
@@ -194,50 +196,78 @@ def generate_paired_measurement(
194196
"value_as_number": values[1],
195197
}
196198
return measurement1, measurement2
197-
199+
198200
event_datetimes = random_event_times(avg_rate, visit_occurrence)
199201

200-
avg_systolic = 114.236842
201-
avg_diastolic = 74.447368
202+
if len(event_datetimes) == 0:
203+
return []
204+
205+
# can we get this from the data?
202206
sys_bp_non_invasive_concept_id = 21492239
203207
dias_bp_non_invasive_concept_id = 21492240
204208
measurement_type_concept_id = 32817 # EHR measurement
205209
unit_source_value = "mmHg"
206210
unit_concept_id = 8876 # mmHg
207211

208212
gender = cast(int, person["gender_concept_id"])
213+
age = (cast(dt.datetime, visit_occurrence["visit_start_datetime"]) - cast(dt.datetime,
214+
person["birth_datetime"])).days / 365.25
215+
216+
main_key = 'bp_profile'
217+
relative_change_key = 'bp_sys_relative_change_stats'
218+
if age < 60:
219+
key_mean = 'average_under_60_systolic'
220+
key_std = 'stddev_under_60_systolic'
221+
222+
key_epsilon_mean = 'avg_under_60_systolic_rel_var'
223+
key_epsilon_std = 'stddev_under_60_systolic_rel_var'
224+
225+
226+
else:
227+
key_mean = 'average_over_60_systolic'
228+
key_std = 'stddev_over_60_systolic'
229+
230+
key_epsilon_mean = 'avg_over_60_systolic_rel_var'
231+
key_epsilon_std = 'stddev_over_60_systolic_rel_var'
232+
209233
if gender == 8507:
210-
systolic_value = np.round(generate_time_series(len(event_datetimes), 'iid',
211-
{'mean': src_stats["bp_profile"][0]["average_under_60_systolic"],
212-
'std': src_stats["bp_profile"][0]["stddev_under_60_systolic"]},
213-
random_state=42))
214-
diastolic_value = np.round(random_normal(src_stats["bp_profile"][0]["average_systolic_diastolic_difference"],src_stats["bp_profile"][0]["average_systolic_diastolic_difference"]*0.1) + systolic_value)
215-
elif gender == 8532:
216-
systolic_value = np.round(generate_time_series(len(event_datetimes), 'iid',
217-
{'mean': src_stats["bp_profile"][1]["average_under_60_systolic"],
218-
'std': src_stats["bp_profile"][1]["stddev_under_60_systolic"]},
219-
random_state=42))
220-
diastolic_value = np.round(random_normal(src_stats["bp_profile"][1]["average_systolic_diastolic_difference"],
221-
src_stats["bp_profile"][1][
222-
"average_systolic_diastolic_difference"] * 0.1) + systolic_value)
234+
index_gender = 0
223235
else:
224-
systolic_value = avg_systolic
225-
diastolic_value = avg_diastolic
236+
index_gender = 1
237+
238+
sample_epsilon = np.random.normal(src_stats[relative_change_key][index_gender][key_epsilon_mean],
239+
src_stats[relative_change_key][index_gender][key_epsilon_std], 1)
240+
241+
systolic_value = np.round(generate_time_series(len(event_datetimes), 'random_walk',
242+
{'mean': src_stats[main_key][index_gender][key_mean],
243+
'std': src_stats[main_key][0][key_std],
244+
'epsilon_std': sample_epsilon, 'drift': 0},
245+
random_state=42))
246+
247+
# diastolic value is calculated based on systolic value plus the average difference extrated from data
248+
# we add some variation to the difference between systolic and diastolic
249+
diastolic_value = np.round(random_normal(src_stats[main_key][index_gender]['average_systolic_diastolic_difference'],
250+
src_stats[main_key][index_gender][
251+
"average_systolic_diastolic_difference"] * 0.1) + systolic_value)
226252

227253
events: list[tuple[str, SqlRow]] = []
228254
for index, event_datetime in enumerate(sorted(event_datetimes)):
229255
systolic_dict, diastolic_dict = generate_paired_measurement(cast(int, person["person_id"]),
230-
cast(int, visit_occurrence["visit_occurrence_id"]),
231-
event_datetime,(systolic_value[index], diastolic_value[index]),
232-
(sys_bp_non_invasive_concept_id,dias_bp_non_invasive_concept_id),
233-
measurement_type_concept_id,unit_concept_id,unit_source_value)
256+
cast(int, visit_occurrence["visit_occurrence_id"]),
257+
event_datetime,
258+
(systolic_value[index], diastolic_value[index]),
259+
(sys_bp_non_invasive_concept_id,
260+
dias_bp_non_invasive_concept_id),
261+
measurement_type_concept_id, unit_concept_id,
262+
unit_source_value)
234263
events.append(("measurement", systolic_dict)),
235264
events.append(("measurement", diastolic_dict))
236265
return events
237266

267+
238268
def generate(
239-
generic: Generic,
240-
src_stats: SrcStats,
269+
generic: Generic,
270+
src_stats: SrcStats,
241271
) -> Generator[tuple[str, SqlRow], SqlRow, None]:
242272
"""Yield all the data related to a single patient.
243273
@@ -263,15 +293,16 @@ def generate(
263293
# abs to avoid negative rates due to random normal variation
264294
avg_rate = abs(random_normal(
265295
src_stats["avg_measurements_per_visit_hour"][0]['avg_measurements_per_hour'],
266-
src_stats["avg_measurements_per_visit_hour"][0]['stddev_measurements_per_hour'] )
296+
src_stats["avg_measurements_per_visit_hour"][0]['stddev_measurements_per_hour'])
267297
)
268298

269-
270299
print(f"Generating blood pressure events at an average rate of {avg_rate} per hour.")
271300
for event in gen_blood_pressure_events(
272-
avg_rate,
273-
visit_occurrence,
274-
person,
275-
src_stats,
301+
avg_rate,
302+
visit_occurrence,
303+
person,
304+
src_stats,
276305
):
277-
yield event
306+
# Yield each measurement event if is not empty dictionary
307+
if len(event) > 0:
308+
yield event

0 commit comments

Comments
 (0)