Skip to content

Commit

Permalink
feat: migration to task processing api
Browse files Browse the repository at this point in the history
Signed-off-by: Anupam Kumar <[email protected]>
  • Loading branch information
kyteinsky committed Aug 8, 2024
1 parent e0651d7 commit f9b6053
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 65 deletions.
2 changes: 2 additions & 0 deletions config.json
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
{
"__comment::log_level": "Log level for the app, see https://docs.python.org/3/library/logging.html#logging-levels",
"__comment::idle_polling_interval": "The interval in seconds to check for new messages when the app has no tasks",
"__comment::tokenizer_file": "The tokenizer file name inside the model directory (loader.model_path)",
"__comment::loader": "CTranslate2 loader options, see https://opennmt.net/CTranslate2/python/ctranslate2.Translator.html#ctranslate2.Translator.__init__. Use 'model_path' key for local paths or 'model_name' key for models hosted on Hugging Face. Both can't be used at the same time.",
"__comment::inference": "CTranslate2 inference options, see the kwargs in https://opennmt.net/CTranslate2/python/ctranslate2.Translator.html#ctranslate2.Translator.translate_batch",
"__comment::changes_to_the_config": "the program needs to be restarted if you change this file since it is stored in memory on startup",
"log_level": 20,
"idle_polling_interval": 5,
"tokenizer_file": "spiece.model",
"loader": {
"model_name": "Nextcloud-AI/madlad400-3b-mt-ct2-int8_float32",
Expand Down
15 changes: 11 additions & 4 deletions lib/Service.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from contextlib import contextmanager
from copy import deepcopy
from time import perf_counter
from typing import TypedDict

import ctranslate2
from sentencepiece import SentencePieceProcessor
Expand All @@ -15,6 +16,12 @@

logger = logging.getLogger(__name__)

class TranslateRequest(TypedDict):
origin_language: str
input: str
target_language: str


if os.getenv("CI") is not None:
ctranslate2.set_random_seed(420)

Expand Down Expand Up @@ -64,7 +71,7 @@ def __init__(self, config: dict):
"Error reading languages list, ensure languages.json is present in the project root"
) from e

def get_lang_names(self):
def get_languages(self) -> dict[str, str]:
return self.languages

def load_config(self, config: dict):
Expand All @@ -76,11 +83,11 @@ def load_config(self, config: dict):

self.config = config_copy

def translate(self, to_language: str, text: str) -> str:
logger.debug(f"translating text to: {to_language}")
def translate(self, data: TranslateRequest) -> str:
logger.debug(f"translating text to: {data['target_language']}")

with translate_context(self.config) as (tokenizer, translator):
input_tokens = tokenizer.Encode(f"<2{to_language}> {clean_text(text)}", out_type=str)
input_tokens = tokenizer.Encode(f"<2{data['target_language']}> {clean_text(data['input'])}", out_type=str)
results = translator.translate_batch(
[input_tokens],
batch_type="tokens",
Expand Down
165 changes: 104 additions & 61 deletions lib/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,17 @@

import logging
import os
import queue
import threading
import typing
from contextlib import asynccontextmanager
from contextlib import asynccontextmanager, suppress
from time import sleep

import uvicorn.logging
from dotenv import load_dotenv
from fastapi import Body, FastAPI, Request, responses
from fastapi import FastAPI, Request, responses
from nc_py_api import AsyncNextcloudApp, NextcloudApp
from nc_py_api.ex_app import LogLvl, run_app, set_handlers
from Service import Service
from nc_py_api.ex_app.providers.task_processing import ShapeEnumValue, TaskProcessingProvider
from Service import Service, TranslateRequest
from util import load_config_file, save_config_file

load_dotenv()
Expand Down Expand Up @@ -45,33 +45,45 @@ def __setitem__(self, key, value):
models_to_fetch = { config["loader"]["model_name"]: ModelConfig({ "cache_dir": cache_dir }) }


worker = None
@asynccontextmanager
async def lifespan(_: FastAPI):
global worker
set_handlers(
fast_api_app=APP,
enabled_handler=enabled_handler, # type: ignore
models_to_fetch=models_to_fetch, # type: ignore
)
t = BackgroundProcessTask()
t.start()
worker = BackgroundProcessTask()
worker.start()
yield
if isinstance(worker, threading.Thread):
worker._stop() # pyright: ignore[reportAttributeAccessIssue]
worker = None


APP_ID = "translate2"
TASK_TYPE_ID = "core:text2text:translate"
IDLE_POLLING_INTERVAL = config["idle_polling_interval"]
DETECT_LANGUAGE = ShapeEnumValue(name="Detect Language", value="auto")
APP = FastAPI(lifespan=lifespan)
TASK_LIST: queue.Queue = queue.Queue(maxsize=100)
service = Service(config)


def report_error(task: dict | None, exc: Exception):
with suppress(Exception):
nc = NextcloudApp()
nc.log(LogLvl.ERROR, str(exc))
if task:
nc.providers.task_processing.report_result(task["id"], error_message=str(exc))


@APP.exception_handler(Exception)
async def _(request: Request, exc: Exception):
logger.error("Error processing request", request.url.path, exc)

task: dict | None = getattr(exc, "args", None)

nc = NextcloudApp()
nc.log(LogLvl.ERROR, str(exc))
if task:
nc.providers.translations.report_result(task["id"], error=str(exc))
report_error(task, exc)

return responses.JSONResponse({
"error": "An error occurred while processing the request, please check the logs for more info"
Expand All @@ -80,62 +92,93 @@ async def _(request: Request, exc: Exception):

class BackgroundProcessTask(threading.Thread):
def run(self, *args, **kwargs): # pylint: disable=unused-argument
nc = NextcloudApp()
while True:
task = TASK_LIST.get(block=True)
if not nc.enabled_state:
logger.debug("App is disabled")
break

task = nc.providers.task_processing.next_task([APP_ID], [TASK_TYPE_ID])
if not task:
logger.debug("No tasks found")
sleep(IDLE_POLLING_INTERVAL)
continue

logger.debug(f"Processing task: {task}")

input_ = task.get("task", {}).get("input")
if input_ is None or not isinstance(input_, dict):
logger.error("Invalid task object received, expected task object with input key")
continue

output = None
error = None
try:
translation = service.translate(task["to_language"], task["text"])
NextcloudApp().providers.translations.report_result(
task_id=task["id"],
result=str(translation).strip(),
)
except Exception as e: # noqa
e.args = task
raise e


@APP.post("/translate")
async def tiny_llama(
from_language: typing.Annotated[str, Body()],
to_language: typing.Annotated[str, Body()],
text: typing.Annotated[str, Body()],
task_id: typing.Annotated[int, Body()],
):
try:
task = {
"text": text,
"from_language": from_language,
"to_language": to_language,
"id": task_id,
}
logger.debug(task)
TASK_LIST.put(task)
except queue.Full:
return responses.JSONResponse(content={"error": "task queue is full"}, status_code=429)
return responses.Response()
request = TranslateRequest(**input_)
translation = service.translate(request)
logger.debug(f"Translation: {translation}")
output = translation
except Exception as e:
e.args = (task,)
report_error(task, e)
error = f"Error translating the input text: {e}"

nc.providers.task_processing.report_result(
task_id=task["task"]["id"],
output={"output": output},
error_message=error,
)


async def enabled_handler(enabled: bool, nc: AsyncNextcloudApp) -> str:
global worker
print(f"enabled={enabled}")
if enabled is True:
languages = service.get_lang_names()
logger.info(
"Supported languages short list", {
"count": len(languages),
"languages": list(languages.keys())[:10],
}
)
await nc.providers.translations.register(
"translate2",
"Local Machine Translation",
"/translate",
languages,
languages,
)
else:
await nc.providers.speech_to_text.unregister("translate2")

if not enabled:
await nc.providers.task_processing.unregister(APP_ID)
if isinstance(worker, threading.Thread):
worker._stop() # pyright: ignore[reportAttributeAccessIssue]
worker = None
return ""

languages = [
ShapeEnumValue(name=lang_name, value=lang_id)
for lang_id, lang_name in service.get_languages().items()
]

provider = TaskProcessingProvider(
id=APP_ID,
name="Local Machine Translation",
task_type=TASK_TYPE_ID,
input_shape_enum_values={
"origin_language": [DETECT_LANGUAGE],
"target_language": languages,
},
input_shape_defaults={
"origin_language": DETECT_LANGUAGE.value,
},
)

await nc.providers.task_processing.register(
APP_ID,
"Local Machine Translation",
TASK_TYPE_ID,
provider,
)

if isinstance(worker, threading.Thread):
worker._stop() # pyright: ignore[reportAttributeAccessIssue]

worker = BackgroundProcessTask()
worker.start()

return ""


if __name__ == "__main__":
uvicorn_log_level = uvicorn.logging.TRACE_LOG_LEVEL if config["log_level"] == logging.DEBUG else config["log_level"]
uvicorn_log_level = (
uvicorn.logging.TRACE_LOG_LEVEL
if config["log_level"] == logging.DEBUG
else config["log_level"]
)
run_app("main:APP", log_level=uvicorn_log_level)

0 comments on commit f9b6053

Please sign in to comment.