Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Azure Tables integration | Support needed #1560

Open
saimanoj1206 opened this issue Dec 2, 2024 · 4 comments
Open

Azure Tables integration | Support needed #1560

saimanoj1206 opened this issue Dec 2, 2024 · 4 comments
Labels
data layer Pertains to data layers.

Comments

@saimanoj1206
Copy link

saimanoj1206 commented Dec 2, 2024

Hi,
I am currently working on integrating Azure Tables as the data layer for our application. I used AWS DynamoDB data layer implementation and translated it for Azure Tables. While the data is being successfully written to the Azure Tables, I am encountering an error when attempting to resume a thread on the UI (nothing is coming up in logs though)

RangeError: Invalid time value
    at iEt (http://localhost:8000/assets/index-DLRdQOIx.js:723:67225)
    at V3 (http://localhost:8000/assets/index-DLRdQOIx.js:66:19552)
    at KN (http://localhost:8000/assets/index-DLRdQOIx.js:68:3143)
    at mW (http://localhost:8000/assets/index-DLRdQOIx.js:68:44880)
    at dW (http://localhost:8000/assets/index-DLRdQOIx.js:68:39822)
    at B0e (http://localhost:8000/assets/index-DLRdQOIx.js:68:39748)
    at Eb (http://localhost:8000/assets/index-DLRdQOIx.js:68:39598)
    at oI (http://localhost:8000/assets/index-DLRdQOIx.js:68:35960)
    at Hk (http://localhost:8000/assets/index-DLRdQOIx.js:68:36765)
    at eu (http://localhost:8000/assets/index-DLRdQOIx.js:66:3288)

I did check my logs and last method that was called is get_thread

2024-12-02 16:11:30 - Azure Tables: get_user identifier=test
2024-12-02 16:11:30 - Azure Tables: get_user identifier=test
2024-12-02 16:11:31 - Translated markdown file for en-US not found. Defaulting to chainlit.md.
2024-12-02 16:11:31 - Azure Tables: get_user identifier=test
2024-12-02 16:11:31 - Azure Tables: get_thread_author thread=1f28e4d4-5e0f-4bce-afe2-14f7d37f9691
2024-12-02 16:11:31 - Azure Tables: list_threads filters.userId=test
2024-12-02 16:11:31 - Azure Tables: get_user identifier=test
2024-12-02 16:11:31 - Unclosed client session
client_session: <aiohttp.client.ClientSession object at 0x7fc7c123e520>
2024-12-02 16:11:31 - Azure Tables: get_thread thread=1f28e4d4-5e0f-4bce-afe2-14f7d37f9691
[{'PartitionKey': 'THREAD-1f28e4d4-5e0f-4bce-afe2-14f7d37f9691', 'RowKey': 'THREAD', 'id': '1f28e4d4-5e0f-4bce-afe2-14f7d37f9691', 'metadata': {'id': '716ba446-e7c3-41ad-b959-d76ca6e06add', 'env': '{}', 'chat_settings': {}, 'user': None, 'chat_profile': None, 'http_referer': 'http://localhost:8000/thread/1f28e4d4-5e0f-4bce-afe2-14f7d37f9691', 'client_type': 'webapp', 'languages': 'en-US,en;q=0.9'}, 'name': 'Hi', 'timestamp': '2024-11-29T15:30:15.649083Z', 'userId': 'test', 'userIdentifier': 'test', 'steps': [{'PartitionKey': 'THREAD-1f28e4d4-5e0f-4bce-afe2-14f7d37f9691', 'RowKey': 'STEP-b3dce18a-a5d3-47d5-95d7-80521901b757', 'createdAt': '2024-11-28T18:57:06.401370Z', 'end': '2024-11-28T18:57:06.404655Z', 'id': 'b3dce18a-a5d3-47d5-95d7-80521901b757', 'input': '{}', 'isError': False, 'metadata': {}, 'name': 'on_chat_start', 'output': '', 'showInput': 'json', 'start': '2024-11-28T18:57:06.401394Z', 'streaming': False, 'threadId': '1f28e4d4-5e0f-4bce-afe2-14f7d37f9691', 'type': 'run'}, {'PartitionKey': 'THREAD-1f28e4d4-5e0f-4bce-afe2-14f7d37f9691', 'RowKey': 'STEP-0d7ccafe-a13c-4ceb-9626-59c85968dac1', 'id': '0d7ccafe-a13c-4ceb-9626-59c85968dac1', 'threadId': '1f28e4d4-5e0f-4bce-afe2-14f7d37f9691', 'parentId': 'b3dce18a-a5d3-47d5-95d7-80521901b757', 'createdAt': '2024-11-28T18:57:06.403676Z', 'start': '2024-11-28T18:57:06.403676Z', 'end': '2024-11-28T18:57:06.403676Z', 'output': 'Hi, how can I help you?', 'name': 'Assistant', 'type': 'assistant_message', 'streaming': False, 'isError': False, 'waitForAnswer': False, 'metadata': {}}, {'PartitionKey': 'THREAD-1f28e4d4-5e0f-4bce-afe2-14f7d37f9691', 'RowKey': 'STEP-1e9d5388-c5e3-4a90-bf3a-2aa240979c04', 'id': '1e9d5388-c5e3-4a90-bf3a-2aa240979c04', 'threadId': '1f28e4d4-5e0f-4bce-afe2-14f7d37f9691', 'createdAt': '2024-11-28T18:57:07.773552Z', 'start': '2024-11-28T18:57:07.773552Z', 'end': '2024-11-28T18:57:07.773552Z', 'output': 'Hi', 'name': 'test', 'type': 'user_message', 'streaming': False, 'isError': False, 'waitForAnswer': False, 'metadata': {}}, {'PartitionKey': 'THREAD-1f28e4d4-5e0f-4bce-afe2-14f7d37f9691', 'RowKey': 'STEP-a4e2d870-c23a-44aa-a322-acc5552024ab', 'name': 'on_message', 'type': 'run', 'id': 'a4e2d870-c23a-44aa-a322-acc5552024ab', 'threadId': '1f28e4d4-5e0f-4bce-afe2-14f7d37f9691', 'parentId': '1e9d5388-c5e3-4a90-bf3a-2aa240979c04', 'streaming': False, 'metadata': {}, 'input': '', 'isError': False, 'output': '', 'createdAt': '2024-11-28T18:57:07.776126Z', 'start': '2024-11-28T18:57:07.776146Z', 'showInput': 'json'}, {'PartitionKey': 'THREAD-1f28e4d4-5e0f-4bce-afe2-14f7d37f9691', 'RowKey': 'STEP-90fb44f0-3912-4c43-82c1-b949f3fb02d1', 'id': '90fb44f0-3912-4c43-82c1-b949f3fb02d1', 'threadId': '1f28e4d4-5e0f-4bce-afe2-14f7d37f9691', 'createdAt': '2024-11-29T09:56:35.125651Z', 'start': '2024-11-29T09:56:35.125651Z', 'end': '2024-11-29T09:56:35.125651Z', 'output': "'parentId'", 'name': 'Error', 'type': 'assistant_message', 'streaming': False, 'isError': True, 'waitForAnswer': False, 'metadata': {}}, {'PartitionKey': 'THREAD-1f28e4d4-5e0f-4bce-afe2-14f7d37f9691', 'RowKey': 'STEP-cdb60bd4-dcf6-4ef8-a4ba-37e32fac48f8', 'id': 'cdb60bd4-dcf6-4ef8-a4ba-37e32fac48f8', 'threadId': '1f28e4d4-5e0f-4bce-afe2-14f7d37f9691', 'createdAt': '2024-11-29T10:00:15.642193Z', 'start': '2024-11-29T10:00:15.642193Z', 'end': '2024-11-29T10:00:15.642193Z', 'output': "Cannot connect to host localhost:8080 ssl:default [Connect call failed ('127.0.0.1', 8080)]", 'name': 'Error', 'type': 'assistant_message', 'streaming': False, 'isError': True, 'waitForAnswer': False, 'metadata': {}}], 'elements': [], 'createdAt': '2024-11-29T15:30:15.649083Z'}]

Here is the data layer that I wrote:

import asyncio
import json
import logging
from dataclasses import asdict
from datetime import datetime
from typing import Any, Dict, List, Optional, Union

import aiofiles
import aiohttp
from azure.core.exceptions import ResourceNotFoundError

# from azure.data.tables import TableClient, TableServiceClient
from azure.data.tables.aio import TableClient, TableServiceClient
from chainlit.context import context
from chainlit.data.base import BaseDataLayer
from chainlit.data.utils import queue_until_user_message
from chainlit.element import ElementDict
from chainlit.step import StepDict
from chainlit.types import (
    Feedback,
    PageInfo,
    PaginatedResponse,
    Pagination,
    ThreadDict,
    ThreadFilter,
)
from chainlit.user import PersistedUser, User

_logger = logging.getLogger(__name__)


class AzureTableDataLayer(BaseDataLayer):
    def __init__(
        self,
        connection_string: str,
        table_name: str,
        storage_provider: None,
        user_thread_limit: int = 10,
    ):
        self.table_service = TableServiceClient.from_connection_string(
            connection_string, connection_verify=False
        )
        self.table_name = table_name
        self.storage_provider = storage_provider
        self.user_thread_limit = user_thread_limit

    def _get_current_timestamp(self) -> str:

        return datetime.now().isoformat() + "Z"

    def _get_table_client(self) -> TableClient:
        return self.table_service.get_table_client(self.table_name)

    @property
    def context(self):
        return context

    async def get_user(self, identifier: str) -> Optional["PersistedUser"]:
        _logger.info("Azure Tables: get_user identifier=%s", identifier)

        table_client = self._get_table_client()
        try:
            entity = await table_client.get_entity(
                partition_key=f"USER-{identifier}", row_key="USER"
            )
            return PersistedUser(
                id=entity["id"],
                identifier=entity["identifier"],
                createdAt=entity["createdAt"],
                metadata=json.loads(entity["metadata"]),
            )
        except ResourceNotFoundError:
            return None

    async def create_user(self, user: "User") -> Optional["PersistedUser"]:
        _logger.info("Azure Tables: create_user user.identifier=%s", user.identifier)

        ts = self._get_current_timestamp()
        metadata = user.metadata or {}

        entity = {
            "PartitionKey": f"USER-{user.identifier}",
            "RowKey": "USER",
            "id": user.identifier,
            "identifier": user.identifier,
            "metadata": json.dumps(metadata),
            "createdAt": ts,
        }

        table_client = self._get_table_client()
        await table_client.create_entity(entity)

        return PersistedUser(
            id=user.identifier,
            identifier=user.identifier,
            createdAt=ts,
            metadata=metadata,
        )

    async def delete_feedback(self, feedback_id: str) -> bool:
        _logger.info("Azure Tables: delete_feedback feedback_id=%s", feedback_id)

        # feedback id = THREAD#{thread_id}::STEP#{step_id}
        thread_id, step_id = feedback_id.split("::")
        thread_id = thread_id.strip("THREAD-")
        step_id = step_id.strip("STEP-")

        table_client = self._get_table_client()
        try:
            entity = await table_client.get_entity(
                partition_key=f"THREAD-{thread_id}", row_key=f"STEP-{step_id}"
            )
            entity.pop("feedback", None)
            await table_client.update_entity(entity)
            return True
        except ResourceNotFoundError:
            return False

    async def upsert_feedback(self, feedback: Feedback) -> str:
        _logger.info(
            "Azure Tables: upsert_feedback thread=%s step=%s value=%s",
            feedback.threadId,
            feedback.forId,
            feedback.value,
        )

        if not feedback.forId:
            raise ValueError(
                "Azure Tables data layer expects value for feedback.threadId got None"
            )

        feedback.id = f"THREAD-{feedback.threadId}::STEP-{feedback.forId}"

        table_client = self._get_table_client()
        try:
            entity = await table_client.get_entity(
                partition_key=f"THREAD-{feedback.threadId}",
                row_key=f"STEP-{feedback.forId}",
            )
            entity["feedback"] = json.dumps(asdict(feedback))
            await table_client.update_entity(entity)
        except ResourceNotFoundError:
            pass

        return feedback.id

    @queue_until_user_message()
    async def create_element(self, element):
        _logger.info(
            "Azure Tables: create_element thread=%s step=%s type=%s",
            element.thread_id,
            element.for_id,
            element.type,
        )

        if not element.for_id:
            return

        if not self.storage_provider:
            _logger.warning(
                "Azure Tables: create_element error. No storage_provider is configured!"
            )
            return

        content: Optional[Union[bytes, str]] = None

        if element.content:
            content = element.content
        elif element.path:
            async with aiofiles.open(element.path, "rb") as f:
                content = await f.read()
        elif element.url:
            async with aiohttp.ClientSession() as session:
                async with session.get(element.url) as response:
                    if response.status == 200:
                        content = await response.read()
                    else:
                        raise ValueError(f"Failed to read content from {element.url}")
        else:
            raise ValueError("Element url, path or content must be provided")

        if content is None:
            raise ValueError("Content is None, cannot upload file")

        if not element.mime:
            element.mime = "application/octet-stream"

        context_user = self.context.session.user
        user_folder = getattr(context_user, "id", "unknown")
        file_object_key = f"{user_folder}/{element.thread_id}/{element.id}"

        uploaded_file = await self.storage_provider.upload_file(
            object_key=file_object_key,
            data=content,
            mime=element.mime,
            overwrite=True,
        )

        if not uploaded_file:
            raise ValueError(
                "Azure Tables Error: Failed to persist data in storage_provider"
            )

        element_dict = element.to_dict()
        entity = {
            "PartitionKey": f"THREAD-{element.thread_id}",
            "RowKey": f"ELEMENT-{element.id}",
            **element_dict,
            "url": uploaded_file.get("url"),
            "objectKey": uploaded_file.get("object_key"),
        }

        table_client = self._get_table_client()
        await table_client.create_entity(entity)

    async def get_element(
        self, thread_id: str, element_id: str
    ) -> Optional["ElementDict"]:
        _logger.info(
            "Azure Tables: get_element thread=%s element=%s", thread_id, element_id
        )

        table_client = self._get_table_client()
        try:
            entity = await table_client.get_entity(
                partition_key=f"THREAD-{thread_id}", row_key=f"ELEMENT-{element_id}"
            )
            return entity
        except ResourceNotFoundError:
            return None

    @queue_until_user_message()
    async def delete_element(self, element_id: str, thread_id: Optional[str] = None):
        thread_id = self.context.session.thread_id
        _logger.info(
            "Azure Tables: delete_element thread=%s element=%s", thread_id, element_id
        )

        table_client = self._get_table_client()
        try:
            await table_client.delete_entity(
                partition_key=f"THREAD-{thread_id}", row_key=f"ELEMENT-{element_id}"
            )
        except ResourceNotFoundError:
            pass

    @queue_until_user_message()
    async def create_step(self, step_dict: "StepDict"):
        _logger.info(
            "Azure Tables: create_step thread=%s step=%s",
            step_dict.get("threadId"),
            step_dict.get("id"),
        )

        entity = {
            "PartitionKey": f"THREAD-{step_dict['threadId']}",
            "RowKey": f"STEP-{step_dict['id']}",
            **step_dict,
        }
        entity = {k: json.dumps(v) if isinstance(v, (dict, list)) else v for k, v in entity.items()}
        table_client = self._get_table_client()
        await table_client.create_entity(entity)

    @queue_until_user_message()
    async def update_step(self, step_dict: "StepDict"):
        _logger.info(
            "Azure Tables: update_step thread=%s step=%s",
            step_dict.get("threadId"),
            step_dict.get("id"),
        )

        table_client = self._get_table_client()
        try:
            entity = await table_client.get_entity(
                partition_key=f"THREAD-{step_dict['threadId']}",
                row_key=f"STEP-{step_dict['id']}",
            )
            entity.update(step_dict)
            entity = {k: json.dumps(v) if isinstance(v, (dict, list)) else v for k, v in entity.items()}
            await table_client.update_entity(entity)
        except ResourceNotFoundError:
            pass

    @queue_until_user_message()
    async def delete_step(self, step_id: str):
        thread_id = self.context.session.thread_id
        _logger.info("Azure Tables: delete_step thread=%s step=%s", thread_id, step_id)

        table_client = self._get_table_client()
        try:
            await table_client.delete_entity(
                partition_key=f"THREAD-{thread_id}", row_key=f"STEP-{step_id}"
            )
        except ResourceNotFoundError:
            pass

    async def get_thread_author(self, thread_id: str) -> str:
        _logger.info("Azure Tables: get_thread_author thread=%s", thread_id)


        table_client = self._get_table_client()
        try:
            entity = await table_client.get_entity(
                partition_key=f"THREAD-{thread_id}", row_key="THREAD"
            )
            return entity["userId"]
        except ResourceNotFoundError:
            raise ValueError(f"Author not found for thread_id {thread_id}")

    async def delete_thread(self, thread_id: str):
        _logger.info("Azure Tables: delete_thread thread=%s", thread_id)

        thread = await self.get_thread(thread_id)
        if not thread:
            return

        table_client = self._get_table_client()

        # Delete all related entities (steps and elements)
        entities = table_client.query_entities(
            query_filter=f"PartitionKey eq 'THREAD-{thread_id}'"
        )

        async for entity in entities:
            try:
                await table_client.delete_entity(
                    partition_key=entity["PartitionKey"], row_key=entity["RowKey"]
                )
            except ResourceNotFoundError:
                continue

    async def list_threads(
        self, pagination: "Pagination", filters: "ThreadFilter"
    ) -> "PaginatedResponse[ThreadDict]":
        _logger.info("Azure Tables: list_threads filters.userId=%s", filters.userId)

        table_client = self._get_table_client()

        # Base query filter
        query_filter = f"PartitionKey eq 'USER-{filters.userId}' and RowKey ne'USER'"

        if filters.search:
            query_filter += f" and name eq '{filters.search}'"

        # Note: Azure Tables doesn't support native feedback filtering
        if filters.feedback:
            _logger.warning("Azure Tables: filters on feedback not supported")

        # Query with continuation token if provided
        entities = table_client.query_entities(
            query_filter=query_filter, results_per_page=self.user_thread_limit
        )

        threads = []
        continuation_token = None

        async for entity in entities:
            thread = ThreadDict(
                id=entity["RowKey"].replace("THREAD-", ""),
                createdAt=entity["timestamp"],
                name=entity.get("name", "Unnamed Thread"),
            )
            threads.append(thread)

            if len(threads) >= self.user_thread_limit:
                continuation_token = entities.continuation_token
                break
        return PaginatedResponse(
            data=threads,
            pageInfo=PageInfo(
                hasNextPage=bool(continuation_token),
                startCursor=pagination.cursor,
                endCursor=continuation_token,
            ),
        )

    async def get_thread(self, thread_id: str) -> "Optional[ThreadDict]":
        _logger.info("Azure Tables: get_thread thread=%s", thread_id)

        table_client = self._get_table_client()

        # Query all entities for the thread
        entities = table_client.query_entities(
            query_filter=f"PartitionKey eq 'THREAD-{thread_id}'"
        )

        thread_dict = None
        steps = []
        elements = []

        async for entity in entities:
            row_key = entity["RowKey"]

            if row_key == "THREAD":
                thread_dict = dict(entity)
            elif row_key.startswith("ELEMENT"):
                elements.append(dict(entity))
            elif row_key.startswith("STEP"):
                if "feedback" in entity:
                    entity["feedback"] = json.loads(entity["feedback"])
                if "metadata" in entity:
                    entity["metadata"] = json.loads(entity["metadata"])
                steps.append(dict(entity))

        if not thread_dict:
            return None

        # Sort steps by creation time
        steps.sort(key=lambda x: x["createdAt"])
        thread_dict.update({"steps": steps, "elements": elements})
        # print(thread_dict)
        thread_dict["createdAt"] = thread_dict["timestamp"]
        thread_dict["metadata"] = json.loads(thread_dict["metadata"])
        return [thread_dict]

    async def update_thread(
        self,
        thread_id: str,
        name: Optional[str] = None,
        user_id: Optional[str] = None,
        metadata: Optional[Dict] = None,
        tags: Optional[List[str]] = None,
    ):
        """
        Update thread information in Azure Table Storage.
        """
        _logger.info(
            "Azure Tables: update_thread thread=%s userId=%s", thread_id, user_id
        )
        _logger.debug(
            "Azure Tables: update_thread name=%s tags=%s metadata=%s",
            name,
            tags,
            metadata,
        )

        ts = self._get_current_timestamp()
        table_client = self._get_table_client()

        # Create the thread entity updates
        thread_entity = {
            "PartitionKey": f"THREAD-{thread_id}",
            "RowKey": "THREAD",
            "timestamp": ts,
            "id": thread_id,
        }

        if name is not None:
            thread_entity["name"] = name
        if metadata is not None:
            thread_entity["metadata"] = json.dumps(metadata)
        if tags is not None:
            thread_entity["tags"] = json.dumps(tags)
        if user_id is not None:
            thread_entity["userId"] = user_id
            thread_entity["userIdentifier"] = user_id

            # Create/update user thread reference
            user_thread_entity = {
                "PartitionKey": f"USER-{user_id}",
                "RowKey": f"THREAD-{thread_id}",
                "timestamp": ts,
                "threadId": thread_id,
                "name": name if name is not None else "Unnamed Thread",
            }

        try:
            # Update thread entity
            existing_entity = await table_client.get_entity(
                partition_key=thread_entity["PartitionKey"],
                row_key=thread_entity["RowKey"],
            )
            existing_entity.update(thread_entity)
            await table_client.update_entity(existing_entity)

            # Update user thread reference if user_id is provided
            if user_id is not None:
                try:
                    existing_user_thread = await table_client.get_entity(
                        partition_key=user_thread_entity["PartitionKey"],
                        row_key=user_thread_entity["RowKey"],
                    )
                    existing_user_thread.update(user_thread_entity)
                    await table_client.update_entity(existing_user_thread)
                except ResourceNotFoundError:
                    await table_client.create_entity(user_thread_entity)

        except ResourceNotFoundError:
            # If thread doesn't exist, create it
            await table_client.create_entity(thread_entity)
            if user_id is not None:
                await table_client.create_entity(user_thread_entity)

    async def build_debug_url(self) -> str:
        """
        Return an empty string as debug URL is not implemented for Azure Tables.
        """

        return ""

Any advice on debugging would be greatly appreciated.

@dosubot dosubot bot added the data layer Pertains to data layers. label Dec 2, 2024
@dokterbob
Copy link
Collaborator

We don't actively support the implementation for 3rd party data layers but perhaps @dosu can offer some suggestions.

In addition, have you tried using something like pyright to ensure there's no typing errors?

Copy link

dosubot bot commented Dec 3, 2024

The RangeError: Invalid time value error in JavaScript typically occurs when a date string is not in a format that can be parsed by the JavaScript Date object. In your case, this error might be related to the timestamps being retrieved from Azure Tables.

Here are a few steps to help you debug and resolve the issue:

  1. Check Timestamp Format: Ensure that all timestamps in your Azure Tables are in the correct ISO 8601 format with a 'Z' suffix, indicating UTC time. For example, 2024-11-29T15:30:15.649083Z. This format is expected by the Chainlit data layer [1].

  2. Verify Data Conversion: In your get_thread method, make sure that the createdAt and other timestamp fields are correctly set to the timestamp field from your Azure Tables data. Double-check that no additional formatting or conversion is altering these values.

  3. Inspect Data in Logs: Add logging to print out the exact values of the timestamps being processed in your get_thread method. This can help you identify if any unexpected values are being returned.

  4. Check for Missing or Null Values: Ensure that none of the timestamp fields are missing or null, as this could also lead to parsing errors in JavaScript.

  5. Use a JavaScript Date Parser: If the issue persists, consider using a JavaScript library like date-fns or moment.js to parse and validate the date strings before they are used in your application.

By following these steps, you should be able to identify and correct the source of the RangeError. If the timestamps are correctly formatted and the error persists, further investigation into how the data is being handled on the client side may be necessary.

To continue talking to Dosu, mention @dosu.

@Okaya20
Copy link

Okaya20 commented Dec 27, 2024

@saimanoj1206 hi there did you managed to find this error , after utilizing your method i also experienced this :/

@Okaya20
Copy link

Okaya20 commented Jan 2, 2025

`import asyncio
import json
import logging
from dataclasses import asdict
from datetime import datetime
from typing import Any, Dict, List, Optional, Union

import aiofiles
import aiohttp
from azure.core.exceptions import ResourceNotFoundError

from azure.data.tables import TableClient, TableServiceClient

from azure.data.tables.aio import TableClient, TableServiceClient
from chainlit.context import context
from chainlit.data.base import BaseDataLayer
from chainlit.data.utils import queue_until_user_message
from chainlit.element import ElementDict
from chainlit.step import StepDict
from chainlit.types import (
Feedback,
PageInfo,
PaginatedResponse,
Pagination,
ThreadDict,
ThreadFilter,
)
from chainlit.user import PersistedUser, User

_logger = logging.getLogger(name)

class AzureTableDataLayer(BaseDataLayer):
def init(
self,
connection_string: str,
table_name: str,
storage_provider: None,
user_thread_limit: int = 10,
):
self.table_service = TableServiceClient.from_connection_string(
connection_string, connection_verify=False
)
self.table_name = table_name
self.storage_provider = storage_provider
self.user_thread_limit = user_thread_limit

def _get_current_timestamp(self) -> str:

    return datetime.now().isoformat() + "Z"

def _get_table_client(self) -> TableClient:
    return self.table_service.get_table_client(self.table_name)

@property
def context(self):
    return context

async def get_user(self, identifier: str) -> Optional["PersistedUser"]:
    _logger.info("Azure Tables: get_user identifier=%s", identifier)

    table_client = self._get_table_client()
    try:
        entity = await table_client.get_entity(
            partition_key=f"USER-{identifier}", row_key="USER"
        )
        return PersistedUser(
            id=entity["id"],
            identifier=entity["identifier"],
            createdAt=entity["createdAt"],
            metadata=json.loads(entity["metadata"]),
        )
    except ResourceNotFoundError:
        return None

async def create_user(self, user: "User") -> Optional["PersistedUser"]:
    _logger.info("Azure Tables: create_user user.identifier=%s", user.identifier)

    ts = self._get_current_timestamp()
    metadata = user.metadata or {}

    entity = {
        "PartitionKey": f"USER-{user.identifier}",
        "RowKey": "USER",
        "id": user.identifier,
        "identifier": user.identifier,
        "metadata": json.dumps(metadata),
        "createdAt": ts,
    }

    table_client = self._get_table_client()
    await table_client.create_entity(entity)

    return PersistedUser(
        id=user.identifier,
        identifier=user.identifier,
        createdAt=ts,
        metadata=metadata,
    )

async def delete_feedback(self, feedback_id: str) -> bool:
    _logger.info("Azure Tables: delete_feedback feedback_id=%s", feedback_id)

    # feedback id = THREAD#{thread_id}::STEP#{step_id}
    thread_id, step_id = feedback_id.split("::")
    thread_id = thread_id.strip("THREAD-")
    step_id = step_id.strip("STEP-")

    table_client = self._get_table_client()
    try:
        entity = await table_client.get_entity(
            partition_key=f"THREAD-{thread_id}", row_key=f"STEP-{step_id}"
        )
        entity.pop("feedback", None)
        await table_client.update_entity(entity)
        return True
    except ResourceNotFoundError:
        return False

async def upsert_feedback(self, feedback: Feedback) -> str:
    _logger.info(
        "Azure Tables: upsert_feedback thread=%s step=%s value=%s",
        feedback.threadId,
        feedback.forId,
        feedback.value
    )

    if not feedback.forId:
        raise ValueError(
            "Azure Tables data layer expects value for feedback.threadId got None"
        )

    feedback.id = f"THREAD-{feedback.threadId}::STEP-{feedback.forId}"

    table_client = self._get_table_client()
    try:
        entity = await table_client.get_entity(
            partition_key=f"THREAD-{feedback.threadId}",
            row_key=f"STEP-{feedback.forId}",
        )
        entity["feedback"] = json.dumps(asdict(feedback))
        await table_client.update_entity(entity)
    except ResourceNotFoundError:
        pass
    userrent = await table_client.get_entity(
        partition_key=f'THREAD-{feedback.threadId}',row_key='THREAD',
    )
    us_id=userrent['userId']
    thread_ent = await table_client.get_entity(
        partition_key=f'USER-{us_id}',row_key=f'THREAD-{feedback.threadId}',
    )
    thread_ent[f"isfeedback{feedback.value}"] = feedback.value
    await table_client.update_entity(thread_ent)
    return feedback.id

@queue_until_user_message()
async def create_element(self, element):
    _logger.info(
        "Azure Tables: create_element thread=%s step=%s type=%s",
        element.thread_id,
        element.for_id,
        element.type,
    )

    if not element.for_id:
        return

    if not self.storage_provider:
        _logger.warning(
            "Azure Tables: create_element error. No storage_provider is configured!"
        )
        return

    content: Optional[Union[bytes, str]] = None

    if element.content:
        content = element.content
    elif element.path:
        async with aiofiles.open(element.path, "rb") as f:
            content = await f.read()
    elif element.url:
        async with aiohttp.ClientSession() as session:
            async with session.get(element.url) as response:
                if response.status == 200:
                    content = await response.read()
                else:
                    raise ValueError(f"Failed to read content from {element.url}")
    else:
        raise ValueError("Element url, path or content must be provided")

    if content is None:
        raise ValueError("Content is None, cannot upload file")

    if not element.mime:
        element.mime = "application/octet-stream"

    context_user = self.context.session.user
    user_folder = getattr(context_user, "id", "unknown")
    file_object_key = f"{user_folder}/{element.thread_id}/{element.id}"

    uploaded_file = await self.storage_provider.upload_file(
        object_key=file_object_key,
        data=content,
        mime=element.mime,
        overwrite=True,
    )

    if not uploaded_file:
        raise ValueError(
            "Azure Tables Error: Failed to persist data in storage_provider"
        )

    element_dict = element.to_dict()
    entity = {
        "PartitionKey": f"THREAD-{element.thread_id}",
        "RowKey": f"ELEMENT-{element.id}",
        **element_dict,
        "url": uploaded_file.get("url"),
        "objectKey": uploaded_file.get("object_key"),
    }

    table_client = self._get_table_client()
    await table_client.create_entity(entity)

async def get_element(
    self, thread_id: str, element_id: str
) -> Optional["ElementDict"]:
    _logger.info(
        "Azure Tables: get_element thread=%s element=%s", thread_id, element_id
    )

    table_client = self._get_table_client()
    try:
        entity = await table_client.get_entity(
            partition_key=f"THREAD-{thread_id}", row_key=f"ELEMENT-{element_id}"
        )
        return entity
    except ResourceNotFoundError:
        return None

@queue_until_user_message()
async def delete_element(self, element_id: str, thread_id: Optional[str] = None):
    thread_id = self.context.session.thread_id
    _logger.info(
        "Azure Tables: delete_element thread=%s element=%s", thread_id, element_id
    )

    table_client = self._get_table_client()
    try:
        await table_client.delete_entity(
            partition_key=f"THREAD-{thread_id}", row_key=f"ELEMENT-{element_id}"
        )
    except ResourceNotFoundError:
        pass

@queue_until_user_message()
async def create_step(self, step_dict: "StepDict"):
    _logger.info(
        "Azure Tables: create_step thread=%s step=%s",
        step_dict.get("threadId"),
        step_dict.get("id"),
    )

    entity = {
        "PartitionKey": f"THREAD-{step_dict['threadId']}",
        "RowKey": f"STEP-{step_dict['id']}",
        **step_dict,
    }
    entity = {k: json.dumps(v) if isinstance(v, (dict, list)) else v for k, v in entity.items()}
    table_client = self._get_table_client()
    await table_client.create_entity(entity)

@queue_until_user_message()
async def update_step(self, step_dict: "StepDict"):
    _logger.info(
        "Azure Tables: update_step thread=%s step=%s",
        step_dict.get("threadId"),
        step_dict.get("id"),
    )

    table_client = self._get_table_client()
    try:
        entity = await table_client.get_entity(
            partition_key=f"THREAD-{step_dict['threadId']}",
            row_key=f"STEP-{step_dict['id']}",
        )
        entity.update(step_dict)
        entity = {k: json.dumps(v) if isinstance(v, (dict, list)) else v for k, v in entity.items()}
        await table_client.update_entity(entity)
    except ResourceNotFoundError:
        pass

@queue_until_user_message()
async def delete_step(self, step_id: str):
    thread_id = self.context.session.thread_id
    _logger.info("Azure Tables: delete_step thread=%s step=%s", thread_id, step_id)

    table_client = self._get_table_client()
    try:
        await table_client.delete_entity(
            partition_key=f"THREAD-{thread_id}", row_key=f"STEP-{step_id}"
        )
    except ResourceNotFoundError:
        pass

async def get_thread_author(self, thread_id: str) -> str:
    _logger.info("Azure Tables: get_thread_author thread=%s", thread_id)


    table_client = self._get_table_client()
    try:
        entity = await table_client.get_entity(
            partition_key=f"THREAD-{thread_id}", row_key="THREAD"
        )
        return entity["userId"]
    except ResourceNotFoundError:
        raise ValueError(f"Author not found for thread_id {thread_id}")

async def delete_thread(self, thread_id: str):
    _logger.info("Azure Tables: delete_thread thread=%s", thread_id)

    thread = await self.get_thread(thread_id)
    if not thread:
        return
    _logger.info(f"kerekere {[i for i in thread]}")
    _logger.info(f"karakara {[thread['userId']]}")
    _logger.info(f"karakara {[thread_id]}")
    table_client = self._get_table_client()

    # Delete all related entities (steps and elements)
    entities = table_client.query_entities(
        query_filter=f"PartitionKey eq 'THREAD-{thread_id}'"
    )
    #kerekere
    async for entity in entities:
        try:
            await table_client.delete_entity(
                partition_key=entity["PartitionKey"], row_key=entity["RowKey"]
            )
    
        except ResourceNotFoundError:
            continue
    await table_client.delete_entity(
        partition_key=f"USER-{thread['userId']}",row_key=f"THREAD-{thread_id}"
    )

async def list_threads(
    self, pagination: "Pagination", filters: "ThreadFilter"
) -> "PaginatedResponse[ThreadDict]":
    _logger.info("Azure Tables: list_threads filters.userId=%s", filters.userId)

    table_client = self._get_table_client()

    # Base query filter
    query_filter = f"PartitionKey eq 'USER-{filters.userId}' and RowKey ne'USER'"

    #if filters.search:
    #    query_filter += f" and name eq '{filters.search}'"

    # Note: Azure Tables doesn't support native feedback filtering
    if filters.feedback == 0 or filters.feedback == 1:
        _logger.warning("Azure Tables: filters on feedback not supported")
        _logger.info(f'Feedbackkk {filters.feedback}')
        query_filter += f" and isfeedback{filters.feedback} eq {filters.feedback}"

    # Query with continuation token if provided
    entities = table_client.query_entities(
        query_filter=query_filter, results_per_page=self.user_thread_limit
    )
    
    search_keyword = filters.search.lower() if filters.search else None
    feedback_value = int(filters.feedback) if filters.feedback else None

    if filters.search:
        #thread_keys = [item["PartitionKey"] for item in entities if item["PartitionKey"].startswith("THREAD-")]
        thread_query = f"name eq '{filters.userId}' and type eq 'user_message' "
        #thread_query += " or ".join(f"PartitionKey eq '{key}'" for key in thread_keys)
        entities2 = table_client.query_entities(
        query_filter=thread_query, results_per_page=self.user_thread_limit)
        #thread_partition=[ti['PartitionKey'] async for ti in  entities2 if any(search_keyword in ti["output"].lower())]
        thread_partition = []
        async for ti in entities2:
            if search_keyword and search_keyword in ti["output"].lower():
                thread_partition.append(ti['PartitionKey'])
        _logger.info(f"thread_partition {thread_partition}")

    if feedback_value:
        print('zim')
    threads = []

    _logger.info(f"search_keyword {search_keyword}")
    _logger.info(f"feedback_value {filters.feedback}")
    filtered_threads = []
    thread_ids = []
    async for entity in entities:
        _logger.info(f"time stampp {entity['timestamp']}")
        
        # Handle timestamp formatting
        if isinstance(entity["timestamp"], datetime):
            entity["timestamp"] = entity["timestamp"].isoformat() + "Z"
        elif isinstance(entity["timestamp"], str) and not entity["timestamp"].endswith("Z"):
            entity["timestamp"] += "Z"
        
        # If no search filter, include all threads
        # If search filter exists, only include threads in thread_partition
        if not filters.search or (filters.search and entity['RowKey'] in thread_partition):
            thread = ThreadDict(
                id=entity["RowKey"].replace("THREAD-", ""),
                createdAt=entity["timestamp"],
                name=entity.get("name", "Unnamed Thread"),
            )
            thread_ids.append(entity["RowKey"])
            threads.append(thread)

        _logger.info(f"threaddd {entity}")
        if len(threads) >= self.user_thread_limit:
            #continuation_token = entities.continuation_token
            break

    start = 0
    if pagination.cursor:
        for i, thread in enumerate(filtered_threads):
            if (
                thread["id"] == pagination.cursor
            ):  # Find the start index using pagination.cursor
                start = i + 1
                break
    end = start + pagination.first
    paginated_threads = filtered_threads[start:end] or []

    has_next_page = len(filtered_threads) > end
    start_cursor = paginated_threads[0]["id"] if paginated_threads else None
    end_cursor = paginated_threads[-1]["id"] if paginated_threads else None
    return PaginatedResponse(
        data=threads,
        pageInfo=PageInfo(
            hasNextPage=has_next_page,
            startCursor=start_cursor,
            endCursor=end_cursor,
        ),
    )

async def get_thread(self, thread_id: str) -> Optional[ThreadDict]:
    """
    Retrieve a single thread from Azure Table Storage.
    Returns a ThreadDict or None if not found.
    """
    _logger.info("Azure Tables: get_thread thread=%s", thread_id)

    table_client = self._get_table_client()

    # Query all entities for this thread
    entities = table_client.query_entities(
        query_filter=f"PartitionKey eq 'THREAD-{thread_id}'"
    )

    thread_dict = None  # We'll store the main thread row here
    steps = []
    elements = []

    async for entity in entities:
        row_key = entity["RowKey"]

        if row_key == "THREAD":
            # This is the main thread entity
            thread_dict = dict(entity)
          #  _logger.info(f"THREAD entity found in get_thread: {thread_dict}")

        elif row_key.startswith("ELEMENT"):
            # This is an Element entity
            elements.append(dict(entity))

        elif row_key.startswith("STEP"):
            # This is a Step entity
            # Convert feedback/metadata from JSON if they exist
            if "feedback" in entity:
                try:
                    entity["feedback"] = json.loads(entity["feedback"])
                except json.JSONDecodeError:
                    entity["feedback"] = None

            if "metadata" in entity:
                try:
                    entity["metadata"] = json.loads(entity["metadata"])
                except json.JSONDecodeError:
                    entity["metadata"] = {}

            steps.append(dict(entity))

    # If we never found the main thread entity, return None
    if not thread_dict:
        return None

    # Sort the steps by `createdAt`. If `createdAt` is missing/invalid, handle gracefully.
    # 1) Convert string to datetime for sorting
    for s in steps:
        if "createdAt" in s and isinstance(s["createdAt"], str):
            try:
                # Remove trailing 'Z' before parsing, if present
                dt_str = s["createdAt"].replace("Z", "")
                dt = datetime.fromisoformat(dt_str)
                s["createdAt"] = dt
            except ValueError:
                # If invalid, fallback to a default or remove
                _logger.warning(f"Invalid createdAt for step: {s['createdAt']}")
                s["createdAt"] = datetime.now()

        elif "createdAt" not in s:
            # If missing entirely, fallback
            s["createdAt"] = datetime.now()

    steps.sort(key=lambda x: x["createdAt"])

    # 2) Convert them back to ISO strings + 'Z'
    for s in steps:
        if isinstance(s["createdAt"], datetime):
            s["createdAt"] = s["createdAt"].isoformat() + "Z"

    # Attach steps and elements to the thread entity
    thread_dict["steps"] = steps
    thread_dict["elements"] = elements

    # Handle timestamp on the main thread entity
    # If 'timestamp' is missing or invalid, fallback
    if "timestamp" not in thread_dict:
        _logger.warning("THREAD entity missing 'timestamp' field, defaulting to now.")
        thread_dict["timestamp"] = datetime.now().isoformat() + "Z"
    else:
        # If 'timestamp' is a Python datetime object
        if isinstance(thread_dict["timestamp"], datetime):
            thread_dict["timestamp"] = thread_dict["timestamp"].isoformat() + "Z"
        elif isinstance(thread_dict["timestamp"], str):
            if not thread_dict["timestamp"].endswith("Z"):
                thread_dict["timestamp"] += "Z"

    # For Chainlit UI, we often set 'createdAt' from 'timestamp'
    thread_dict["createdAt"] = thread_dict["timestamp"]

    # Safely parse thread metadata if present
    if "metadata" in thread_dict:
        try:
            thread_dict["metadata"] = json.loads(thread_dict["metadata"])
        except (TypeError, json.JSONDecodeError):
            thread_dict["metadata"] = {}
    else:
        thread_dict["metadata"] = {}

   # _logger.info(f"Final thread_dict => {thread_dict}")

    # Return a single dict, not a list
    return thread_dict

async def update_thread(
    self,
    thread_id: str,
    name: Optional[str] = None,
    user_id: Optional[str] = None,
    metadata: Optional[Dict] = None,
    tags: Optional[List[str]] = None,
):
    """
    Update thread information in Azure Table Storage.
    """
    _logger.info(
        "Azure Tables: update_thread thread=%s userId=%s", thread_id, user_id
    )
    _logger.info(
        "Azure Tables: update_thread debug name=%s tags=%s metadata=%s",
        name,
        tags,
        metadata,
    )

    ts = self._get_current_timestamp()
    _logger.info(f'lelelele {ts}')
    table_client = self._get_table_client()

    # Create the thread entity updates
    thread_entity = {
        "PartitionKey": f"THREAD-{thread_id}",
        "RowKey": "THREAD",
        "timestamp": ts,
        "id": thread_id,
        'createdAt':ts
    }

    if name is not None:
        thread_entity["name"] = name
    if metadata is not None:
        thread_entity["metadata"] = json.dumps(metadata)
    if tags is not None:
        thread_entity["tags"] = json.dumps(tags)
    if user_id is not None:
        thread_entity["userId"] = user_id
        thread_entity["userIdentifier"] = user_id

        # Create/update user thread reference
        user_thread_entity = {
            "PartitionKey": f"USER-{user_id}",
            "RowKey": f"THREAD-{thread_id}",
            "timestamp": ts,
            "threadId": thread_id,
            "name": name if name is not None else "Unnamed Thread",
        }

    try:
        # Update thread entity
        existing_entity = await table_client.get_entity(
            partition_key=thread_entity["PartitionKey"],
            row_key=thread_entity["RowKey"],
        )
        existing_entity.update(thread_entity)
        await table_client.update_entity(existing_entity)

        # Update user thread reference if user_id is provided
        if user_id is not None:
            try:
                existing_user_thread = await table_client.get_entity(
                    partition_key=user_thread_entity["PartitionKey"],
                    row_key=user_thread_entity["RowKey"],
                )
                existing_user_thread.update(user_thread_entity)
                await table_client.update_entity(existing_user_thread)
            except ResourceNotFoundError:
                await table_client.create_entity(user_thread_entity)

    except ResourceNotFoundError:
        # If thread doesn't exist, create it
        await table_client.create_entity(thread_entity)
        if user_id is not None:
            await table_client.create_entity(user_thread_entity)

async def build_debug_url(self) -> str:
    """
    Return an empty string as debug URL is not implemented for Azure Tables.
    """

    return ""

`

I have even handled feedback and keyword search, its not 100% optimal but nonetheless its working :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
data layer Pertains to data layers.
Projects
None yet
Development

No branches or pull requests

3 participants