Skip to content

Commit

Permalink
Consumer with MappingMessage (#8)
Browse files Browse the repository at this point in the history
* Consumer with MappingMessage
  • Loading branch information
aamalev authored Dec 19, 2024
1 parent 89c1467 commit d89a929
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 13 deletions.
78 changes: 71 additions & 7 deletions aioworkers_kafka/consumer.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
from typing import Dict, List, Optional
from dataclasses import dataclass
from typing import Any, Dict, List, Mapping, Optional

import confluent_kafka
from aioworkers.core.base import AbstractConnector, AbstractReader, ExecutorEntity
Expand All @@ -10,6 +11,54 @@
from aioworkers_kafka.utils import flat_conf


@dataclass
class IncommingMessage(Mapping):
value: Any
topic: str
key: Optional[bytes]
headers: Mapping[str, bytes]

def __getitem__(self, key: str) -> bytes:
return self.headers[key]

def __iter__(self):
return iter(())

def __len__(self) -> int:
return 0


@dataclass
class RawMessage(IncommingMessage):
value: bytes
topic: str
key: Optional[bytes]
headers: Mapping[str, bytes]


@dataclass
class MappingMessage(IncommingMessage):
value: Mapping[str, Any]
topic: str
key: Optional[bytes]
headers: Mapping[str, bytes]

def __getitem__(self, key: str) -> Any:
try:
return self.value[key]
except KeyError:
if v := self.headers.get(key):
return v
else:
raise

def __iter__(self):
return iter(self.value)

def __len__(self) -> int:
return len(self.value)


class KafkaConsumer(AbstractReader, FormattedEntity, ExecutorEntity, AbstractConnector):
def __init__(
self,
Expand Down Expand Up @@ -51,22 +100,37 @@ async def connect(self):
async def disconnect(self):
await self.run_in_executor(self.consumer.close)

async def get(self, timeout: Optional[float] = None):
async def get(self, timeout: Optional[float] = None) -> Optional[IncommingMessage]:
poll_timeout = timeout or 1.0
while True:
msg = await self.run_in_executor(self.consumer.poll, poll_timeout)
if msg is None:
if timeout is not None:
break
return None
elif msg.error():
if msg.error().code() not in {KafkaError._PARTITION_EOF}:
self.logger.error("Consume with error %s", msg.error())
await asyncio.sleep(1)
elif (h := msg.headers()) and (ct := h.get("content-type")):
f = self.registry.get(ct)
return f.decode(msg.value())
else:
return self.decode(msg.value())
return self.decode_msg(msg)

def decode_msg(self, msg: confluent_kafka.Message) -> IncommingMessage:
if h := msg.headers():
headers = dict(h)
else:
headers = {}

raw_value = msg.value()
if ct := headers.get("content-type"):
f = self.registry.get(ct.decode())
value = f.decode(msg.value())
else:
value = self.decode(msg.value())

if type(value) is bytes:
return RawMessage(value=raw_value, topic=msg.topic(), key=msg.key(), headers=headers)
else:
return MappingMessage(value=value, topic=msg.topic(), key=msg.key(), headers=headers)

async def __aenter__(self):
self.set_context(Context())
Expand Down
60 changes: 54 additions & 6 deletions tests/test_consumer.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,63 @@
import asyncio
import json

from aioworkers_kafka.consumer import KafkaConsumer
from aioworkers_kafka.consumer import KafkaConsumer, MappingMessage, RawMessage
from aioworkers_kafka.producer import KafkaProducer

CONTENT_TYPE = "application/json"


def test_decode_message_header_ct(mocker):
data = {"test": 1}
msg = mocker.Mock(
value=lambda: json.dumps(data).encode(),
headers=lambda: {"content-type": CONTENT_TYPE.encode()},
key=lambda: None,
topic=lambda: "test",
)
c = KafkaConsumer()
result = c.decode_msg(msg)
assert isinstance(result, MappingMessage)
assert dict(result) == data
assert result.value == data
assert result["content-type"] == CONTENT_TYPE.encode()


def test_decode_message_formatted(mocker):
data = {"test": 1}
msg = mocker.Mock(
value=lambda: json.dumps(data).encode(),
headers=lambda: None,
key=lambda: None,
topic=lambda: "test",
)
c = KafkaConsumer(content_type=CONTENT_TYPE)
result = c.decode_msg(msg)
assert isinstance(result, MappingMessage)
assert dict(result) == data
assert result.value == data


def test_decode_message_bytes(mocker):
data = {"test": 1}
b = json.dumps(data).encode()
msg = mocker.Mock(
value=lambda: b,
headers=lambda: {"a": b"b"},
key=lambda: None,
topic=lambda: "test",
)
c = KafkaConsumer()
result = c.decode_msg(msg)
assert isinstance(result, RawMessage)
assert result.value == b
assert result["a"] == b"b"


async def test_get(bootstrap_servers, topic, mocker):
data = {"test": 1}
ct = "application/json"
async with KafkaConsumer(
bootstrap_servers=bootstrap_servers, group_id="test", topics=[topic], content_type=ct
bootstrap_servers=bootstrap_servers, group_id="test", topics=[topic], content_type=CONTENT_TYPE
) as c:
m = mocker.patch.object(c, "consumer")
m.poll.return_value = None
Expand All @@ -19,10 +67,10 @@ async def test_get(bootstrap_servers, topic, mocker):
msg = m.poll.return_value = mocker.Mock()
msg.error.return_value = None
msg.value.return_value = json.dumps(data).encode()
msg.headers.return_value = {"content-type": ct}
msg.headers.return_value = {"content-type": CONTENT_TYPE.encode()}

async def produce():
async with KafkaProducer(content_type=ct) as p:
async with KafkaProducer(content_type=CONTENT_TYPE) as p:
for _ in range(2):
await asyncio.sleep(0.3)
await p.put(data, topic=topic)
Expand All @@ -34,4 +82,4 @@ async def produce():
result = await c.get()

await task
assert result == data
assert dict(result) == data

0 comments on commit d89a929

Please sign in to comment.