Skip to content

Commit 86ce74f

Browse files
committed
[Data] Fix datetime namespace review comments
1 parent 159bdfe commit 86ce74f

File tree

2 files changed

+136
-55
lines changed

2 files changed

+136
-55
lines changed
Lines changed: 64 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,7 @@
1-
"""
2-
Datetime namespace for expression operations on datetime-typed columns.
3-
4-
This module defines the ``_DatetimeNamespace`` class which exposes a set of
5-
convenience methods for working with timestamp and date columns in Ray Data
6-
expressions. The API mirrors pandas' ``Series.dt`` accessor and is backed by
7-
PyArrow compute functions for efficient execution.
8-
9-
Example
10-
-------
11-
12-
>>> from ray.data.expressions import col
13-
>>> # Extract year, month and day from a timestamp column
14-
>>> expr_year = col("timestamp").dt.year()
15-
>>> expr_month = col("timestamp").dt.month()
16-
>>> expr_day = col("timestamp").dt.day()
17-
>>> # Format the timestamp as a string
18-
>>> expr_fmt = col("timestamp").dt.strftime("%Y-%m-%d")
19-
"""
20-
211
from __future__ import annotations
222

233
from dataclasses import dataclass
24-
from typing import TYPE_CHECKING, Any
4+
from typing import TYPE_CHECKING, Any, Callable, Literal
255

266
import pyarrow
277
import pyarrow.compute as pc
@@ -32,75 +12,104 @@
3212
if TYPE_CHECKING:
3313
from ray.data.expressions import Expr, UDFExpr
3414

15+
TemporalUnit = Literal[
16+
"year",
17+
"quarter",
18+
"month",
19+
"week",
20+
"day",
21+
"hour",
22+
"minute",
23+
"second",
24+
"millisecond",
25+
"microsecond",
26+
"nanosecond",
27+
]
28+
29+
3530
@dataclass
3631
class _DatetimeNamespace:
37-
"""Namespace for datetime operations on expression columns."""
32+
"""Datetime namespace for operations on datetime-typed expression columns."""
33+
3834
_expr: "Expr"
3935

40-
def year(self) -> "UDFExpr":
36+
def _unary_temporal_int(
37+
self, func: Callable[[pyarrow.Array], pyarrow.Array]
38+
) -> "UDFExpr":
39+
"""Helper for year/month/… that return int32."""
40+
4141
@pyarrow_udf(return_dtype=DataType.int32())
42-
def _year(arr: pyarrow.Array) -> pyarrow.Array:
43-
return pc.year(arr)
44-
return _year(self._expr)
42+
def _udf(arr: pyarrow.Array) -> pyarrow.Array:
43+
return func(arr)
44+
45+
return _udf(self._expr)
46+
47+
# extractors
48+
49+
def year(self) -> "UDFExpr":
50+
"""Extract year component."""
51+
return self._unary_temporal_int(pc.year)
4552

4653
def month(self) -> "UDFExpr":
47-
@pyarrow_udf(return_dtype=DataType.int32())
48-
def _month(arr: pyarrow.Array) -> pyarrow.Array:
49-
return pc.month(arr)
50-
return _month(self._expr)
54+
"""Extract month component."""
55+
return self._unary_temporal_int(pc.month)
5156

5257
def day(self) -> "UDFExpr":
53-
@pyarrow_udf(return_dtype=DataType.int32())
54-
def _day(arr: pyarrow.Array) -> pyarrow.Array:
55-
return pc.day(arr)
56-
return _day(self._expr)
58+
"""Extract day component."""
59+
return self._unary_temporal_int(pc.day)
5760

5861
def hour(self) -> "UDFExpr":
59-
@pyarrow_udf(return_dtype=DataType.int32())
60-
def _hour(arr: pyarrow.Array) -> pyarrow.Array:
61-
return pc.hour(arr)
62-
return _hour(self._expr)
62+
"""Extract hour component."""
63+
return self._unary_temporal_int(pc.hour)
6364

6465
def minute(self) -> "UDFExpr":
65-
@pyarrow_udf(return_dtype=DataType.int32())
66-
def _minute(arr: pyarrow.Array) -> pyarrow.Array:
67-
return pc.minute(arr)
68-
return _minute(self._expr)
66+
"""Extract minute component."""
67+
return self._unary_temporal_int(pc.minute)
6968

7069
def second(self) -> "UDFExpr":
71-
@pyarrow_udf(return_dtype=DataType.int32())
72-
def _second(arr: pyarrow.Array) -> pyarrow.Array:
73-
return pc.second(arr)
74-
return _second(self._expr)
70+
"""Extract second component."""
71+
return self._unary_temporal_int(pc.second)
72+
73+
# formatting
7574

7675
def strftime(self, fmt: str) -> "UDFExpr":
77-
"""Format each timestamp using a strftime format string."""
76+
"""Format timestamps with a strftime pattern."""
77+
7878
@pyarrow_udf(return_dtype=DataType.string())
7979
def _format(arr: pyarrow.Array) -> pyarrow.Array:
8080
return pc.strftime(arr, format=fmt)
81+
8182
return _format(self._expr)
8283

83-
def ceil(self, unit: str) -> "UDFExpr":
84-
"""Ceil timestamps up to the nearest boundary of the given unit."""
84+
# rounding
85+
86+
def ceil(self, unit: TemporalUnit) -> "UDFExpr":
87+
"""Ceil timestamps to the next multiple of the given unit."""
8588
return_dtype = self._expr.data_type
89+
8690
@pyarrow_udf(return_dtype=return_dtype)
8791
def _ceil(arr: pyarrow.Array) -> pyarrow.Array:
8892
return pc.ceil_temporal(arr, multiple=1, unit=unit)
93+
8994
return _ceil(self._expr)
9095

91-
def floor(self, unit: str) -> "UDFExpr":
92-
"""Floor timestamps down to the previous boundary of the given unit."""
96+
def floor(self, unit: TemporalUnit) -> "UDFExpr":
97+
"""Floor timestamps to the previous multiple of the given unit."""
9398
return_dtype = self._expr.data_type
99+
94100
@pyarrow_udf(return_dtype=return_dtype)
95101
def _floor(arr: pyarrow.Array) -> pyarrow.Array:
96102
return pc.floor_temporal(arr, multiple=1, unit=unit)
103+
97104
return _floor(self._expr)
98105

99-
def round(self, unit: str, tie_breaker: str = "half_to_even") -> "UDFExpr":
100-
"""Round timestamps to the nearest boundary of the given unit."""
106+
def round(self, unit: TemporalUnit) -> "UDFExpr":
107+
"""Round timestamps to the nearest multiple of the given unit."""
101108
return_dtype = self._expr.data_type
109+
102110
@pyarrow_udf(return_dtype=return_dtype)
103111
def _round(arr: pyarrow.Array) -> pyarrow.Array:
104-
return pc.round_temporal(arr, multiple=1, unit=unit,
105-
tie_breaker=tie_breaker)
112+
113+
return pc.round_temporal(arr, multiple=1, unit=unit)
114+
106115
return _round(self._expr)

python/ray/data/tests/test_namespace_expressions.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
"""
66

77
from typing import Any
8+
import datetime
89

910
import pandas as pd
1011
import pyarrow as pa
@@ -521,6 +522,77 @@ def test_struct_nested_bracket(self, dataset_format):
521522
assert_df_equal(result, expected)
522523

523524

525+
# ──────────────────────────────────────
526+
# Struct Namespace Tests
527+
# ──────────────────────────────────────
528+
529+
530+
def test_dt_namespace_extractors(ray_start_regular):
531+
ds = ray.data.from_items(
532+
[
533+
{
534+
"ts": datetime.datetime(2024, 1, 2, 3, 4, 5),
535+
}
536+
]
537+
)
538+
539+
result_ds = ds.select(
540+
[
541+
col("ts").dt.year().alias("year"),
542+
col("ts").dt.month().alias("month"),
543+
col("ts").dt.day().alias("day"),
544+
col("ts").dt.hour().alias("hour"),
545+
col("ts").dt.minute().alias("minute"),
546+
col("ts").dt.second().alias("second"),
547+
]
548+
)
549+
550+
row = result_ds.take(1)[0]
551+
assert row["year"] == 2024
552+
assert row["month"] == 1
553+
assert row["day"] == 2
554+
assert row["hour"] == 3
555+
assert row["minute"] == 4
556+
assert row["second"] == 5
557+
558+
559+
def test_dt_namespace_strftime(ray_start_regular):
560+
ds = ray.data.from_items(
561+
[
562+
{
563+
"ts": datetime.datetime(2024, 1, 2, 3, 4, 5),
564+
}
565+
]
566+
)
567+
568+
result_ds = ds.select(
569+
[col("ts").dt.strftime("%Y-%m-%d").alias("date_str")]
570+
)
571+
572+
row = result_ds.take(1)[0]
573+
assert row["date_str"] == "2024-01-02"
574+
575+
576+
def test_dt_namespace_rounding(ray_start_regular):
577+
ts = datetime.datetime(2024, 1, 2, 10, 30, 0)
578+
579+
ds = ray.data.from_items([{"ts": ts}])
580+
581+
floored = ds.select(
582+
[col("ts").dt.floor("day").alias("ts_floor")]
583+
).take(1)[0]["ts_floor"]
584+
ceiled = ds.select(
585+
[col("ts").dt.ceil("day").alias("ts_ceil")]
586+
).take(1)[0]["ts_ceil"]
587+
rounded = ds.select(
588+
[col("ts").dt.round("day").alias("ts_round")]
589+
).take(1)[0]["ts_round"]
590+
591+
assert floored == datetime.datetime(2024, 1, 2, 0, 0, 0)
592+
assert ceiled == datetime.datetime(2024, 1, 3, 0, 0, 0)
593+
assert rounded == datetime.datetime(2024, 1, 3, 0, 0, 0)
594+
595+
524596
# ──────────────────────────────────────
525597
# Integration Tests
526598
# ──────────────────────────────────────

0 commit comments

Comments
 (0)