Skip to content

Commit cd7eadd

Browse files
committed
closes #989 #988 preparing for bumping version 5.5.4
2 parents f0ca919 + 832c520 commit cd7eadd

File tree

8 files changed

+25740
-84
lines changed

8 files changed

+25740
-84
lines changed

requirement/main.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ verboselogs==1.7
2525
s3fs==0.4.2 # pyup: ignore
2626
scikit-learn==1.6.1
2727
matplotlib==3.5.1
28+
matplotlib-inline==0.1.3
2829
seaborn==0.11.2
2930
ohio==0.5.0
3031
aequitas==0.42.0

src/triage/component/postmodeling/base.py

Lines changed: 50 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@
55
import seaborn as sns
66
import matplotlib.table as tab
77
import matplotlib.pyplot as plt
8+
import itertools
89
#from tabulate import tabulate
910

1011
from IPython.display import display
11-
import itertools
12-
12+
from io import StringIO
1313
from descriptors import cachedproperty
1414
from sqlalchemy import create_engine
1515
from sklearn.calibration import calibration_curve
@@ -475,15 +475,55 @@ def mean_ratio(pos, neg):
475475
crosstabs_df['model_id'] = self.model_id
476476
crosstabs_df['matrix_uuid'] = matrix_uuid
477477

478-
479478
if push_to_db:
480-
logging.info('Pushing the results to the DB')
479+
# logging.info('Pushing the results to the DB')
480+
# crosstabs_df.set_index(
481+
# ['model_id', 'matrix_uuid', 'feature', 'metric', 'threshold_type', 'threshold'], inplace=True
482+
# )
483+
484+
# # TODO: Figure out to change the owner of the table
485+
# crosstabs_df.pg_copy_to(schema='test_results', name=table_name, con=self.engine, if_exists='append')
486+
logging.info(f'Pushing the results to the database, {len(crosstabs_df)} rows')
487+
481488
crosstabs_df.set_index(
482-
['model_id', 'matrix_uuid', 'feature', 'metric', 'threshold_type', 'threshold'], inplace=True
489+
['model_id', 'matrix_uuid', 'feature', 'metric', 'threshold_type', 'threshold'],
490+
inplace=True
483491
)
484-
485-
# TODO: Figure out to change the owner of the table
486-
crosstabs_df.pg_copy_to(schema='test_results', name=table_name, con=self.engine, if_exists='append')
492+
493+
crosstabs_df = crosstabs_df.reset_index()
494+
495+
if not table_exists(f'test_results.{table_name}', self.engine):
496+
q = f'''
497+
create schema if not exists test_results;
498+
499+
create table test_results.{table_name} (
500+
model_id INTEGER,
501+
matrix_uuid TEXT,
502+
feature TEXT,
503+
metric TEXT,
504+
threshold_type TEXT,
505+
threshold FLOAT,
506+
value FLOAT
507+
);
508+
509+
'''
510+
# q = _generate_create_table_sql_statement_from_df(results, f'{table_schema}.{table_name}')
511+
self.engine.execute(q)
512+
513+
conn = self.engine.raw_connection()
514+
cursor = conn.cursor()
515+
516+
buffer = StringIO()
517+
crosstabs_df.to_csv(buffer, index=False, header=False)
518+
buffer.seek(0)
519+
520+
columns = ', '.join(crosstabs_df.columns)
521+
print(columns)
522+
cursor.copy_expert(f"COPY test_results.{table_name} ({columns}) FROM STDIN WITH CSV", buffer)
523+
# results.to_sql(con=db_engine, schema=table_schema, name=table_name, if_exists='append')
524+
conn.commit()
525+
cursor.close()
526+
conn.close()
487527

488528
if return_df:
489529
return crosstabs_df
@@ -1136,7 +1176,7 @@ def get_model_ids(self):
11361176
and model_group_id in ('{model_groups}')
11371177
"""
11381178
# TODO do we really need experiment_hashes here? can we query with only model_group_ids?
1139-
1179+
11401180
# TODO: modify to remove pandas
11411181
models = pd.read_sql(q, self.engine).to_dict(orient='records')
11421182

@@ -1197,7 +1237,7 @@ def _make_plot_grid(self, plot_type, subplot_width=3, subplot_len=None, sharey=F
11971237
"""
11981238
fig, axes = self._get_subplots(subplot_width=subplot_width, subplot_len=subplot_len, sharey=sharey, sharex=sharex)
11991239

1200-
print(len(axes), len(axes[0]))
1240+
logging.info(f"{len(axes), len(axes[0])}")
12011241

12021242
for j, mg in enumerate(self.models):
12031243
for i, train_end_time in enumerate(self.models[mg]):

src/triage/component/postmodeling/experiment_summarizer.py

Lines changed: 111 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,17 @@
11

22
""" This is a module for moving the ad-hoc code we wrote in generating the modeling report"""
3+
import verboselogs, logging
4+
5+
logger = verboselogs.VerboseLogger(__name__)
36

47
import pandas as pd
58
import json
69
import matplotlib.pyplot as plt
710
import seaborn as sns
8-
import logging
9-
from matplotlib.lines import Line2D
10-
1111
import warnings
1212

13+
from matplotlib.lines import Line2D
14+
1315
from triage.component.timechop.plotting import visualize_chops_plotly
1416
from triage.component.timechop import Timechop
1517
from triage.component.audition.plotting import plot_cats
@@ -112,7 +114,12 @@ def get_most_recent_experiment_hash(engine):
112114

113115
class ExperimentReport:
114116

115-
def __init__(self, engine, experiment_hashes, performance_priority_metric, threshold, bias_priority_metric, bias_priority_groups):
117+
def __init__(self, engine,
118+
experiment_hashes,
119+
performance_priority_metric='recall@',
120+
threshold="1_pct",
121+
bias_priority_metric='tpr_disparity',
122+
bias_priority_groups=None):
116123
self.engine = engine
117124
self.experiment_hashes = experiment_hashes
118125

@@ -429,33 +436,51 @@ def model_performance(self, metric=None, parameter=None, generate_plot=True):
429436
parameter = self.threshold
430437

431438
# fetch model groups
432-
q = f'''
433-
with models as (
439+
def fetch_evaluation_values(with_parameter=True):
440+
if with_parameter:
441+
parameter_stmt = f"and e.parameter = '{parameter}'"
442+
else:
443+
parameter_stmt = ''
444+
445+
q = f'''
446+
with models as (
447+
select
448+
distinct model_id,
449+
train_end_time,
450+
model_group_id,
451+
model_type,
452+
hyperparameters
453+
from triage_metadata.experiment_models join triage_metadata.models using(model_hash)
454+
where experiment_hash in ('{"','".join(self.experiment_hashes)}')
455+
)
434456
select
435-
distinct model_id,
436-
train_end_time,
437-
model_group_id,
457+
m.model_id,
458+
train_end_time::date as train_end_time_dt,
459+
to_char(train_end_time, 'YYYY-MM-DD') as train_end_time,
438460
model_type,
439-
hyperparameters
440-
from triage_metadata.experiment_models join triage_metadata.models using(model_hash)
441-
where experiment_hash in ('{"','".join(self.experiment_hashes)}')
442-
)
443-
select
444-
m.model_id,
445-
train_end_time::date as train_end_time_dt,
446-
to_char(train_end_time, 'YYYY-MM-DD') as train_end_time,
447-
model_type,
448-
model_group_id,
449-
stochastic_value as metric_value
450-
from models m left join test_results.evaluations e
451-
on m.model_id = e.model_id
452-
and e.metric = '{metric}'
453-
and e.parameter = '{parameter}'
454-
and e.subset_hash = ''
455-
'''
456-
457-
461+
model_group_id,
462+
stochastic_value as metric_value,
463+
parameter
464+
from models m left join test_results.evaluations e
465+
on m.model_id = e.model_id
466+
and e.metric = '{metric}'
467+
{parameter_stmt}
468+
and e.subset_hash = ''
469+
'''
470+
return q
471+
472+
# 1. fetch evaluation values to check if we have
473+
q = fetch_evaluation_values()
458474
df = pd.read_sql(q, self.engine)
475+
# Validate that we have value for the DEFAULT metric and parameter
476+
if df.metric_value.isna().unique():
477+
q = fetch_evaluation_values(with_parameter=False)
478+
df = pd.read_sql(q, self.engine)
479+
# fetch the first available value
480+
parameter_ = df.loc[0, 'parameter']
481+
self.threshold = parameter_
482+
df = df[df.parameter == parameter_]
483+
459484
df['train_end_time'] = pd.to_datetime(df.train_end_time, format='%Y-%m-%d')
460485

461486
models_per_train_end_time = df.groupby(['model_group_id', 'train_end_time']).count()['model_id']
@@ -508,27 +533,45 @@ def model_performance_subsets(self, metric=None, parameter=None, generate_plot=T
508533
if parameter is None:
509534
parameter = self.threshold
510535

511-
q = f'''
512-
select
513-
case when e.subset_hash is null then 'full_cohort'
514-
else s.config ->> 'name'
515-
end as "subset",
516-
e.subset_hash,
517-
m.model_id,
518-
m.model_group_id,
519-
m.model_type,
520-
m.train_end_time::date,
521-
e.stochastic_value as metric_value
522-
from triage_metadata.experiment_models join triage_metadata.models m using(model_hash)
523-
left join test_results.evaluations e
524-
on m.model_id = e.model_id
525-
and e.parameter = '{parameter}'
526-
and e.metric = '{metric}'
527-
left join triage_metadata.subsets s on e.subset_hash = s.subset_hash
528-
where experiment_hash in ('{"','".join(self.experiment_hashes)}')
529-
'''
536+
def fetch_evaluation_values_subsets(with_parameter=True):
537+
if with_parameter:
538+
parameter_stmt = f"and e.parameter = '{parameter}'"
539+
else:
540+
parameter_stmt = ''
541+
542+
q = f'''
543+
select
544+
case when e.subset_hash is null then 'full_cohort'
545+
else s.config ->> 'name'
546+
end as "subset",
547+
e.subset_hash,
548+
m.model_id,
549+
m.model_group_id,
550+
m.model_type,
551+
m.train_end_time::date,
552+
e.stochastic_value as metric_value,
553+
parameter
554+
from triage_metadata.experiment_models join triage_metadata.models m using(model_hash)
555+
left join test_results.evaluations e
556+
on m.model_id = e.model_id
557+
and e.parameter = '{parameter}'
558+
{parameter_stmt}
559+
left join triage_metadata.subsets s on e.subset_hash = s.subset_hash
560+
where experiment_hash in ('{"','".join(self.experiment_hashes)}')
561+
'''
562+
563+
return q
530564

565+
q = fetch_evaluation_values_subsets()
531566
df = pd.read_sql(q, self.engine)
567+
# Validate that we have value for the DEFAULT metric and parameter
568+
if df.metric_value.isna().unique():
569+
q = fetch_evaluation_values_subsets(with_parameter=False)
570+
df = pd.read_sql(q, self.engine)
571+
# fetch the first available value
572+
parameter_ = df.loc[0, 'parameter']
573+
self.threshold = parameter_
574+
df = df[df.parameter == parameter_]
532575

533576
if (df.empty) or (None in df.subset.unique()):
534577
return None
@@ -921,39 +964,41 @@ def generate_summary(self, metric=None, parameter=None, equity_metric=None):
921964
if equity_metric is None:
922965
equity_metric = self.bias_metric
923966

924-
967+
logger.notice(f"Default performance parameters are set to recall@1_pct and bias metric to tpr_disparity!")
968+
logger.notice("==> In case your experiment doesn't have those parameters Triage will use one of the available. <==")
925969
stats = self.experiment_stats()
926970

927971
if stats['implemented_fewer_splits'] == 1:
928-
print(f"Temporal config suggests {stats['timesplits_from_temporal_config']} temporal splits, but experiment implemented only {stats['validation_splits']} splits. Was this intentional?")
972+
logger.notice(f"Temporal config suggests {stats['timesplits_from_temporal_config']} temporal splits, but experiment implemented only {stats['validation_splits']} splits. Was this intentional?")
929973
else:
930-
print(f'Experiment contained {stats["timesplits_from_temporal_config"]} temporal splits')
974+
logger.notice(f'Experiment contained {stats["timesplits_from_temporal_config"]} temporal splits')
931975

932-
933-
print(f"Experiment contained {stats['as_of_times']} distinct as_of_times")
976+
logger.notice(f"Experiment contained {stats['as_of_times']} distinct as_of_times")
934977

935978
cohorts = self.cohorts(generate_plots=False)
936-
print(f'On average, your cohorts contained around {round(cohorts.cohort_size.mean())} entities with a baserate of {round(cohorts.baserate.mean(), 3)}')
937-
938-
print(f"You built {stats['features']} features organized into {stats['feature_groups']} groups/blocks")
979+
logger.notice(f'On average, your cohorts contained around {round(cohorts.cohort_size.mean())} entities with a baserate of {round(cohorts.baserate.mean(), 3)}')
939980

940-
print(f"Your model grid specification contained {stats['grid_size']} model types with {stats['models_needed']} individual models")
981+
logger.notice(f"You built {stats['features']} features organized into {stats['feature_groups']} groups/blocks")
982+
983+
logger.notice(f"Your model grid specification contained {stats['grid_size']} model types with {stats['models_needed']} individual models")
941984

942985
## Models
943986
num_models = len(self.models())
944987
if num_models < stats['models_needed']:
945-
print(f"However, the experiment only built {num_models} models. You are missing {stats['models_needed'] - num_models} models")
988+
logger.notice(f"However, the experiment only built {num_models} models. You are missing {stats['models_needed'] - num_models} models")
946989

947990
else:
948-
print(f"You successfully built all the {num_models} models")
991+
logger.notice(f"You successfully built all the {num_models} models")
949992

950993
# Model Performance
951994
performance = self.model_performance(metric=metric, parameter=parameter, generate_plot=False)
952995
best_performance = performance.groupby(['model_group_id', 'model_type'])['metric_value'].mean().max()
953996
best_model_group = performance.groupby(['model_group_id', 'model_type'])['metric_value'].mean().idxmax()[0]
954997
best_model_type = performance.groupby(['model_group_id', 'model_type'])['metric_value'].mean().idxmax()[1]
955-
956-
print(f"Your models acheived a best average {metric}{parameter} of {round(best_performance, 3)} over the {stats['validation_splits']} validation splits, with the Model Group {best_model_group},{best_model_type}. Note that model selection is more nuanced than average predictive performance over time. You could use Audition for model selection.")
998+
999+
# because we could change the value of the default parameter in case it doesn't exist,
1000+
# it is safer to take it from the object itself.
1001+
logger.notice(f"Your models achieved a best average {self.performance_metric}{self.threshold} of {round(best_performance, 3)} over the {stats['validation_splits']} validation splits, with the Model Group {best_model_group},{best_model_type}. Note that model selection is more nuanced than average predictive performance over time. You could use Audition for model selection.")
9571002

9581003
## Subsets
9591004
subset_performance = self.model_performance_subsets(metric=metric, parameter=parameter, generate_plot=False)
@@ -969,11 +1014,11 @@ def generate_summary(self, metric=None, parameter=None, equity_metric=None):
9691014
res.append(d)
9701015

9711016
if len(res) > 0:
972-
print(f"You created {len(res)} subsets of your cohort -- {', '.join([x['subset'] for x in res])}")
1017+
logger.notice(f"You created {len(res)} subsets of your cohort -- {', '.join([x['subset'] for x in res])}")
9731018
for d in res:
974-
print(f"For subset '{d['subset'] }', Model Group {d['best_mod'][0]}, {d['best_mod'][1]} achieved the best average {metric}{parameter} of {d['best_perf']}")
1019+
logger.notice(f"For subset '{d['subset'] }', Model Group {d['best_mod'][0]}, {d['best_mod'][1]} achieved the best average {metric}{parameter} of {d['best_perf']}")
9751020
else:
976-
print("No subsets defined.")
1021+
logger.notice("No subsets defined.")
9771022

9781023
## Bias
9791024
equity_metrics = self.efficiency_and_equity(
@@ -986,11 +1031,11 @@ def generate_summary(self, metric=None, parameter=None, equity_metric=None):
9861031
if equity_metrics is not None:
9871032
grpobj = equity_metrics[(equity_metrics.baserate > 0) & (equity_metrics.model_group_id == best_model_group)].groupby('attribute_name')
9881033
for attr, gdf in grpobj:
989-
print(f'Measuring biases across {attr} groups using {equity_metric} for the best performing model:')
1034+
logger.notice(f"Measuring biases across {attr} groups using {equity_metric} for the best performing model:")
9901035
d = gdf.groupby('attribute_value')[equity_metric].mean()
991-
print(", ".join(f"{k}: {round(v, 3)}" for k, v, in d.to_dict().items()))
1036+
logger.notice(", ".join(f"{k}: {round(v, 3)}" for k, v, in d.to_dict().items()))
9921037
else:
993-
print(f"No bias audit results were found in the database for the experiment.")
1038+
logger.notice(f"No bias audit results were found in the database for the experiment.")
9941039

9951040

9961041
def precision_recall_curves(self, plot_size=(3,3)):

0 commit comments

Comments
 (0)