From f6c1eaf735d94d1cca0484b4316cfb2d05313571 Mon Sep 17 00:00:00 2001 From: dnth Date: Fri, 11 Oct 2024 21:54:04 +0800 Subject: [PATCH] streamline all transformers models --- nbs/demo.ipynb | 53 ++++++++----------- xinfer/model_factory.py | 40 ++++++++------ xinfer/transformers/__init__.py | 1 + .../{blip2.py => transformers_model.py} | 25 +++++---- 4 files changed, 60 insertions(+), 59 deletions(-) create mode 100644 xinfer/transformers/__init__.py rename xinfer/transformers/{blip2.py => transformers_model.py} (75%) diff --git a/nbs/demo.ipynb b/nbs/demo.ipynb index 6e1a793..6283b92 100644 --- a/nbs/demo.ipynb +++ b/nbs/demo.ipynb @@ -5,36 +5,24 @@ "execution_count": 1, "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/home/dnth/mambaforge-pypy3/envs/xinfer-test/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - }, { "data": { "text/html": [ - "
                             Available Models                             \n",
-       "┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓\n",
-       "┃ Backend       Model ID                           Input --> Output    ┃\n",
-       "┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩\n",
-       "│ transformers  Salesforce/blip2-opt-2.7b          image-text --> text │\n",
-       "│ transformers  sashakunitsyn/vlrm-blip2-opt-2.7b  image-text --> text │\n",
-       "│ transformers  vikhyatk/moondream2                image-text --> text │\n",
-       "└──────────────┴───────────────────────────────────┴─────────────────────┘\n",
+       "
                         Available Models                          \n",
+       "┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓\n",
+       "┃ Backend              Model ID             Input --> Output    ┃\n",
+       "┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩\n",
+       "│ custom-transformers  vikhyatk/moondream2  image-text --> text │\n",
+       "└─────────────────────┴─────────────────────┴─────────────────────┘\n",
        "
\n" ], "text/plain": [ - "\u001b[3m Available Models \u001b[0m\n", - "┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓\n", - "┃\u001b[1m \u001b[0m\u001b[1mBackend \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mModel ID \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mInput --> Output \u001b[0m\u001b[1m \u001b[0m┃\n", - "┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩\n", - "│\u001b[36m \u001b[0m\u001b[36mtransformers\u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35mSalesforce/blip2-opt-2.7b \u001b[0m\u001b[35m \u001b[0m│\u001b[32m \u001b[0m\u001b[32mimage-text --> text\u001b[0m\u001b[32m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36mtransformers\u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35msashakunitsyn/vlrm-blip2-opt-2.7b\u001b[0m\u001b[35m \u001b[0m│\u001b[32m \u001b[0m\u001b[32mimage-text --> text\u001b[0m\u001b[32m \u001b[0m│\n", - "│\u001b[36m \u001b[0m\u001b[36mtransformers\u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35mvikhyatk/moondream2 \u001b[0m\u001b[35m \u001b[0m│\u001b[32m \u001b[0m\u001b[32mimage-text --> text\u001b[0m\u001b[32m \u001b[0m│\n", - "└──────────────┴───────────────────────────────────┴─────────────────────┘\n" + "\u001b[3m Available Models \u001b[0m\n", + "┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓\n", + "┃\u001b[1m \u001b[0m\u001b[1mBackend \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mModel ID \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mInput --> Output \u001b[0m\u001b[1m \u001b[0m┃\n", + "┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩\n", + "│\u001b[36m \u001b[0m\u001b[36mcustom-transformers\u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35mvikhyatk/moondream2\u001b[0m\u001b[35m \u001b[0m│\u001b[32m \u001b[0m\u001b[32mimage-text --> text\u001b[0m\u001b[32m \u001b[0m│\n", + "└─────────────────────┴─────────────────────┴─────────────────────┘\n" ] }, "metadata": {}, @@ -44,12 +32,12 @@ "source": [ "import xinfer\n", "\n", - "xinfer.list_models()" + "xinfer.list_models(\"moondream\")" ] }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -64,14 +52,15 @@ } ], "source": [ - "model = xinfer.create_model(\"vikhyatk/moondream2\", \"transformers\")\n", + "model = xinfer.create_model(\"vikhyatk/moondream2\", \"custom-transformers\")\n", "# model = xinfer.create_model(\"Salesforce/blip2-opt-2.7b\", \"transformers\")\n", - "# model = xinfer.create_model(\"sashakunitsyn/vlrm-blip2-opt-2.7b\", \"transformers\")\n" + "# model = xinfer.create_model(\"sashakunitsyn/vlrm-blip2-opt-2.7b\", \"transformers\")\n", + "# model = xinfer.create_model(\"microsoft/Florence-2\", \"transformers\")" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -80,14 +69,14 @@ "'An animated character with long hair and a serious expression is eating a large burger at a table, with other characters in the background.'" ] }, - "execution_count": 3, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "image = \"https://raw.githubusercontent.com/vikhyat/moondream/main/assets/demo-1.jpg\"\n", - "prompt = \"Describe this image. \"\n", + "prompt = \"Describe this image\"\n", "\n", "model.inference(image, prompt, max_new_tokens=50)" ] @@ -116,7 +105,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.10" + "version": "3.10.15" } }, "nbformat": 4, diff --git a/xinfer/model_factory.py b/xinfer/model_factory.py index 9f7ac2a..0006f7f 100644 --- a/xinfer/model_factory.py +++ b/xinfer/model_factory.py @@ -1,29 +1,33 @@ +from huggingface_hub import HfApi from rich.console import Console from rich.table import Table from .model_registry import InputOutput, ModelRegistry from .timm import TimmModel, timm_models -from .transformers.blip2 import BLIP2, VLRMBlip2 from .transformers.moondream import Moondream +from .transformers.transformers_model import TransformerVisionLanguageModel from .ultralytics import UltralyticsYoloModel, ultralytics_models +def get_vision_language_models(): + api = HfApi() + models = api.list_models(filter="image-to-text") + return [model.modelId for model in models] + + def register_models(): - ModelRegistry.register( - "transformers", - "Salesforce/blip2-opt-2.7b", - BLIP2, - input_output=InputOutput.IMAGE_TEXT_TO_TEXT, - ) - ModelRegistry.register( - "transformers", - "sashakunitsyn/vlrm-blip2-opt-2.7b", - VLRMBlip2, - input_output=InputOutput.IMAGE_TEXT_TO_TEXT, - ) + hf_vision_language_models = get_vision_language_models() + + for model in hf_vision_language_models: + ModelRegistry.register( + "transformers", + model, + TransformerVisionLanguageModel, + input_output=InputOutput.IMAGE_TEXT_TO_TEXT, + ) ModelRegistry.register( - "transformers", + "custom-transformers", "vikhyatk/moondream2", Moondream, input_output=InputOutput.IMAGE_TEXT_TO_TEXT, @@ -51,10 +55,12 @@ def create_model(model_id: str, backend: str, **kwargs): kwargs["model_name"] = model_id if backend == "ultralytics": kwargs["model_name"] = model_id + ".pt" + if backend == "transformers": + kwargs["model_name"] = model_id return ModelRegistry.get_model(model_id, backend, **kwargs) -def list_models(wildcard: str = None, limit: int = 20): +def list_models(wildcard: str = None, backend: str = None, limit: int = 20): console = Console() table = Table(title="Available Models") table.add_column("Backend", style="cyan") @@ -63,7 +69,9 @@ def list_models(wildcard: str = None, limit: int = 20): rows = [] for model in ModelRegistry.list_models(): - if wildcard is None or wildcard.lower() in model["model_id"].lower(): + if (wildcard is None or wildcard.lower() in model["model_id"].lower()) and ( + backend is None or backend.lower() == model["backend"].lower() + ): rows.append( ( model["backend"], diff --git a/xinfer/transformers/__init__.py b/xinfer/transformers/__init__.py new file mode 100644 index 0000000..9a7acf0 --- /dev/null +++ b/xinfer/transformers/__init__.py @@ -0,0 +1 @@ +from .transformers_model import TransformerVisionLanguageModel diff --git a/xinfer/transformers/blip2.py b/xinfer/transformers/transformers_model.py similarity index 75% rename from xinfer/transformers/blip2.py rename to xinfer/transformers/transformers_model.py index 9ee95b8..a58ed04 100644 --- a/xinfer/transformers/blip2.py +++ b/xinfer/transformers/transformers_model.py @@ -1,20 +1,20 @@ import requests import torch from PIL import Image -from transformers import Blip2ForConditionalGeneration, Blip2Processor +from transformers import AutoModelForVision2Seq, AutoProcessor from ..base_model import BaseModel -class BLIP2(BaseModel): - def __init__(self, model_name: str = "Salesforce/blip2-opt-2.7b", **kwargs): +class TransformerVisionLanguageModel(BaseModel): + def __init__(self, model_name: str, **kwargs): self.model_name = model_name self.device = "cuda" if torch.cuda.is_available() else "cpu" self.load_model(**kwargs) def load_model(self, **kwargs): - self.processor = Blip2Processor.from_pretrained(self.model_name, **kwargs) - self.model = Blip2ForConditionalGeneration.from_pretrained( + self.processor = AutoProcessor.from_pretrained(self.model_name, **kwargs) + self.model = AutoModelForVision2Seq.from_pretrained( self.model_name, **kwargs ).to(self.device, torch.bfloat16) @@ -26,11 +26,9 @@ def preprocess(self, image: str | Image.Image, prompt: str = None): if image.startswith(("http://", "https://")): image = Image.open(requests.get(image, stream=True).raw).convert("RGB") else: - raise ValueError("Input string must be an image URL for BLIP2") - else: - raise ValueError( - "Input must be either an image URL or a PIL Image for BLIP2" - ) + raise ValueError("Input string must be an image URL") + elif not isinstance(image, Image.Image): + raise ValueError("Input must be either an image URL or a PIL Image") return self.processor(images=image, text=prompt, return_tensors="pt").to( self.device @@ -52,7 +50,12 @@ def inference(self, image, prompt, **generate_kwargs): return self.postprocess(prediction) -class VLRMBlip2(BLIP2): +# class BLIP2(TransformerVisionLanguageModel): +# def __init__(self, model_name: str = "Salesforce/blip2-opt-2.7b", **kwargs): +# super().__init__(model_name, **kwargs) + + +class VLRMBlip2(TransformerVisionLanguageModel): def __init__(self, model_name: str = "sashakunitsyn/vlrm-blip2-opt-2.7b", **kwargs): super().__init__(model_name, **kwargs) self.load_vlrm_weights()