Skip to content

Commit cb0aef6

Browse files
authored
Merge branch 'main' into tchow/align-confs
2 parents 73b04ba + 7248864 commit cb0aef6

File tree

5 files changed

+109
-20
lines changed

5 files changed

+109
-20
lines changed

api/python/ai/chronon/cli/compile/conf_validator.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import json
1919
import logging
2020
import re
21+
import textwrap
2122
from collections import defaultdict
2223
from typing import Dict, List, Set
2324

@@ -29,6 +30,7 @@
2930
ExternalPart,
3031
GroupBy,
3132
Join,
33+
JoinPart,
3234
Source,
3335
)
3436
from ai.chronon.group_by import get_output_col_names
@@ -377,6 +379,81 @@ def _validate_derivations(
377379
derived_columns.add(derivation.name)
378380
return errors
379381

382+
def _validate_join_part_keys(self, join_part: JoinPart, left_cols: List[str]) -> BaseException:
383+
keys = []
384+
385+
key_mapping = join_part.keyMapping if join_part.keyMapping else {}
386+
for key in join_part.groupBy.keyColumns:
387+
keys.append(key_mapping.get(key, key))
388+
389+
missing = [k for k in keys if k not in left_cols]
390+
391+
err_string = ""
392+
left_cols_as_str = ", ".join(left_cols)
393+
group_by_name = join_part.groupBy.metaData.name
394+
if missing:
395+
key_mapping_str = f"Key Mapping: {key_mapping}" if key_mapping else ""
396+
err_string += textwrap.dedent(f"""
397+
- Join is missing keys {missing} on left side. Required for JoinPart: {group_by_name}.
398+
Existing columns on left side: {left_cols_as_str}
399+
All required Keys: {join_part.groupBy.keyColumns}
400+
{key_mapping_str}
401+
Consider renaming a column on the left, or including the key_mapping argument to your join_part.""")
402+
403+
if key_mapping:
404+
# Left side of key mapping should include columns on the left
405+
key_map_keys_missing_from_left = [k for k in key_mapping.keys() if k not in left_cols]
406+
if key_map_keys_missing_from_left:
407+
err_string += f"\n- The following keys in your key_mapping: {str(key_map_keys_missing_from_left)} for JoinPart {group_by_name} are not included in the left side of the join: {left_cols_as_str}"
408+
409+
# Right side of key mapping should only include keys in GroupBy
410+
keys_missing_from_key_map_values = [v for v in key_mapping.values() if v not in join_part.groupBy.keyColumns]
411+
if keys_missing_from_key_map_values:
412+
err_string += f"\n- The following values in your key_mapping: {str(keys_missing_from_key_map_values)} for JoinPart {group_by_name} do not cover any group by key columns: {join_part.groupBy.keyColumns}"
413+
414+
if key_map_keys_missing_from_left or keys_missing_from_key_map_values:
415+
err_string += "\n(Key Mapping should be formatted as column_from_left -> group_by_key)"
416+
417+
if err_string:
418+
return ValueError(err_string)
419+
420+
421+
def _validate_keys(self, join: Join) -> List[BaseException]:
422+
left = join.left
423+
424+
left_selects = None
425+
if left.events:
426+
left_selects = left.events.query.selects
427+
elif left.entities:
428+
left_selects = left.entities.query.selects
429+
elif left.joinSource:
430+
left_selects = left.joinSource.query.selects
431+
# TODO -- if selects are not selected here, get output cols from join
432+
433+
left_cols = []
434+
435+
if left_selects:
436+
left_cols = left_selects.keys()
437+
438+
errors = []
439+
440+
if left_cols:
441+
join_parts = join.joinParts
442+
443+
# Add label_parts to join_parts to validate if set
444+
label_parts = join.labelParts
445+
if label_parts:
446+
for label_jp in label_parts.labels:
447+
join_parts.append(label_jp)
448+
449+
# Validate join_parts
450+
for join_part in join_parts:
451+
join_part_err = self._validate_join_part_keys(join_part, left_cols)
452+
if join_part_err:
453+
errors.append(join_part_err)
454+
455+
return errors
456+
380457
def _validate_join(self, join: Join) -> List[BaseException]:
381458
"""
382459
Validate join's status with materialized versions of group_bys
@@ -437,6 +514,9 @@ def _validate_join(self, join: Join) -> List[BaseException]:
437514
keys = get_pre_derived_source_keys(join.left)
438515
columns = features + keys
439516
errors.extend(self._validate_derivations(columns, join.derivations))
517+
518+
errors.extend(self._validate_keys(join))
519+
440520
return errors
441521

442522
def _validate_group_by(self, group_by: GroupBy) -> List[BaseException]:

api/python/ai/chronon/cli/compile/display/class_tracker.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,9 +96,14 @@ def to_errors(self) -> Text:
9696
text = Text(overflow="fold", no_wrap=False)
9797

9898
if self.files_to_errors:
99-
for file, error in self.files_to_errors.items():
99+
for file, errors in self.files_to_errors.items():
100100
text.append(" ERROR ", style="bold red")
101-
text.append(f"- {file}: {str(error)}\n")
101+
text.append(f"- {file}:\n")
102+
103+
for error in errors:
104+
# Format each error properly, handling newlines
105+
error_msg = str(error)
106+
text.append(f" {error_msg}\n", style="red")
102107

103108
return text
104109

api/python/test/sample/sources/test_sources.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,7 @@ def basic_event_source(table):
4040
selects=selects(
4141
event="event_expr",
4242
group_by_subject="group_by_expr",
43+
subject="subject",
4344
),
4445
start_partition="2021-04-09",
4546
time_column="ts",

online/src/main/scala/ai/chronon/online/serde/AvroCodec.scala

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,9 @@ import org.apache.avro.generic.{GenericData, GenericRecord}
2525
import org.apache.avro.io._
2626
import com.linkedin.avro.fastserde.FastGenericDatumReader
2727
import com.linkedin.avro.fastserde.FastGenericDatumWriter
28+
import java.util.concurrent.ConcurrentHashMap
29+
2830
import java.io.ByteArrayOutputStream
29-
import scala.collection.mutable
3031

3132
class AvroCodec(val schemaStr: String) extends Serializable {
3233
@transient private lazy val parser = new Schema.Parser()
@@ -130,11 +131,16 @@ class ArrayRow(values: Array[Any], millis: Long, mutation: Boolean = false) exte
130131

131132
object AvroCodec {
132133
// creating new codecs is expensive - so we want to do it once per process
133-
// but at the same-time we want to avoid contention across threads - hence threadlocal
134-
private val codecMap: ThreadLocal[mutable.HashMap[String, AvroCodec]] =
135-
new ThreadLocal[mutable.HashMap[String, AvroCodec]] {
136-
override def initialValue(): mutable.HashMap[String, AvroCodec] = new mutable.HashMap[String, AvroCodec]()
137-
}
138-
139-
def of(schemaStr: String): AvroCodec = codecMap.get().getOrElseUpdate(schemaStr, new AvroCodec(schemaStr))
134+
// but at the same-time we want to avoid contention across threads - hence thread-local
135+
private val codecMap: ConcurrentHashMap[String, ThreadLocal[AvroCodec]] =
136+
new ConcurrentHashMap[String, ThreadLocal[AvroCodec]]
137+
138+
def ofThreaded(schemaStr: String): ThreadLocal[AvroCodec] = codecMap.computeIfAbsent(
139+
schemaStr,
140+
str =>
141+
new ThreadLocal[AvroCodec] {
142+
override def initialValue(): AvroCodec = new AvroCodec(str)
143+
})
144+
145+
def of(schemaStr: String): AvroCodec = ofThreaded(schemaStr).get()
140146
}

online/src/main/scala/ai/chronon/online/serde/AvroSerDe.scala

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,24 +5,21 @@ import org.apache.avro.Schema
55
import org.apache.avro.generic.GenericRecord
66
import org.apache.avro.io.{BinaryDecoder, DecoderFactory}
77
import org.apache.avro.specific.SpecificDatumReader
8-
8+
import com.linkedin.avro.fastserde.FastDeserializer
99
import java.io.{ByteArrayInputStream, InputStream}
1010

1111
class AvroSerDe(avroSchema: Schema) extends SerDe {
1212

13-
lazy val chrononSchema = AvroConversions.toChrononSchema(avroSchema).asInstanceOf[StructType]
13+
lazy val chrononSchema: StructType = AvroConversions.toChrononSchema(avroSchema).asInstanceOf[StructType]
1414

15-
@transient lazy val avroToRowConverter = AvroConversions.genericRecordToChrononRowConverter(chrononSchema)
15+
@transient private lazy val avroToRowConverter = AvroConversions.genericRecordToChrononRowConverter(chrononSchema)
1616

17-
private def byteArrayToAvro(avro: Array[Byte], schema: Schema): GenericRecord = {
18-
val reader = new SpecificDatumReader[GenericRecord](schema)
19-
val input: InputStream = new ByteArrayInputStream(avro)
20-
val decoder: BinaryDecoder = DecoderFactory.get().binaryDecoder(input, null)
21-
reader.read(null, decoder)
22-
}
17+
lazy val schemaString: String = avroSchema.toString()
18+
19+
def avroCodec: ThreadLocal[AvroCodec] = AvroCodec.ofThreaded(schemaString)
2320

2421
override def fromBytes(bytes: Array[Byte]): Mutation = {
25-
val avroRecord = byteArrayToAvro(bytes, avroSchema)
22+
val avroRecord = avroCodec.get().decode(bytes)
2623

2724
val row: Array[Any] = avroToRowConverter(avroRecord)
2825

0 commit comments

Comments
 (0)