From 5076b8a42733c6fd1286bcc82b8a48893f0aae03 Mon Sep 17 00:00:00 2001 From: dnth Date: Fri, 11 Oct 2024 10:33:31 +0800 Subject: [PATCH] update readme --- README.md | 2 ++ nbs/example.ipynb | 42 ++++++++++++---------------------------- pyproject.toml | 1 + xinfer/model_factory.py | 13 ++----------- xinfer/model_registry.py | 13 +++++++++++-- 5 files changed, 28 insertions(+), 43 deletions(-) diff --git a/README.md b/README.md index ab298b5..b125703 100644 --- a/README.md +++ b/README.md @@ -24,6 +24,8 @@ xinfer is a modular Python framework that provides a unified interface for perfo - Ultralytics YOLO: State-of-the-art real-time object detection models. - Custom Models: Support for your own machine learning models and architectures. +## Prerequisites +Install [PyTorch](https://pytorch.org/get-started/locally/). ## Installation Install xinfer using pip: diff --git a/nbs/example.ipynb b/nbs/example.ipynb index 807441c..22593b7 100644 --- a/nbs/example.ipynb +++ b/nbs/example.ipynb @@ -5,6 +5,14 @@ "execution_count": 1, "metadata": {}, "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/dnth/mambaforge-pypy3/envs/xinfer-test/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", + " from .autonotebook import tqdm as notebook_tqdm\n" + ] + }, { "data": { "text/html": [ @@ -49,24 +57,11 @@ "execution_count": 2, "metadata": {}, "outputs": [ - { - "data": { - "application/vnd.jupyter.widget-view+json": { - "model_id": "a59fae37d6c24825ba4afafa57451156", - "version_major": 2, - "version_minor": 0 - }, - "text/plain": [ - "Loading checkpoint shards: 0%| | 0/2 [00:00 text" - IMAGE_TEXT_TO_TEXT = "image-text --> text" - TEXT_TO_TEXT = "text --> text" - IMAGE_TO_BBOX = "image --> bbox" - IMAGE_TO_CLASS = "image --> class" - - def register_models(): ModelRegistry.register( "transformers", @@ -34,6 +24,7 @@ def create_model(model_id: str, backend: str, **kwargs): return ModelRegistry.get_model(model_id, backend, **kwargs) +# TODO: list by backend or wildcard def list_models(): console = Console() table = Table(title="Available Models") diff --git a/xinfer/model_registry.py b/xinfer/model_registry.py index defeada..c67bfcf 100644 --- a/xinfer/model_registry.py +++ b/xinfer/model_registry.py @@ -1,13 +1,22 @@ from dataclasses import dataclass, field +from enum import Enum from typing import Dict, List, Type from .base_model import BaseModel +class InputOutput(Enum): + IMAGE_TO_TEXT = "image --> text" + IMAGE_TEXT_TO_TEXT = "image-text --> text" + TEXT_TO_TEXT = "text --> text" + IMAGE_TO_BBOX = "image --> bbox" + IMAGE_TO_CLASS = "image --> class" + + @dataclass class ModelInfo: model_class: Type[BaseModel] - input_output: str = "" + input_output: InputOutput @dataclass @@ -25,7 +34,7 @@ def register( backend: str, model_id: str, model_class: Type[BaseModel], - input_output: str = "", + input_output: InputOutput, ): if backend not in cls._registry: cls._registry[backend] = BackendRegistry(backend_name=backend)