diff --git a/Dockerfile b/.docker/Dockerfile similarity index 100% rename from Dockerfile rename to .docker/Dockerfile diff --git a/.docker/docker-compose.dev.yml b/.docker/docker-compose.dev.yml new file mode 100644 index 00000000..71d0569d --- /dev/null +++ b/.docker/docker-compose.dev.yml @@ -0,0 +1,80 @@ +version: "3" +services: + kafka1: + image: confluentinc/cp-kafka + container_name: kafka1 + hostname: kafka1 + ports: + - "9092:9092" + environment: + KAFKA_NODE_ID: 1 + KAFKA_CONTROLLER_LISTENER_NAMES: "CONTROLLER" + KAFKA_LISTENER_SECURITY_PROTOCOL_MAP: "CONTROLLER:PLAINTEXT,INTERNAL:PLAINTEXT,EXTERNAL:PLAINTEXT" + KAFKA_LISTENERS: "INTERNAL://kafka1:29092,CONTROLLER://kafka1:29093,EXTERNAL://0.0.0.0:9092" + KAFKA_ADVERTISED_LISTENERS: "INTERNAL://kafka1:29092,EXTERNAL://localhost:9092" + KAFKA_INTER_BROKER_LISTENER_NAME: "INTERNAL" + KAFKA_CONTROLLER_QUORUM_VOTERS: "1@kafka1:29093" + KAFKA_PROCESS_ROLES: "broker,controller" + KAFKA_GROUP_INITIAL_REBALANCE_DELAY_MS: 0 + KAFKA_OFFSETS_TOPIC_REPLICATION_FACTOR: 1 + KAFKA_TRANSACTION_STATE_LOG_REPLICATION_FACTOR: 1 + CLUSTER_ID: "vGAyBqiIjJIyN9Tp5B3aVQ==" + KAFKA_LOG_DIRS: "/tmp/kraft-combined-logs" + + healthcheck: + test: nc -z localhost 9092 || exit 1 + interval: 10s + timeout: 5s + retries: 15 + + # schema-registry0: + # image: confluentinc/cp-schema-registry + # container_name: schema-registry0 + # hostname: schema-registry0 + # ports: + # - "8081:8081" + # environment: + # SCHEMA_REGISTRY_HOST_NAME: schema-registry0 + # SCHEMA_REGISTRY_KAFKASTORE_BOOTSTRAP_SERVERS: "kafka1:29092" + # SCHEMA_REGISTRY_LISTENERS: "http://0.0.0.0:8081" + # depends_on: + # - kafka1 + + kafka-ui: + image: provectuslabs/kafka-ui + container_name: kafka-ui + ports: + - "3001:8080" + environment: + KAFKA_CLUSTERS_0_NAME: local + KAFKA_CLUSTERS_0_BOOTSTRAPSERVERS: kafka1:29092 + + # KAFKA_CLUSTERS_0_METRICS_PORT: 9997 + # KAFKA_CLUSTERS_0_SCHEMAREGISTRY: http://schema-registry0:8081 + + depends_on: + - kafka1 + # kafka-init-topics: + # image: confluentinc/cp-kafka:7.2.1 + # volumes: + # - ./data/message.json:/data/message.json + # depends_on: + # - kafka1 + # command: "bash -c 'echo Waiting for Kafka to be ready... && \ + # cub kafka-ready -b kafka1:29092 1 30 && \ + # kafka-topics --create --topic second.users --partitions 3 --replication-factor 1 --if-not-exists --bootstrap-server kafka1:29092 && \ + # kafka-topics --create --topic second.messages --partitions 2 --replication-factor 1 --if-not-exists --bootstrap-server kafka1:29092 && \ + # kafka-topics --create --topic first.messages --partitions 2 --replication-factor 1 --if-not-exists --bootstrap-server kafka0:29092 && \ + # kafka-console-producer --bootstrap-server kafka1:29092 -topic second.users < /data/message.json'" + redis: + image: redis:latest + container_name: redis + ports: + - "6379:6379" + networks: + - my-network + +networks: + my-network: + name: my-network + external: true diff --git a/.env.example b/.env.example index 31ff8a25..15d40430 100644 --- a/.env.example +++ b/.env.example @@ -19,3 +19,15 @@ PINECONE_INDEX= # Unstructured API UNSTRUCTURED_IO_API_KEY= UNSTRUCTURED_IO_SERVER_URL= + +# Redis +REDIS_HOST=localhost +REDIS_PORT=6379 + +# Kafka +KAFKA_TOPIC_INGESTION=ingestion +KAFKA_BOOTSTRAP_SERVERS=localhost:9092 # Comma separated list of kafka brokers (e.g. localhost:9092,localhost:9093) +KAFKA_SECURITY_PROTOCOL= +KAFKA_SASL_MECHANISM= +KAFKA_SASL_PLAIN_USERNAME= +KAFKA_SASL_PLAIN_PASSWORD= \ No newline at end of file diff --git a/README.md b/README.md index 4bf1ba4c..6f51d6b1 100644 --- a/README.md +++ b/README.md @@ -55,7 +55,18 @@ Easiest way to get started is to use our [Cloud API](https://d3jvqvcd9u4534.clou 5. Run server ```bash uvicorn main:app --reload + ``` +6. Start Kafka & Redis + ```bash + docker compose -f docker-compose.dev.yml up -d + ``` +7. Start Kafka consumer + ```bash + cd ./service/kafka & python3 ./consume.py ``` + + + ## 🤖 Interpreter mode Super-Rag has built in support for running computational Q&A using code interpreters powered by [E2B.dev](https://e2b.dev) custom runtimes. You can signup to receive an API key to leverage they sandboxes in a cloud environment or setup your own by following [these instructions](https://github.com/e2b-dev/infra). diff --git a/api/ingest.py b/api/ingest.py index 7bde0d04..45767ea5 100644 --- a/api/ingest.py +++ b/api/ingest.py @@ -1,62 +1,168 @@ import asyncio +import time +import logging from typing import Dict import aiohttp -from fastapi import APIRouter +from fastapi import APIRouter, status +from fastapi.responses import JSONResponse -from models.ingest import RequestPayload +from models.api import ApiError +from models.ingest import RequestPayload, TaskStatus from service.embedding import EmbeddingService from service.ingest import handle_google_drive, handle_urls +from service.kafka.config import ingest_topic +from service.kafka.producer import kafka_producer +from service.redis.client import redis_client +from service.redis.ingest_task_manager import ( + CreateTaskDto, + IngestTaskManager, + UpdateTaskDto, +) from utils.summarise import SUMMARY_SUFFIX + router = APIRouter() +logger = logging.getLogger(__name__) + + +class IngestPayload(RequestPayload): + task_id: str + + @router.post("/ingest") -async def ingest(payload: RequestPayload) -> Dict: - encoder = payload.document_processor.encoder.get_encoder() - embedding_service = EmbeddingService( - encoder=encoder, - index_name=payload.index_name, - vector_credentials=payload.vector_database, - dimensions=payload.document_processor.encoder.dimensions, - ) - chunks = [] - summary_documents = [] - if payload.files: - chunks, summary_documents = await handle_urls( - embedding_service=embedding_service, - files=payload.files, - config=payload.document_processor, +async def add_ingest_queue(payload: RequestPayload): + try: + task_manager = IngestTaskManager(redis_client) + task_id = task_manager.create(CreateTaskDto(status=TaskStatus.PENDING)) + + message = IngestPayload(**payload.model_dump(), task_id=str(task_id)) + + msg = message.model_dump_json().encode() + + kafka_producer.send(ingest_topic, msg) + kafka_producer.flush() + + logger.info(f"Task {task_id} added to the queue") + + return {"success": True, "task": {"id": task_id}} + + except Exception as err: + logger.error(f"Error adding task to the queue: {err}") + return JSONResponse( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + content={"sucess": False, "error": {"message": "Internal server error"}}, ) - elif payload.google_drive: - chunks, summary_documents = await handle_google_drive( - embedding_service, payload.google_drive - ) # type: ignore TODO: Fix typing - tasks = [ - embedding_service.embed_and_upsert( - chunks=chunks, encoder=encoder, index_name=payload.index_name - ), - ] +@router.get("/ingest/tasks/{task_id}") +async def get_task( + task_id: str, + long_polling: bool = False, +): + if long_polling: + logger.info(f"Long pooling is enabled for task {task_id}") + else: + logger.info(f"Long pooling is disabled for task {task_id}") - if summary_documents and all(item is not None for item in summary_documents): - tasks.append( - embedding_service.embed_and_upsert( - chunks=summary_documents, - encoder=encoder, - index_name=f"{payload.index_name}{SUMMARY_SUFFIX}", - ) + task_manager = IngestTaskManager(redis_client) + + def handle_task_not_found(task_id: str): + logger.warning(f"Task {task_id} not found") + return JSONResponse( + status_code=status.HTTP_404_NOT_FOUND, + content={"success": False, "error": {"message": "Task not found"}}, ) - await asyncio.gather(*tasks) + if not long_polling: + task = task_manager.get(task_id) + if not task: + return handle_task_not_found(task_id) + return {"success": True, "task": task.model_dump()} + else: + start_time = time.time() + timeout_time = start_time + 30 # 30 seconds from now + sleep_interval = 3 # seconds + + while start_time < timeout_time: + task = task_manager.get(task_id) + + if task is None: + return handle_task_not_found(task_id) - if payload.webhook_url: - async with aiohttp.ClientSession() as session: - await session.post( - url=payload.webhook_url, - json={"index_name": payload.index_name, "status": "completed"}, + if task.status != TaskStatus.PENDING: + return {"success": True, "task": task.model_dump()} + await asyncio.sleep(sleep_interval) + + logger.warning(f"Request timeout for task {task_id}") + + return JSONResponse( + status_code=status.HTTP_408_REQUEST_TIMEOUT, + content={"sucess": False, "error": {"message": "Request timeout"}}, + ) + + +async def ingest(payload: IngestPayload, task_manager: IngestTaskManager) -> Dict: + try: + encoder = payload.document_processor.encoder.get_encoder() + embedding_service = EmbeddingService( + encoder=encoder, + index_name=payload.index_name, + vector_credentials=payload.vector_database, + dimensions=payload.document_processor.encoder.dimensions, + ) + chunks = [] + summary_documents = [] + if payload.files: + chunks, summary_documents = await handle_urls( + embedding_service=embedding_service, + files=payload.files, + config=payload.document_processor, ) - return {"success": True, "index_name": payload.index_name} + elif payload.google_drive: + chunks, summary_documents = await handle_google_drive( + embedding_service, payload.google_drive + ) # type: ignore TODO: Fix typing + + tasks = [ + embedding_service.embed_and_upsert( + chunks=chunks, encoder=encoder, index_name=payload.index_name + ), + ] + + if summary_documents and all(item is not None for item in summary_documents): + tasks.append( + embedding_service.embed_and_upsert( + chunks=summary_documents, + encoder=encoder, + index_name=f"{payload.index_name}{SUMMARY_SUFFIX}", + ) + ) + + await asyncio.gather(*tasks) + + if payload.webhook_url: + async with aiohttp.ClientSession() as session: + await session.post( + url=payload.webhook_url, + json={"index_name": payload.index_name, "status": "completed"}, + ) + + task_manager.update( + task_id=payload.task_id, + task=UpdateTaskDto( + status=TaskStatus.DONE, + ), + ) + except Exception as e: + logger.error(f"Error processing ingest task: {e}") + task_manager.update( + task_id=payload.task_id, + task=UpdateTaskDto( + status=TaskStatus.FAILED, + error=ApiError(message=str(e)), + ), + ) diff --git a/models/api.py b/models/api.py new file mode 100644 index 00000000..1ce29cf5 --- /dev/null +++ b/models/api.py @@ -0,0 +1,7 @@ +from typing import Optional + +from pydantic import BaseModel + + +class ApiError(BaseModel): + message: Optional[str] diff --git a/models/ingest.py b/models/ingest.py index e82920d9..690bcb92 100644 --- a/models/ingest.py +++ b/models/ingest.py @@ -4,6 +4,7 @@ from pydantic import BaseModel, Field from semantic_router.encoders import BaseEncoder, CohereEncoder, OpenAIEncoder +from models.api import ApiError from models.file import File from models.google_drive import GoogleDrive from models.vector_database import VectorDatabase @@ -90,3 +91,14 @@ class RequestPayload(BaseModel): files: Optional[List[File]] = None google_drive: Optional[GoogleDrive] = None webhook_url: Optional[str] = None + + +class TaskStatus(str, Enum): + DONE = "DONE" + PENDING = "PENDING" + FAILED = "FAILED" + + +class IngestTaskResponse(BaseModel): + status: TaskStatus + error: Optional[ApiError] = None diff --git a/poetry.lock b/poetry.lock index 6889e530..98d10fd6 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1800,6 +1800,20 @@ traitlets = ">=5.3" docs = ["myst-parser", "pydata-sphinx-theme", "sphinx-autodoc-typehints", "sphinxcontrib-github-alt", "sphinxcontrib-spelling", "traitlets"] test = ["ipykernel", "pre-commit", "pytest", "pytest-cov", "pytest-timeout"] +[[package]] +name = "kafka-python" +version = "2.0.2" +description = "Pure Python client for Apache Kafka" +optional = false +python-versions = "*" +files = [ + {file = "kafka-python-2.0.2.tar.gz", hash = "sha256:04dfe7fea2b63726cd6f3e79a2d86e709d608d74406638c5da33a01d45a9d7e3"}, + {file = "kafka_python-2.0.2-py2.py3-none-any.whl", hash = "sha256:2d92418c7cb1c298fa6c7f0fb3519b520d0d7526ac6cb7ae2a4fc65a51a94b6e"}, +] + +[package.extras] +crc32c = ["crc32c"] + [[package]] name = "langdetect" version = "1.0.9" @@ -3166,6 +3180,24 @@ files = [ [package.extras] full = ["numpy"] +[[package]] +name = "redis" +version = "5.0.2" +description = "Python client for Redis database and key-value store" +optional = false +python-versions = ">=3.7" +files = [ + {file = "redis-5.0.2-py3-none-any.whl", hash = "sha256:4caa8e1fcb6f3c0ef28dba99535101d80934b7d4cd541bbb47f4a3826ee472d1"}, + {file = "redis-5.0.2.tar.gz", hash = "sha256:3f82cc80d350e93042c8e6e7a5d0596e4dd68715babffba79492733e1f367037"}, +] + +[package.dependencies] +async-timeout = ">=4.0.3" + +[package.extras] +hiredis = ["hiredis (>=1.0.0)"] +ocsp = ["cryptography (>=36.0.1)", "pyopenssl (==20.0.1)", "requests (>=2.26.0)"] + [[package]] name = "regex" version = "2023.12.25" @@ -4258,4 +4290,4 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [metadata] lock-version = "2.0" python-versions = ">=3.9,<3.12" -content-hash = "b9774decb9338d39235bb4599347540fa5a684cecd89d194a95b26257be43364" +content-hash = "0bc57d0e29028de68b822188a12711cf8ee6b2cd6ead23c34598dd948b13cb06" diff --git a/pyproject.toml b/pyproject.toml index 7263d176..d8b99692 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -33,6 +33,9 @@ gunicorn = "^21.2.0" unstructured-client = "^0.18.0" unstructured = {extras = ["google-drive"], version = "^0.12.4"} tiktoken = "^0.6.0" +kafka-python = "^2.0.2" +pydantic = "^2.6.3" +redis = "^5.0.2" [tool.poetry.group.dev.dependencies] termcolor = "^2.4.0" diff --git a/service/kafka/config.py b/service/kafka/config.py new file mode 100644 index 00000000..80b7f5f9 --- /dev/null +++ b/service/kafka/config.py @@ -0,0 +1,8 @@ +from decouple import config + +ingest_topic = config("KAFKA_TOPIC_INGEST", default="ingestion") + + +kafka_bootstrap_servers: str = config( + "KAFKA_BOOTSTRAP_SERVERS", default="localhost:9092" +) diff --git a/service/kafka/consume.py b/service/kafka/consume.py new file mode 100644 index 00000000..b1bde197 --- /dev/null +++ b/service/kafka/consume.py @@ -0,0 +1,41 @@ +import asyncio + +from kafka.consumer.fetcher import ConsumerRecord + +from api.ingest import IngestPayload +from api.ingest import ingest as _ingest +from service.kafka.config import ingest_topic +from service.kafka.consumer import get_kafka_consumer +from service.redis.client import redis_client +from service.redis.ingest_task_manager import IngestTaskManager + + +async def ingest(msg: ConsumerRecord): + payload = IngestPayload(**msg.value) + task_manager = IngestTaskManager(redis_client) + await _ingest(payload, task_manager) + + +kafka_actions = { + ingest_topic: ingest, +} + + +async def process_msg(msg: ConsumerRecord, topic: str, consumer): + await kafka_actions[topic](msg) + consumer.commit() + + +async def consume(): + consumer = get_kafka_consumer(ingest_topic) + + while True: + # Response format is {TopicPartiton('topic1', 1): [msg1, msg2]} + msg_pack = consumer.poll(timeout_ms=3000) + + for tp, messages in msg_pack.items(): + for message in messages: + await process_msg(message, tp.topic, consumer) + + +asyncio.run(consume()) diff --git a/service/kafka/consumer.py b/service/kafka/consumer.py new file mode 100644 index 00000000..de150c2e --- /dev/null +++ b/service/kafka/consumer.py @@ -0,0 +1,23 @@ +import json + +from kafka import KafkaConsumer +from decouple import config + +from service.kafka.config import kafka_bootstrap_servers + + +def get_kafka_consumer(topic: str): + consumer = KafkaConsumer( + topic, + bootstrap_servers=kafka_bootstrap_servers, + group_id="my-group", + security_protocol=config("KAFKA_SECURITY_PROTOCOL", "PLAINTEXT"), + sasl_mechanism=config("KAFKA_SASL_MECHANISM", "PLAIN"), + sasl_plain_username=config("KAFKA_SASL_PLAIN_USERNAME", None), + sasl_plain_password=config("KAFKA_SASL_PLAIN_PASSWORD", None), + auto_offset_reset="earliest", + value_deserializer=lambda m: json.loads(m.decode("ascii")), + enable_auto_commit=False, + ) + + return consumer diff --git a/service/kafka/producer.py b/service/kafka/producer.py new file mode 100644 index 00000000..b1933bae --- /dev/null +++ b/service/kafka/producer.py @@ -0,0 +1,12 @@ +from kafka import KafkaProducer +from decouple import config +from service.kafka.config import kafka_bootstrap_servers + +kafka_producer = KafkaProducer( + bootstrap_servers=kafka_bootstrap_servers, + security_protocol=config("KAFKA_SECURITY_PROTOCOL", "PLAINTEXT"), + sasl_mechanism=config("KAFKA_SASL_MECHANISM", "PLAIN"), + sasl_plain_username=config("KAFKA_SASL_PLAIN_USERNAME", None), + sasl_plain_password=config("KAFKA_SASL_PLAIN_PASSWORD", None), + api_version_auto_timeout_ms=100000, +) diff --git a/service/redis/client.py b/service/redis/client.py new file mode 100644 index 00000000..0a6f61cd --- /dev/null +++ b/service/redis/client.py @@ -0,0 +1,6 @@ +from decouple import config +from redis import Redis + +redis_client = Redis( + host=config("REDIS_HOST", "localhost"), port=config("REDIS_PORT", 6379) +) diff --git a/service/redis/ingest_task_manager.py b/service/redis/ingest_task_manager.py new file mode 100644 index 00000000..964eaa14 --- /dev/null +++ b/service/redis/ingest_task_manager.py @@ -0,0 +1,47 @@ +import json + +from redis import Redis + +from models.ingest import IngestTaskResponse + + +class CreateTaskDto(IngestTaskResponse): + pass + + +class UpdateTaskDto(IngestTaskResponse): + pass + + +class IngestTaskManager: + TASK_PREFIX = "ingest:task:" + INGESTION_TASK_ID_KEY = "ingestion_task_id" + + def __init__(self, redis_client: Redis): + self.redis_client = redis_client + + def _get_task_key(self, task_id): + return f"{self.TASK_PREFIX}{task_id}" + + def create(self, task: CreateTaskDto): + task_id = self.redis_client.incr(self.INGESTION_TASK_ID_KEY) + task_key = self._get_task_key(task_id) + self.redis_client.set(task_key, task.model_dump_json()) + return task_id + + def get(self, task_id): + task_key = self._get_task_key(task_id) + task_data = self.redis_client.get(task_key) + + if task_data: + return IngestTaskResponse(**json.loads(task_data)) + else: + return None + + def update(self, task_id, task: UpdateTaskDto): + task_key = self._get_task_key(task_id) + self.redis_client.set(task_key, task.model_dump_json()) + + def delete(self, task_id): + task_key = self._get_task_key(task_id) + self.redis_client.delete(task_key)