|
1 | | -""" |
2 | | -Datetime namespace for expression operations on datetime-typed columns. |
3 | | -
|
4 | | -This module defines the ``_DatetimeNamespace`` class which exposes a set of |
5 | | -convenience methods for working with timestamp and date columns in Ray Data |
6 | | -expressions. The API mirrors pandas' ``Series.dt`` accessor and is backed by |
7 | | -PyArrow compute functions for efficient execution. |
8 | | -
|
9 | | -Example |
10 | | -------- |
11 | | -
|
12 | | ->>> from ray.data.expressions import col |
13 | | ->>> # Extract year, month and day from a timestamp column |
14 | | ->>> expr_year = col("timestamp").dt.year() |
15 | | ->>> expr_month = col("timestamp").dt.month() |
16 | | ->>> expr_day = col("timestamp").dt.day() |
17 | | ->>> # Format the timestamp as a string |
18 | | ->>> expr_fmt = col("timestamp").dt.strftime("%Y-%m-%d") |
19 | | -""" |
20 | | - |
21 | 1 | from __future__ import annotations |
22 | 2 |
|
23 | 3 | from dataclasses import dataclass |
24 | | -from typing import TYPE_CHECKING, Any |
| 4 | +from typing import TYPE_CHECKING, Any, Callable, Literal |
25 | 5 |
|
26 | 6 | import pyarrow |
27 | 7 | import pyarrow.compute as pc |
|
32 | 12 | if TYPE_CHECKING: |
33 | 13 | from ray.data.expressions import Expr, UDFExpr |
34 | 14 |
|
| 15 | +TemporalUnit = Literal[ |
| 16 | + "year", |
| 17 | + "quarter", |
| 18 | + "month", |
| 19 | + "week", |
| 20 | + "day", |
| 21 | + "hour", |
| 22 | + "minute", |
| 23 | + "second", |
| 24 | + "millisecond", |
| 25 | + "microsecond", |
| 26 | + "nanosecond", |
| 27 | +] |
| 28 | + |
| 29 | + |
35 | 30 | @dataclass |
36 | 31 | class _DatetimeNamespace: |
37 | | - """Namespace for datetime operations on expression columns.""" |
| 32 | + """Datetime namespace for operations on datetime-typed expression columns.""" |
| 33 | + |
38 | 34 | _expr: "Expr" |
39 | 35 |
|
40 | | - def year(self) -> "UDFExpr": |
| 36 | + def _unary_temporal_int( |
| 37 | + self, func: Callable[[pyarrow.Array], pyarrow.Array] |
| 38 | + ) -> "UDFExpr": |
| 39 | + """Helper for year/month/… that return int32.""" |
| 40 | + |
41 | 41 | @pyarrow_udf(return_dtype=DataType.int32()) |
42 | | - def _year(arr: pyarrow.Array) -> pyarrow.Array: |
43 | | - return pc.year(arr) |
44 | | - return _year(self._expr) |
| 42 | + def _udf(arr: pyarrow.Array) -> pyarrow.Array: |
| 43 | + return func(arr) |
| 44 | + |
| 45 | + return _udf(self._expr) |
| 46 | + |
| 47 | + # extractors |
| 48 | + |
| 49 | + def year(self) -> "UDFExpr": |
| 50 | + """Extract year component.""" |
| 51 | + return self._unary_temporal_int(pc.year) |
45 | 52 |
|
46 | 53 | def month(self) -> "UDFExpr": |
47 | | - @pyarrow_udf(return_dtype=DataType.int32()) |
48 | | - def _month(arr: pyarrow.Array) -> pyarrow.Array: |
49 | | - return pc.month(arr) |
50 | | - return _month(self._expr) |
| 54 | + """Extract month component.""" |
| 55 | + return self._unary_temporal_int(pc.month) |
51 | 56 |
|
52 | 57 | def day(self) -> "UDFExpr": |
53 | | - @pyarrow_udf(return_dtype=DataType.int32()) |
54 | | - def _day(arr: pyarrow.Array) -> pyarrow.Array: |
55 | | - return pc.day(arr) |
56 | | - return _day(self._expr) |
| 58 | + """Extract day component.""" |
| 59 | + return self._unary_temporal_int(pc.day) |
57 | 60 |
|
58 | 61 | def hour(self) -> "UDFExpr": |
59 | | - @pyarrow_udf(return_dtype=DataType.int32()) |
60 | | - def _hour(arr: pyarrow.Array) -> pyarrow.Array: |
61 | | - return pc.hour(arr) |
62 | | - return _hour(self._expr) |
| 62 | + """Extract hour component.""" |
| 63 | + return self._unary_temporal_int(pc.hour) |
63 | 64 |
|
64 | 65 | def minute(self) -> "UDFExpr": |
65 | | - @pyarrow_udf(return_dtype=DataType.int32()) |
66 | | - def _minute(arr: pyarrow.Array) -> pyarrow.Array: |
67 | | - return pc.minute(arr) |
68 | | - return _minute(self._expr) |
| 66 | + """Extract minute component.""" |
| 67 | + return self._unary_temporal_int(pc.minute) |
69 | 68 |
|
70 | 69 | def second(self) -> "UDFExpr": |
71 | | - @pyarrow_udf(return_dtype=DataType.int32()) |
72 | | - def _second(arr: pyarrow.Array) -> pyarrow.Array: |
73 | | - return pc.second(arr) |
74 | | - return _second(self._expr) |
| 70 | + """Extract second component.""" |
| 71 | + return self._unary_temporal_int(pc.second) |
| 72 | + |
| 73 | + # formatting |
75 | 74 |
|
76 | 75 | def strftime(self, fmt: str) -> "UDFExpr": |
77 | | - """Format each timestamp using a strftime format string.""" |
| 76 | + """Format timestamps with a strftime pattern.""" |
| 77 | + |
78 | 78 | @pyarrow_udf(return_dtype=DataType.string()) |
79 | 79 | def _format(arr: pyarrow.Array) -> pyarrow.Array: |
80 | 80 | return pc.strftime(arr, format=fmt) |
| 81 | + |
81 | 82 | return _format(self._expr) |
82 | 83 |
|
83 | | - def ceil(self, unit: str) -> "UDFExpr": |
84 | | - """Ceil timestamps up to the nearest boundary of the given unit.""" |
| 84 | + # rounding |
| 85 | + |
| 86 | + def ceil(self, unit: TemporalUnit) -> "UDFExpr": |
| 87 | + """Ceil timestamps to the next multiple of the given unit.""" |
85 | 88 | return_dtype = self._expr.data_type |
| 89 | + |
86 | 90 | @pyarrow_udf(return_dtype=return_dtype) |
87 | 91 | def _ceil(arr: pyarrow.Array) -> pyarrow.Array: |
88 | 92 | return pc.ceil_temporal(arr, multiple=1, unit=unit) |
| 93 | + |
89 | 94 | return _ceil(self._expr) |
90 | 95 |
|
91 | | - def floor(self, unit: str) -> "UDFExpr": |
92 | | - """Floor timestamps down to the previous boundary of the given unit.""" |
| 96 | + def floor(self, unit: TemporalUnit) -> "UDFExpr": |
| 97 | + """Floor timestamps to the previous multiple of the given unit.""" |
93 | 98 | return_dtype = self._expr.data_type |
| 99 | + |
94 | 100 | @pyarrow_udf(return_dtype=return_dtype) |
95 | 101 | def _floor(arr: pyarrow.Array) -> pyarrow.Array: |
96 | 102 | return pc.floor_temporal(arr, multiple=1, unit=unit) |
| 103 | + |
97 | 104 | return _floor(self._expr) |
98 | 105 |
|
99 | | - def round(self, unit: str, tie_breaker: str = "half_to_even") -> "UDFExpr": |
100 | | - """Round timestamps to the nearest boundary of the given unit.""" |
| 106 | + def round(self, unit: TemporalUnit) -> "UDFExpr": |
| 107 | + """Round timestamps to the nearest multiple of the given unit.""" |
101 | 108 | return_dtype = self._expr.data_type |
| 109 | + |
102 | 110 | @pyarrow_udf(return_dtype=return_dtype) |
103 | 111 | def _round(arr: pyarrow.Array) -> pyarrow.Array: |
104 | | - return pc.round_temporal(arr, multiple=1, unit=unit, |
105 | | - tie_breaker=tie_breaker) |
| 112 | + |
| 113 | + return pc.round_temporal(arr, multiple=1, unit=unit) |
| 114 | + |
106 | 115 | return _round(self._expr) |
0 commit comments