diff --git a/examples/topic/basic_example.py b/examples/topic/basic_example.py index 50dd9a5d..fce1c4b7 100644 --- a/examples/topic/basic_example.py +++ b/examples/topic/basic_example.py @@ -7,9 +7,9 @@ async def connect(endpoint: str, database: str) -> ydb.aio.Driver: config = ydb.DriverConfig(endpoint=endpoint, database=database) - config.credentials = ydb.credentials_from_env_variables() + # config.credentials = ydb.credentials_from_env_variables() driver = ydb.aio.Driver(config) - await driver.wait(15) + await driver.wait(5,fail_fast=True) return driver @@ -25,7 +25,11 @@ async def create_topic(driver: ydb.aio.Driver, topic: str, consumer: str): async def write_messages(driver: ydb.aio.Driver, topic: str): async with driver.topic_client.writer(topic) as writer: for i in range(10): - await writer.write(f"mess-{i}") + mess = ydb.TopicWriterMessage( + data = f"mess-{i}", + metadata_items= {"index": f"{i}"} + ) + await writer.write(mess) await asyncio.sleep(1) @@ -38,6 +42,7 @@ async def read_messages(driver: ydb.aio.Driver, topic: str, consumer: str): print(mess.seqno) print(mess.created_at) print(mess.data.decode()) + print(mess.metadata_items) reader.commit(mess) except asyncio.TimeoutError: return diff --git a/ydb/_grpc/grpcwrapper/ydb_topic.py b/ydb/_grpc/grpcwrapper/ydb_topic.py index c1789b6c..b001504b 100644 --- a/ydb/_grpc/grpcwrapper/ydb_topic.py +++ b/ydb/_grpc/grpcwrapper/ydb_topic.py @@ -208,6 +208,7 @@ class MessageData(IToProto): data: bytes uncompressed_size: int partitioning: "StreamWriteMessage.PartitioningType" + metadata_items: Dict[str, bytes] def to_proto( self, @@ -218,6 +219,11 @@ def to_proto( proto.data = self.data proto.uncompressed_size = self.uncompressed_size + for key, value in self.metadata_items.items(): + # TODO: CHECK + item = ydb_topic_pb2.MetadataItem(key=key, value=value) + proto.metadata_items.append(item) + if self.partitioning is None: pass elif isinstance(self.partitioning, StreamWriteMessage.PartitioningPartitionID): @@ -489,16 +495,19 @@ class MessageData(IFromProto): data: bytes uncompresed_size: int message_group_id: str + metadata_items: Dict[str, bytes] @staticmethod def from_proto( msg: ydb_topic_pb2.StreamReadMessage.ReadResponse.MessageData, ) -> "StreamReadMessage.ReadResponse.MessageData": + metadata_items = {meta.key: meta.value for meta in msg.metadata_items} return StreamReadMessage.ReadResponse.MessageData( offset=msg.offset, seq_no=msg.seq_no, created_at=msg.created_at.ToDatetime(), data=msg.data, + metadata_items=metadata_items, uncompresed_size=msg.uncompressed_size, message_group_id=msg.message_group_id, ) diff --git a/ydb/_topic_reader/datatypes.py b/ydb/_topic_reader/datatypes.py index 01501638..a9c811ac 100644 --- a/ydb/_topic_reader/datatypes.py +++ b/ydb/_topic_reader/datatypes.py @@ -40,6 +40,7 @@ class PublicMessage(ICommittable, ISessionAlive): written_at: datetime.datetime producer_id: str data: Union[bytes, Any] # set as original decompressed bytes or deserialized object if deserializer set in reader + metadata_items: Dict[str, bytes] _partition_session: PartitionSession _commit_start_offset: int _commit_end_offset: int diff --git a/ydb/_topic_reader/topic_reader_asyncio.py b/ydb/_topic_reader/topic_reader_asyncio.py index 6833492d..54cb3804 100644 --- a/ydb/_topic_reader/topic_reader_asyncio.py +++ b/ydb/_topic_reader/topic_reader_asyncio.py @@ -624,6 +624,7 @@ def _read_response_to_batches(self, message: StreamReadMessage.ReadResponse) -> written_at=server_batch.written_at, producer_id=server_batch.producer_id, data=message_data.data, + metadata_items=message_data.metadata_items, _partition_session=partition_session, _commit_start_offset=partition_session._next_message_start_commit_offset, _commit_end_offset=message_data.offset + 1, diff --git a/ydb/_topic_writer/topic_writer.py b/ydb/_topic_writer/topic_writer.py index 527bf03e..f8d2add7 100644 --- a/ydb/_topic_writer/topic_writer.py +++ b/ydb/_topic_writer/topic_writer.py @@ -15,7 +15,7 @@ from .._grpc.grpcwrapper.ydb_topic_public_types import PublicCodec from .. import connection -Message = typing.Union["PublicMessage", "PublicMessage.SimpleMessageSourceType"] +Message = typing.Union["PublicMessage", "PublicMessage.SimpleSourceType"] @dataclass @@ -91,20 +91,23 @@ class PublicWriterInitInfo: class PublicMessage: seqno: Optional[int] created_at: Optional[datetime.datetime] - data: "PublicMessage.SimpleMessageSourceType" + data: "PublicMessage.SimpleSourceType" + metadata_items: Optional[Dict[str, "PublicMessage.SimpleSourceType"]] - SimpleMessageSourceType = Union[str, bytes] # Will be extend + SimpleSourceType = Union[str, bytes] # Will be extend def __init__( self, - data: SimpleMessageSourceType, + data: SimpleSourceType, *, + metadata_items: Optional[Dict[str, "PublicMessage.SimpleSourceType"]] = None, seqno: Optional[int] = None, created_at: Optional[datetime.datetime] = None, ): self.seqno = seqno self.created_at = created_at self.data = data + self.metadata_items = metadata_items @staticmethod def _create_message(data: Message) -> "PublicMessage": @@ -121,26 +124,29 @@ def __init__(self, mess: PublicMessage): seq_no=mess.seqno, created_at=mess.created_at, data=mess.data, + metadata_items=mess.metadata_items, uncompressed_size=len(mess.data), partitioning=None, ) self.codec = PublicCodec.RAW - def get_bytes(self) -> bytes: - if self.data is None: + def get_bytes(self, obj: Optional[PublicMessage.SimpleSourceType]) -> bytes: + if obj is None: return bytes() - if isinstance(self.data, bytes): - return self.data - if isinstance(self.data, str): - return self.data.encode("utf-8") + if isinstance(obj, bytes): + return obj + if isinstance(obj, str): + return obj.encode("utf-8") raise ValueError("Bad data type") def to_message_data(self) -> StreamWriteMessage.WriteRequest.MessageData: - data = self.get_bytes() + data = self.get_bytes(self.data) + metadata_items = {key: self.get_bytes(value) for key, value in self.metadata_items.items()} return StreamWriteMessage.WriteRequest.MessageData( seq_no=self.seq_no, created_at=self.created_at, data=data, + metadata_items=metadata_items, uncompressed_size=len(data), partitioning=None, # unsupported by server now ) @@ -221,6 +227,7 @@ def messages_to_proto_requests( seq_no=_max_int, created_at=datetime.datetime(3000, 1, 1, 1, 1, 1, 1), data=bytes(1), + metadata_items={}, uncompressed_size=_max_int, partitioning=StreamWriteMessage.PartitioningMessageGroupID( message_group_id="a" * 100,