Skip to content

Commit

Permalink
Implement tumbling_count_window
Browse files Browse the repository at this point in the history
  • Loading branch information
quentin-quix committed Jan 14, 2025
1 parent 53515e5 commit 89c3ac0
Show file tree
Hide file tree
Showing 11 changed files with 437 additions and 236 deletions.
36 changes: 27 additions & 9 deletions quixstreams/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,10 @@
from .series import StreamingSeries
from .utils import ensure_milliseconds
from .windows import (
HoppingWindowDefinition,
SlidingWindowDefinition,
TumblingWindowDefinition,
CountTumblingWindowDefinition,
FixedTimeHoppingWindowDefinition,
FixedTimeSlidingWindowDefinition,
FixedTimeTumblingWindowDefinition,
)

ApplyCallbackStateful = Callable[[Any, State], Any]
Expand Down Expand Up @@ -844,7 +845,15 @@ def tumbling_window(
duration_ms: Union[int, timedelta],
grace_ms: Union[int, timedelta] = 0,
name: Optional[str] = None,
) -> TumblingWindowDefinition:
) -> FixedTimeTumblingWindowDefinition:
return self.tumbling_time_window(duration_ms, grace_ms, name)

def tumbling_time_window(
self,
duration_ms: Union[int, timedelta],
grace_ms: Union[int, timedelta] = 0,
name: Optional[str] = None,
) -> FixedTimeTumblingWindowDefinition:
"""
Create a tumbling window transformation on this StreamingDataFrame.
Tumbling windows divide time into fixed-sized, non-overlapping windows.
Expand Down Expand Up @@ -911,17 +920,26 @@ def tumbling_window(
duration_ms = ensure_milliseconds(duration_ms)
grace_ms = ensure_milliseconds(grace_ms)

return TumblingWindowDefinition(
return FixedTimeTumblingWindowDefinition(
duration_ms=duration_ms, grace_ms=grace_ms, dataframe=self, name=name
)

def tumbling_count_window(
self, count: int, name: Optional[str] = None
) -> CountTumblingWindowDefinition:
return CountTumblingWindowDefinition(
count=count,
dataframe=self,
name=name,
)

def hopping_window(
self,
duration_ms: Union[int, timedelta],
step_ms: Union[int, timedelta],
grace_ms: Union[int, timedelta] = 0,
name: Optional[str] = None,
) -> HoppingWindowDefinition:
) -> FixedTimeHoppingWindowDefinition:
"""
Create a hopping window transformation on this StreamingDataFrame.
Hopping windows divide the data stream into overlapping windows based on time.
Expand Down Expand Up @@ -999,7 +1017,7 @@ def hopping_window(
step_ms = ensure_milliseconds(step_ms)
grace_ms = ensure_milliseconds(grace_ms)

return HoppingWindowDefinition(
return FixedTimeHoppingWindowDefinition(
duration_ms=duration_ms,
grace_ms=grace_ms,
step_ms=step_ms,
Expand All @@ -1012,7 +1030,7 @@ def sliding_window(
duration_ms: Union[int, timedelta],
grace_ms: Union[int, timedelta] = 0,
name: Optional[str] = None,
) -> SlidingWindowDefinition:
) -> FixedTimeSlidingWindowDefinition:
"""
Create a sliding window transformation on this StreamingDataFrame.
Sliding windows continuously evaluate the stream with a fixed step of 1 ms
Expand Down Expand Up @@ -1084,7 +1102,7 @@ def sliding_window(
duration_ms = ensure_milliseconds(duration_ms)
grace_ms = ensure_milliseconds(grace_ms)

return SlidingWindowDefinition(
return FixedTimeSlidingWindowDefinition(
duration_ms=duration_ms, grace_ms=grace_ms, dataframe=self, name=name
)

Expand Down
14 changes: 8 additions & 6 deletions quixstreams/dataframe/windows/__init__.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,15 @@
from .base import WindowResult
from .definitions import (
HoppingWindowDefinition,
SlidingWindowDefinition,
TumblingWindowDefinition,
CountTumblingWindowDefinition,
FixedTimeHoppingWindowDefinition,
FixedTimeSlidingWindowDefinition,
FixedTimeTumblingWindowDefinition,
)

__all__ = [
"HoppingWindowDefinition",
"SlidingWindowDefinition",
"TumblingWindowDefinition",
"FixedTimeHoppingWindowDefinition",
"FixedTimeSlidingWindowDefinition",
"FixedTimeTumblingWindowDefinition",
"CountTumblingWindowDefinition",
"WindowResult",
]
180 changes: 178 additions & 2 deletions quixstreams/dataframe/windows/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,21 @@
import abc
import functools
import logging
from abc import abstractmethod
from collections import deque
from typing import Any, Callable, Deque, Optional
from typing import Any, Callable, Deque, Iterable, Optional, cast

from typing_extensions import TypedDict
from typing_extensions import TYPE_CHECKING, TypedDict

from quixstreams.context import message_context
from quixstreams.core.stream import TransformExpandedCallback
from quixstreams.processing import ProcessingContext
from quixstreams.state import WindowedPartitionTransaction, WindowedState

if TYPE_CHECKING:
from quixstreams.dataframe.dataframe import StreamingDataFrame

logger = logging.getLogger(__name__)


class WindowResult(TypedDict):
Expand All @@ -13,6 +27,164 @@ class WindowResult(TypedDict):
WindowAggregateFunc = Callable[[Any, Any], Any]
WindowMergeFunc = Callable[[Any], Any]

TransformRecordCallbackExpandedWindowed = Callable[
[Any, Any, int, Any, WindowedState], list[tuple[WindowResult, Any, int, Any]]
]


class Window(abc.ABC):
def __init__(
self,
name: str,
dataframe: "StreamingDataFrame",
) -> None:
if not name:
raise ValueError("Window name must not be empty")

self._name = name
self._dataframe = dataframe

@property
def name(self) -> str:
return self._name

@abstractmethod
def process_window(
self,
value: Any,
timestamp_ms: int,
state: WindowedState,
) -> tuple[Iterable[WindowResult], Iterable[WindowResult]]:
pass

def register_store(self):
self._dataframe.processing_context.state_manager.register_windowed_store(
topic_name=self._dataframe.topic.name, store_name=self._name
)

def _apply_window(
self,
func: TransformRecordCallbackExpandedWindowed,
name: str,
) -> "StreamingDataFrame":
self.register_store()

windowed_func = _as_windowed(
func=func,
processing_context=self._dataframe.processing_context,
store_name=name,
)
# Manually modify the Stream and clone the source StreamingDataFrame
# to avoid adding "transform" API to it.
# Transform callbacks can modify record key and timestamp,
# and it's prone to misuse.
stream = self._dataframe.stream.add_transform(func=windowed_func, expand=True)
return self._dataframe.__dataframe_clone__(stream=stream)

def final(self) -> "StreamingDataFrame":
"""
Apply the window aggregation and return results only when the windows are
closed.
The format of returned windows:
```python
{
"start": <window start time in milliseconds>,
"end": <window end time in milliseconds>,
"value: <aggregated window value>,
}
```
The individual window is closed when the event time
(the maximum observed timestamp across the partition) passes
its end timestamp + grace period.
The closed windows cannot receive updates anymore and are considered final.
>***NOTE:*** Windows can be closed only within the same message key.
If some message keys appear irregularly in the stream, the latest windows
can remain unprocessed until the message the same key is received.
"""

def window_callback(
value: Any, key: Any, timestamp_ms: int, _headers: Any, state: WindowedState
) -> list[tuple[WindowResult, Any, int, Any]]:
_, expired_windows = self.process_window(
value=value, timestamp_ms=timestamp_ms, state=state
)
# Use window start timestamp as a new record timestamp
return [(window, key, window["start"], None) for window in expired_windows]

return self._apply_window(
func=window_callback,
name=self._name,
)

def current(self) -> "StreamingDataFrame":
"""
Apply the window transformation to the StreamingDataFrame to return results
for each updated window.
The format of returned windows:
```python
{
"start": <window start time in milliseconds>,
"end": <window end time in milliseconds>,
"value: <aggregated window value>,
}
```
This method processes streaming data and returns results as they come,
regardless of whether the window is closed or not.
"""

def window_callback(
value: Any, key: Any, timestamp_ms: int, _headers: Any, state: WindowedState
) -> list[tuple[WindowResult, Any, int, Any]]:
updated_windows, _ = self.process_window(
value=value, timestamp_ms=timestamp_ms, state=state
)
return [(window, key, window["start"], None) for window in updated_windows]

return self._apply_window(func=window_callback, name=self._name)


def _noop() -> Any:
"""
No-operation function for skipping messages due to None keys.
Messages with None keys are ignored because keys are essential for performing
accurate and meaningful windowed aggregation.
"""
return []


def _as_windowed(
func: TransformRecordCallbackExpandedWindowed,
processing_context: ProcessingContext,
store_name: str,
) -> TransformExpandedCallback:
@functools.wraps(func)
def wrapper(
value: Any, key: Any, timestamp: int, headers: Any
) -> list[tuple[WindowResult, Any, int, Any]]:
ctx = message_context()
transaction = cast(
WindowedPartitionTransaction,
processing_context.checkpoint.get_store_transaction(
topic=ctx.topic, partition=ctx.partition, store_name=store_name
),
)
if key is None:
logger.warning(
f"Skipping window processing for a message because the key is None, "
f"partition='{ctx.topic}[{ctx.partition}]' offset='{ctx.offset}'."
)
return _noop()
state = transaction.as_state(prefix=key)
return func(value, key, timestamp, headers, state)

return wrapper


def get_window_ranges(
timestamp_ms: int, duration_ms: int, step_ms: Optional[int] = None
Expand All @@ -38,3 +210,7 @@ def get_window_ranges(
current_window_start -= step_ms

return window_ranges


def default_merge_func(state_value: Any) -> Any:
return state_value
72 changes: 72 additions & 0 deletions quixstreams/dataframe/windows/count_based.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import logging
from typing import (
TYPE_CHECKING,
Any,
Iterable,
Optional,
)

from quixstreams.state import WindowedState

from .base import (
Window,
WindowAggregateFunc,
WindowMergeFunc,
WindowResult,
default_merge_func,
)

if TYPE_CHECKING:
from quixstreams.dataframe.dataframe import StreamingDataFrame


logger = logging.getLogger(__name__)


class FixedCountWindow(Window):
def __init__(
self,
name: str,
count: int,
aggregate_func: WindowAggregateFunc,
aggregate_default: Any,
dataframe: "StreamingDataFrame",
merge_func: Optional[WindowMergeFunc] = None,
):
super().__init__(name, dataframe)

self._max_count = count
self._aggregate_func = aggregate_func
self._aggregate_default = aggregate_default
self._merge_func = merge_func or default_merge_func

def process_window(
self,
value: Any,
timestamp_ms: int,
state: WindowedState,
) -> tuple[Iterable[WindowResult], Iterable[WindowResult]]:
data = state.get(key="window")
if data is None:
metadata = {"count": 0, "start": timestamp_ms}
previous_value = self._aggregate_default
else:
metadata, previous_value = data

aggregated = self._aggregate_func(previous_value, value)

metadata["count"] += 1
windows = [
WindowResult(
start=metadata["start"],
end=timestamp_ms,
value=self._merge_func(aggregated),
)
]

if metadata["count"] >= self._max_count:
state.delete(key="window")
return windows, windows

state.set(key="window", value=(metadata, aggregated))
return windows, []
Loading

0 comments on commit 89c3ac0

Please sign in to comment.