|
18 | 18 | import json |
19 | 19 | import logging |
20 | 20 | import re |
| 21 | +import textwrap |
21 | 22 | from collections import defaultdict |
22 | 23 | from typing import Dict, List, Set |
23 | 24 |
|
|
29 | 30 | ExternalPart, |
30 | 31 | GroupBy, |
31 | 32 | Join, |
| 33 | + JoinPart, |
32 | 34 | Source, |
33 | 35 | ) |
34 | 36 | from ai.chronon.group_by import get_output_col_names |
@@ -377,6 +379,81 @@ def _validate_derivations( |
377 | 379 | derived_columns.add(derivation.name) |
378 | 380 | return errors |
379 | 381 |
|
| 382 | + def _validate_join_part_keys(self, join_part: JoinPart, left_cols: List[str]) -> BaseException: |
| 383 | + keys = [] |
| 384 | + |
| 385 | + key_mapping = join_part.keyMapping if join_part.keyMapping else {} |
| 386 | + for key in join_part.groupBy.keyColumns: |
| 387 | + keys.append(key_mapping.get(key, key)) |
| 388 | + |
| 389 | + missing = [k for k in keys if k not in left_cols] |
| 390 | + |
| 391 | + err_string = "" |
| 392 | + left_cols_as_str = ", ".join(left_cols) |
| 393 | + group_by_name = join_part.groupBy.metaData.name |
| 394 | + if missing: |
| 395 | + key_mapping_str = f"Key Mapping: {key_mapping}" if key_mapping else "" |
| 396 | + err_string += textwrap.dedent(f""" |
| 397 | + - Join is missing keys {missing} on left side. Required for JoinPart: {group_by_name}. |
| 398 | + Existing columns on left side: {left_cols_as_str} |
| 399 | + All required Keys: {join_part.groupBy.keyColumns} |
| 400 | + {key_mapping_str} |
| 401 | + Consider renaming a column on the left, or including the key_mapping argument to your join_part.""") |
| 402 | + |
| 403 | + if key_mapping: |
| 404 | + # Left side of key mapping should include columns on the left |
| 405 | + key_map_keys_missing_from_left = [k for k in key_mapping.keys() if k not in left_cols] |
| 406 | + if key_map_keys_missing_from_left: |
| 407 | + err_string += f"\n- The following keys in your key_mapping: {str(key_map_keys_missing_from_left)} for JoinPart {group_by_name} are not included in the left side of the join: {left_cols_as_str}" |
| 408 | + |
| 409 | + # Right side of key mapping should only include keys in GroupBy |
| 410 | + keys_missing_from_key_map_values = [v for v in key_mapping.values() if v not in join_part.groupBy.keyColumns] |
| 411 | + if keys_missing_from_key_map_values: |
| 412 | + err_string += f"\n- The following values in your key_mapping: {str(keys_missing_from_key_map_values)} for JoinPart {group_by_name} do not cover any group by key columns: {join_part.groupBy.keyColumns}" |
| 413 | + |
| 414 | + if key_map_keys_missing_from_left or keys_missing_from_key_map_values: |
| 415 | + err_string += "\n(Key Mapping should be formatted as column_from_left -> group_by_key)" |
| 416 | + |
| 417 | + if err_string: |
| 418 | + return ValueError(err_string) |
| 419 | + |
| 420 | + |
| 421 | + def _validate_keys(self, join: Join) -> List[BaseException]: |
| 422 | + left = join.left |
| 423 | + |
| 424 | + left_selects = None |
| 425 | + if left.events: |
| 426 | + left_selects = left.events.query.selects |
| 427 | + elif left.entities: |
| 428 | + left_selects = left.entities.query.selects |
| 429 | + elif left.joinSource: |
| 430 | + left_selects = left.joinSource.query.selects |
| 431 | + # TODO -- if selects are not selected here, get output cols from join |
| 432 | + |
| 433 | + left_cols = [] |
| 434 | + |
| 435 | + if left_selects: |
| 436 | + left_cols = left_selects.keys() |
| 437 | + |
| 438 | + errors = [] |
| 439 | + |
| 440 | + if left_cols: |
| 441 | + join_parts = join.joinParts |
| 442 | + |
| 443 | + # Add label_parts to join_parts to validate if set |
| 444 | + label_parts = join.labelParts |
| 445 | + if label_parts: |
| 446 | + for label_jp in label_parts.labels: |
| 447 | + join_parts.append(label_jp) |
| 448 | + |
| 449 | + # Validate join_parts |
| 450 | + for join_part in join_parts: |
| 451 | + join_part_err = self._validate_join_part_keys(join_part, left_cols) |
| 452 | + if join_part_err: |
| 453 | + errors.append(join_part_err) |
| 454 | + |
| 455 | + return errors |
| 456 | + |
380 | 457 | def _validate_join(self, join: Join) -> List[BaseException]: |
381 | 458 | """ |
382 | 459 | Validate join's status with materialized versions of group_bys |
@@ -437,6 +514,9 @@ def _validate_join(self, join: Join) -> List[BaseException]: |
437 | 514 | keys = get_pre_derived_source_keys(join.left) |
438 | 515 | columns = features + keys |
439 | 516 | errors.extend(self._validate_derivations(columns, join.derivations)) |
| 517 | + |
| 518 | + errors.extend(self._validate_keys(join)) |
| 519 | + |
440 | 520 | return errors |
441 | 521 |
|
442 | 522 | def _validate_group_by(self, group_by: GroupBy) -> List[BaseException]: |
|
0 commit comments