Skip to content

Commit

Permalink
Fix/1375 (#1377)
Browse files Browse the repository at this point in the history
* chore: update dependencies

* fix (#1375): correct Redis Batch message serialization

* fix (#1376): fix Redis connection options priority

* tests: fix Redis connection tests (correct ping usage
  • Loading branch information
Lancetnik authored Apr 17, 2024
1 parent 463c953 commit a8d3142
Show file tree
Hide file tree
Showing 12 changed files with 175 additions and 122 deletions.
16 changes: 8 additions & 8 deletions faststream/rabbit/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,30 +10,30 @@
gen_cor_id,
)
from faststream.rabbit.message import RabbitMessage
from faststream.utils.context.repository import context

if TYPE_CHECKING:
from re import Pattern

from aio_pika import IncomingMessage
from aio_pika.abc import DateType, HeadersType

from faststream.rabbit.subscriber.usecase import LogicSubscriber
from faststream.rabbit.types import AioPikaSendableMessage
from faststream.types import DecodedMessage


class AioPikaParser:
"""A class for parsing, encoding, and decoding messages using aio-pika."""

@staticmethod
def __init__(self, pattern: Optional["Pattern[str]"] = None) -> None:
self.pattern = pattern

async def parse_message(
self,
message: "IncomingMessage",
) -> StreamMessage["IncomingMessage"]:
"""Parses an incoming message and returns a RabbitMessage object."""
handler: Optional["LogicSubscriber"] = context.get_local("handler_")
if (
handler is not None
and (path_re := handler.queue.path_regex)
and (match := path_re.match(message.routing_key or ""))
if (path_re := self.pattern) and (
match := path_re.match(message.routing_key or "")
):
path = match.groupdict()
else:
Expand Down
7 changes: 5 additions & 2 deletions faststream/rabbit/publisher/producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,13 @@ def __init__(
) -> None:
self._channel = channel
self.declarer = declarer
self._parser = resolve_custom_func(parser, AioPikaParser.parse_message)
self._decoder = resolve_custom_func(decoder, AioPikaParser.decode_message)

self._rpc_lock = anyio.Lock()

default_parser = AioPikaParser()
self._parser = resolve_custom_func(parser, default_parser.parse_message)
self._decoder = resolve_custom_func(decoder, default_parser.decode_message)

@override
async def publish( # type: ignore[override]
self,
Expand Down
6 changes: 4 additions & 2 deletions faststream/rabbit/subscriber/usecase.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,11 @@ def __init__(
description_: Optional[str],
include_in_schema: bool,
) -> None:
parser = AioPikaParser(pattern=queue.path_regex)

super().__init__(
default_parser=AioPikaParser.parse_message,
default_decoder=AioPikaParser.decode_message,
default_parser=parser.parse_message,
default_decoder=parser.decode_message,
# Propagated options
no_ack=no_ack,
retry=retry,
Expand Down
70 changes: 37 additions & 33 deletions faststream/redis/broker/broker.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,10 @@ def __init__(
url: str = "redis://localhost:6379",
polling_interval: Optional[float] = None,
*,
host: str = "localhost",
port: Union[str, int] = 6379,
db: Union[str, int] = 0,
host: Union[str, object] = Parameter.empty,
port: Union[str, int, object] = Parameter.empty,
db: Union[str, int, object] = Parameter.empty,
connection_class: Union[Type["Connection"], object] = Parameter.empty,
client_name: Optional[str] = None,
health_check_interval: float = 0,
max_connections: Optional[int] = None,
Expand All @@ -111,7 +112,6 @@ def __init__(
encoding_errors: str = "strict",
decode_responses: bool = False,
parser_class: Type["BaseParser"] = DefaultParser,
connection_class: Type["Connection"] = Connection,
encoder_class: Type["Encoder"] = Encoder,
# broker args
graceful_timeout: Annotated[
Expand Down Expand Up @@ -272,9 +272,10 @@ async def _connect( # type: ignore[override]
self,
url: str,
*,
host: str,
port: Union[str, int],
db: Union[str, int],
host: Union[str, object],
port: Union[str, int, object],
db: Union[str, int, object],
connection_class: Union[Type["Connection"], object],
client_name: Optional[str],
health_check_interval: float,
max_connections: Optional[int],
Expand All @@ -289,34 +290,37 @@ async def _connect( # type: ignore[override]
encoding_errors: str,
decode_responses: bool,
parser_class: Type["BaseParser"],
connection_class: Type["Connection"],
encoder_class: Type["Encoder"],
) -> "Redis[bytes]":
url_options: "AnyDict" = dict(parse_url(url))
url_options.update(
{
"host": host,
"port": port,
"db": db,
"client_name": client_name,
"health_check_interval": health_check_interval,
"max_connections": max_connections,
"socket_timeout": socket_timeout,
"socket_connect_timeout": socket_connect_timeout,
"socket_read_size": socket_read_size,
"socket_keepalive": socket_keepalive,
"socket_keepalive_options": socket_keepalive_options,
"socket_type": socket_type,
"retry_on_timeout": retry_on_timeout,
"encoding": encoding,
"encoding_errors": encoding_errors,
"decode_responses": decode_responses,
"parser_class": parser_class,
"connection_class": connection_class,
"encoder_class": encoder_class,
}
)
url_options.update(parse_security(self.security))
url_options: "AnyDict" = {
**dict(parse_url(url)),
**parse_security(self.security),
"client_name": client_name,
"health_check_interval": health_check_interval,
"max_connections": max_connections,
"socket_timeout": socket_timeout,
"socket_connect_timeout": socket_connect_timeout,
"socket_read_size": socket_read_size,
"socket_keepalive": socket_keepalive,
"socket_keepalive_options": socket_keepalive_options,
"socket_type": socket_type,
"retry_on_timeout": retry_on_timeout,
"encoding": encoding,
"encoding_errors": encoding_errors,
"decode_responses": decode_responses,
"parser_class": parser_class,
"encoder_class": encoder_class,
}

if port is not Parameter.empty:
url_options["port"] = port
if host is not Parameter.empty:
url_options["host"] = host
if db is not Parameter.empty:
url_options["db"] = db
if connection_class is not Parameter.empty:
url_options["connection_class"] = connection_class

pool = ConnectionPool(
**url_options,
lib_name="faststream",
Expand Down
62 changes: 33 additions & 29 deletions faststream/redis/parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
Type,
TypeVar,
Union,
cast,
)

from faststream._compat import dump_json, json_loads
Expand All @@ -27,12 +26,11 @@
bDATA_KEY,
)
from faststream.types import AnyDict, DecodedMessage, SendableMessage
from faststream.utils.context.repository import context

if TYPE_CHECKING:
from re import Pattern

from faststream.broker.message import StreamMessage
from faststream.redis.schemas import PubSub
from faststream.redis.subscriber.usecase import ChannelSubscriber


MsgType = TypeVar("MsgType", bound=Mapping[str, Any])
Expand Down Expand Up @@ -127,16 +125,22 @@ def parse(data: bytes) -> Tuple[bytes, "AnyDict"]:
class SimpleParser:
msg_class: Type["StreamMessage[Any]"]

@classmethod
def __init__(
self,
pattern: Optional["Pattern[str]"] = None,
) -> None:
self.pattern = pattern

async def parse_message(
cls, message: Mapping[str, Any]
self,
message: Mapping[str, Any],
) -> "StreamMessage[Mapping[str, Any]]":
data, headers = cls._parse_data(message)
data, headers = self._parse_data(message)
id_ = gen_cor_id()
return cls.msg_class(
return self.msg_class(
raw_message=message,
body=data,
path=cls.get_path(message),
path=self.get_path(message),
headers=headers,
reply_to=headers.get("reply_to", ""),
content_type=headers.get("content-type"),
Expand All @@ -148,9 +152,16 @@ async def parse_message(
def _parse_data(message: Mapping[str, Any]) -> Tuple[bytes, "AnyDict"]:
return RawMessage.parse(message["data"])

@staticmethod
def get_path(message: Mapping[str, Any]) -> "AnyDict":
return {}
def get_path(self, message: Mapping[str, Any]) -> "AnyDict":
if (
message.get("pattern")
and (path_re := self.pattern)
and (match := path_re.match(message["channel"]))
):
return match.groupdict()

else:
return {}

@staticmethod
async def decode_message(
Expand All @@ -162,20 +173,6 @@ async def decode_message(
class RedisPubSubParser(SimpleParser):
msg_class = RedisMessage

@staticmethod
def get_path(message: Mapping[str, Any]) -> "AnyDict":
if (
message.get("pattern")
and (handler := cast("ChannelSubscriber", context.get_local("handler_")))
and (channel := cast(Optional["PubSub"], getattr(handler, "channel", None)))
and (path_re := channel.path_regex)
and (match := path_re.match(message["channel"]))
):
return match.groupdict()

else:
return {}


class RedisListParser(SimpleParser):
msg_class = RedisListMessage
Expand All @@ -187,7 +184,7 @@ class RedisBatchListParser(SimpleParser):
@staticmethod
def _parse_data(message: Mapping[str, Any]) -> Tuple[bytes, "AnyDict"]:
return (
dump_json(RawMessage.parse(x)[0] for x in message["data"]),
dump_json(_decode_batch_body_item(x) for x in message["data"]),
{"content-type": ContentTypes.json},
)

Expand All @@ -208,8 +205,15 @@ class RedisBatchStreamParser(SimpleParser):
def _parse_data(message: Mapping[str, Any]) -> Tuple[bytes, "AnyDict"]:
return (
dump_json(
RawMessage.parse(data)[0] if (data := x.get(bDATA_KEY)) else x
for x in message["data"]
_decode_batch_body_item(x.get(bDATA_KEY, x)) for x in message["data"]
),
{"content-type": ContentTypes.json},
)


def _decode_batch_body_item(msg_content: bytes) -> Any:
msg_body, _ = RawMessage.parse(msg_content)
try:
return json_loads(msg_body)
except Exception:
return msg_body
4 changes: 2 additions & 2 deletions faststream/redis/publisher/producer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,11 +37,11 @@ def __init__(
self._connection = connection
self._parser = resolve_custom_func(
parser,
RedisPubSubParser.parse_message,
RedisPubSubParser().parse_message,
)
self._decoder = resolve_custom_func(
decoder,
RedisPubSubParser.decode_message,
RedisPubSubParser().decode_message,
)

@override
Expand Down
25 changes: 15 additions & 10 deletions faststream/redis/subscriber/usecase.py
Original file line number Diff line number Diff line change
Expand Up @@ -215,9 +215,10 @@ def __init__(
description_: Optional[str],
include_in_schema: bool,
) -> None:
parser = RedisPubSubParser(pattern=channel.path_regex)
super().__init__(
default_parser=RedisPubSubParser.parse_message,
default_decoder=RedisPubSubParser.decode_message,
default_parser=parser.parse_message,
default_decoder=parser.decode_message,
# Propagated options
no_ack=no_ack,
retry=retry,
Expand Down Expand Up @@ -367,10 +368,11 @@ def __init__(
description_: Optional[str],
include_in_schema: bool,
) -> None:
parser = RedisListParser()
super().__init__(
list=list,
default_parser=RedisListParser.parse_message,
default_decoder=RedisListParser.decode_message,
default_parser=parser.parse_message,
default_decoder=parser.decode_message,
# Propagated options
no_ack=no_ack,
retry=retry,
Expand Down Expand Up @@ -413,10 +415,11 @@ def __init__(
description_: Optional[str],
include_in_schema: bool,
) -> None:
parser = RedisBatchListParser()
super().__init__(
list=list,
default_parser=RedisBatchListParser.parse_message,
default_decoder=RedisBatchListParser.decode_message,
default_parser=parser.parse_message,
default_decoder=parser.decode_message,
# Propagated options
no_ack=no_ack,
retry=retry,
Expand Down Expand Up @@ -610,10 +613,11 @@ def __init__(
description_: Optional[str],
include_in_schema: bool,
) -> None:
parser = RedisStreamParser()
super().__init__(
stream=stream,
default_parser=RedisStreamParser.parse_message,
default_decoder=RedisStreamParser.decode_message,
default_parser=parser.parse_message,
default_decoder=parser.decode_message,
# Propagated options
no_ack=no_ack,
retry=retry,
Expand Down Expand Up @@ -676,10 +680,11 @@ def __init__(
description_: Optional[str],
include_in_schema: bool,
) -> None:
parser = RedisBatchStreamParser()
super().__init__(
stream=stream,
default_parser=RedisBatchStreamParser.parse_message,
default_decoder=RedisBatchStreamParser.decode_message,
default_parser=parser.parse_message,
default_decoder=parser.decode_message,
# Propagated options
no_ack=no_ack,
retry=retry,
Expand Down
6 changes: 3 additions & 3 deletions tests/brokers/base/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,22 +27,22 @@ async def test_close_before_start(self, async_mock):
async def test_init_connect_by_url(self, settings):
kwargs = self.get_broker_args(settings)
broker = self.broker(**kwargs)
assert await broker.connect()
await broker.connect()
assert await self.ping(broker)
await broker.close()

@pytest.mark.asyncio()
async def test_connection_by_url(self, settings):
kwargs = self.get_broker_args(settings)
broker = self.broker()
assert await broker.connect(**kwargs)
await broker.connect(**kwargs)
assert await self.ping(broker)
await broker.close()

@pytest.mark.asyncio()
async def test_connect_by_url_priority(self, settings):
kwargs = self.get_broker_args(settings)
broker = self.broker("wrong_url")
assert await broker.connect(**kwargs)
await broker.connect(**kwargs)
assert await self.ping(broker)
await broker.close()
Loading

0 comments on commit a8d3142

Please sign in to comment.