Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions src/tests/architect_tests/test_label_generators.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from datetime import date, timedelta

import testing.postgresql
import pytest
from sqlalchemy import create_engine

from triage.component.architect.label_generators import LabelGenerator
Expand Down Expand Up @@ -169,3 +170,30 @@ def test_generate_all_labels_noreplace():
(4, date(2014, 9, 30), timedelta(90), "outcome", "binary", False),
]
assert records == expected


def test_generate_all_labels_errors_on_duplicates():

# label query that will yield duplicates (one row for each event in the timespan)
BAD_LABEL_GENERATE_QUERY = """
select
events.entity_id,
1 as outcome
from events
where
'{as_of_date}' <= outcome_date
and outcome_date < '{as_of_date}'::timestamp + interval '{label_timespan}'
"""

with testing.postgresql.Postgresql() as postgresql:
engine = create_engine(postgresql.url())
create_binary_outcome_events(engine, "events", events_data)

label_generator = LabelGenerator(db_engine=engine, query=BAD_LABEL_GENERATE_QUERY, replace=True)
with pytest.raises(ValueError):
label_generator.generate_all_labels(
labels_table=LABELS_TABLE_NAME,
as_of_dates=["2014-09-30", "2015-03-30"],
label_timespans=["6month", "3month"],
)

2 changes: 1 addition & 1 deletion src/tests/postmodeling_tests/test_add_predictions.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import pytest

from triage.component.postmodeling.utils.add_predictions import add_predictions
from triage.component.architect.database_reflection import table_has_data
from triage.database_reflection import table_has_data


MODEL_IDS_QUERY = """
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from testing.postgresql import Postgresql
from unittest import TestCase

from triage.component.architect import database_reflection as dbreflect
import triage.database_reflection as dbreflect


class TestDatabaseReflection(TestCase):
Expand Down Expand Up @@ -40,6 +40,16 @@ def test_table_has_data(self):
assert dbreflect.table_has_data("compliments", self.engine)
assert not dbreflect.table_has_data("incidents", self.engine)

def test_table_has_duplicates(self):
self.engine.execute("create table events (col1 int, col2 int)")
assert not dbreflect.table_has_duplicates("events", ['col1', 'col2'], self.engine)
self.engine.execute("insert into events values (1,2)")
self.engine.execute("insert into events values (1,3)")
assert dbreflect.table_has_duplicates("events", ['col1'], self.engine)
assert not dbreflect.table_has_duplicates("events", ['col1', 'col2'], self.engine)
self.engine.execute("insert into events values (1,2)")
assert dbreflect.table_has_duplicates("events", ['col1', 'col2'], self.engine)

def test_table_has_column(self):
self.engine.execute("create table incidents (col1 varchar)")
assert dbreflect.table_has_column("incidents", "col1", self.engine)
Expand Down
10 changes: 8 additions & 2 deletions src/triage/component/architect/entity_date_table_generators.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import verboselogs

from triage.component.architect.database_reflection import table_has_data
from triage.database_reflection import table_row_count, table_exists
from triage.database_reflection import table_has_data, table_row_count, table_exists, table_has_duplicates


logger = verboselogs.VerboseLogger(__name__)
Expand Down Expand Up @@ -55,6 +54,13 @@ def generate_entity_date_table(self, as_of_dates):
if not table_has_data(self.entity_date_table_name, self.db_engine):
raise ValueError(self._empty_table_message(as_of_dates))

if table_has_duplicates(
self.entity_date_table_name,
['entity_id', 'as_of_date'],
self.db_engine
):
raise ValueError(f"Duplicates found in {self.entity_date_table_name}!")

logger.debug(f"Entity-date table generated at {self.entity_date_table_name}")
logger.spam(f"Generating stats on {self.entity_date_table_name}")
logger.spam(f"Row count of {self.entity_date_table_name}: {table_row_count(self.entity_date_table_name, self.db_engine)}")
Expand Down
17 changes: 12 additions & 5 deletions src/triage/component/architect/label_generators.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
logger = verboselogs.VerboseLogger(__name__)

import textwrap
from triage.database_reflection import table_row_count, table_exists
from triage.database_reflection import table_row_count, table_exists, table_has_duplicates

DEFAULT_LABEL_NAME = "outcome"

Expand Down Expand Up @@ -86,10 +86,17 @@ def generate_all_labels(self, labels_table, as_of_dates, label_timespans):

if nrows == 0:
logger.warning(f"Done creating labels, but no rows in {labels_table} table!")
raise ValueError(f"{label_table} is empty!")
else:
logger.debug(f"Labels table generated at {labels_table}")
logger.spam(f"Row count of {labels_table}: {nrows}")
raise ValueError(f"{labels_table} is empty!")

if table_has_duplicates(
labels_table,
['entity_id', 'as_of_date', 'label_timespan', 'label_name', 'label_type'],
self.db_engine
):
raise ValueError(f"Duplicates found in {labels_table}!")

logger.debug(f"Labels table generated at {labels_table}")
logger.spam(f"Row count of {labels_table}: {nrows}")

def generate(self, start_date, label_timespan, labels_table):
"""Generate labels table using a query
Expand Down
2 changes: 1 addition & 1 deletion src/triage/component/architect/validations.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
"""Functions for validating input, mostly around database schema and state"""
from triage.component.architect.database_reflection import (
from triage.database_reflection import (
table_exists,
table_has_column,
column_type,
Expand Down
35 changes: 32 additions & 3 deletions src/triage/database_reflection.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,11 @@ def table_has_data(table_name, db_engine):
"""
if not table_exists(table_name, db_engine):
return False
result = [
row for row in db_engine.execute("select 1 from {} limit 1".format(table_name))
results = [
row for row in db_engine.execute("select * from {} limit 1".format(table_name))
]

return any(result)
return len(results) > 0


def table_row_count(table_name, db_engine):
Expand All @@ -99,6 +99,35 @@ def table_row_count(table_name, db_engine):
)


def table_has_duplicates(table_name, column_list, db_engine):
"""Check whether the table has duplicate rows on the set of columns.

The table is expected to exist and contain the columns in column_list.

Args:
table_name (string) A table name (with schema)
column_list (list) A list of column names
db_engine (sqlalchemy.engine)

Returns: (boolean) Whether or not duplicates are found
"""
if not table_has_data(table_name, db_engine):
return False

cols = ','.join(['"%s"' % c for c in column_list])
sql = f"""
WITH counts AS (
SELECT {cols}
, COUNT(*) AS num_records
FROM {table_name}
GROUP BY {cols}
)
SELECT MAX(num_records) FROM counts
"""
result = next(db_engine.execute(sql))[0]
return result > 1


def table_has_column(table_name, column, db_engine):
"""Check whether the table contains a column of the given name

Expand Down