Skip to content

Commit

Permalink
streamline all transformers models
Browse files Browse the repository at this point in the history
  • Loading branch information
dnth committed Oct 11, 2024
1 parent 382737c commit f6c1eaf
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 59 deletions.
53 changes: 21 additions & 32 deletions nbs/demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -5,36 +5,24 @@
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/dnth/mambaforge-pypy3/envs/xinfer-test/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
]
},
{
"data": {
"text/html": [
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-style: italic\"> Available Models </span>\n",
"┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓\n",
"┃<span style=\"font-weight: bold\"> Backend </span>┃<span style=\"font-weight: bold\"> Model ID </span>┃<span style=\"font-weight: bold\"> Input --&gt; Output </span>┃\n",
"┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩\n",
"│<span style=\"color: #008080; text-decoration-color: #008080\"> transformers </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> Salesforce/blip2-opt-2.7b </span>│<span style=\"color: #008000; text-decoration-color: #008000\"> image-text --&gt; text </span>│\n",
"│<span style=\"color: #008080; text-decoration-color: #008080\"> transformers </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> sashakunitsyn/vlrm-blip2-opt-2.7b </span>│<span style=\"color: #008000; text-decoration-color: #008000\"> image-text --&gt; text </span>│\n",
"│<span style=\"color: #008080; text-decoration-color: #008080\"> transformers </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> vikhyatk/moondream2 </span>│<span style=\"color: #008000; text-decoration-color: #008000\"> image-text --&gt; text </span>│\n",
"└──────────────┴───────────────────────────────────┴─────────────────────┘\n",
"<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-style: italic\"> Available Models </span>\n",
"┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓\n",
"┃<span style=\"font-weight: bold\"> Backend </span>┃<span style=\"font-weight: bold\"> Model ID </span>┃<span style=\"font-weight: bold\"> Input --&gt; Output </span>┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩\n",
"│<span style=\"color: #008080; text-decoration-color: #008080\"> custom-transformers </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> vikhyatk/moondream2 </span>│<span style=\"color: #008000; text-decoration-color: #008000\"> image-text --&gt; text </span>│\n",
"└─────────────────────┴─────────────────────┴─────────────────────┘\n",
"</pre>\n"
],
"text/plain": [
"\u001b[3m Available Models \u001b[0m\n",
"┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓\n",
"\u001b[1m \u001b[0m\u001b[1mBackend \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mModel ID \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mInput --> Output \u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩\n",
"\u001b[36m \u001b[0m\u001b[36mtransformers\u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35mSalesforce/blip2-opt-2.7b \u001b[0m\u001b[35m \u001b[0m│\u001b[32m \u001b[0m\u001b[32mimage-text --> text\u001b[0m\u001b[32m \u001b[0m│\n",
"\u001b[36m \u001b[0m\u001b[36mtransformers\u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35msashakunitsyn/vlrm-blip2-opt-2.7b\u001b[0m\u001b[35m \u001b[0m│\u001b[32m \u001b[0m\u001b[32mimage-text --> text\u001b[0m\u001b[32m \u001b[0m│\n",
"\u001b[36m \u001b[0m\u001b[36mtransformers\u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35mvikhyatk/moondream2 \u001b[0m\u001b[35m \u001b[0m│\u001b[32m \u001b[0m\u001b[32mimage-text --> text\u001b[0m\u001b[32m \u001b[0m│\n",
"└──────────────┴───────────────────────────────────┴─────────────────────┘\n"
"\u001b[3m Available Models \u001b[0m\n",
"┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓\n",
"\u001b[1m \u001b[0m\u001b[1mBackend \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mModel ID \u001b[0m\u001b[1m \u001b[0m┃\u001b[1m \u001b[0m\u001b[1mInput --> Output \u001b[0m\u001b[1m \u001b[0m┃\n",
"┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩\n",
"\u001b[36m \u001b[0m\u001b[36mcustom-transformers\u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35mvikhyatk/moondream2\u001b[0m\u001b[35m \u001b[0m│\u001b[32m \u001b[0m\u001b[32mimage-text --> text\u001b[0m\u001b[32m \u001b[0m│\n",
"└─────────────────────┴─────────────────────┴─────────────────────┘\n"
]
},
"metadata": {},
Expand All @@ -44,12 +32,12 @@
"source": [
"import xinfer\n",
"\n",
"xinfer.list_models()"
"xinfer.list_models(\"moondream\")"
]
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 3,
"metadata": {},
"outputs": [
{
Expand All @@ -64,14 +52,15 @@
}
],
"source": [
"model = xinfer.create_model(\"vikhyatk/moondream2\", \"transformers\")\n",
"model = xinfer.create_model(\"vikhyatk/moondream2\", \"custom-transformers\")\n",
"# model = xinfer.create_model(\"Salesforce/blip2-opt-2.7b\", \"transformers\")\n",
"# model = xinfer.create_model(\"sashakunitsyn/vlrm-blip2-opt-2.7b\", \"transformers\")\n"
"# model = xinfer.create_model(\"sashakunitsyn/vlrm-blip2-opt-2.7b\", \"transformers\")\n",
"# model = xinfer.create_model(\"microsoft/Florence-2\", \"transformers\")"
]
},
{
"cell_type": "code",
"execution_count": 3,
"execution_count": 4,
"metadata": {},
"outputs": [
{
Expand All @@ -80,14 +69,14 @@
"'An animated character with long hair and a serious expression is eating a large burger at a table, with other characters in the background.'"
]
},
"execution_count": 3,
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"image = \"https://raw.githubusercontent.com/vikhyat/moondream/main/assets/demo-1.jpg\"\n",
"prompt = \"Describe this image. \"\n",
"prompt = \"Describe this image\"\n",
"\n",
"model.inference(image, prompt, max_new_tokens=50)"
]
Expand Down Expand Up @@ -116,7 +105,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.10"
"version": "3.10.15"
}
},
"nbformat": 4,
Expand Down
40 changes: 24 additions & 16 deletions xinfer/model_factory.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,33 @@
from huggingface_hub import HfApi
from rich.console import Console
from rich.table import Table

from .model_registry import InputOutput, ModelRegistry
from .timm import TimmModel, timm_models
from .transformers.blip2 import BLIP2, VLRMBlip2
from .transformers.moondream import Moondream
from .transformers.transformers_model import TransformerVisionLanguageModel
from .ultralytics import UltralyticsYoloModel, ultralytics_models


def get_vision_language_models():
api = HfApi()
models = api.list_models(filter="image-to-text")
return [model.modelId for model in models]


def register_models():
ModelRegistry.register(
"transformers",
"Salesforce/blip2-opt-2.7b",
BLIP2,
input_output=InputOutput.IMAGE_TEXT_TO_TEXT,
)
ModelRegistry.register(
"transformers",
"sashakunitsyn/vlrm-blip2-opt-2.7b",
VLRMBlip2,
input_output=InputOutput.IMAGE_TEXT_TO_TEXT,
)
hf_vision_language_models = get_vision_language_models()

for model in hf_vision_language_models:
ModelRegistry.register(
"transformers",
model,
TransformerVisionLanguageModel,
input_output=InputOutput.IMAGE_TEXT_TO_TEXT,
)

ModelRegistry.register(
"transformers",
"custom-transformers",
"vikhyatk/moondream2",
Moondream,
input_output=InputOutput.IMAGE_TEXT_TO_TEXT,
Expand Down Expand Up @@ -51,10 +55,12 @@ def create_model(model_id: str, backend: str, **kwargs):
kwargs["model_name"] = model_id
if backend == "ultralytics":
kwargs["model_name"] = model_id + ".pt"
if backend == "transformers":
kwargs["model_name"] = model_id
return ModelRegistry.get_model(model_id, backend, **kwargs)


def list_models(wildcard: str = None, limit: int = 20):
def list_models(wildcard: str = None, backend: str = None, limit: int = 20):
console = Console()
table = Table(title="Available Models")
table.add_column("Backend", style="cyan")
Expand All @@ -63,7 +69,9 @@ def list_models(wildcard: str = None, limit: int = 20):

rows = []
for model in ModelRegistry.list_models():
if wildcard is None or wildcard.lower() in model["model_id"].lower():
if (wildcard is None or wildcard.lower() in model["model_id"].lower()) and (
backend is None or backend.lower() == model["backend"].lower()
):
rows.append(
(
model["backend"],
Expand Down
1 change: 1 addition & 0 deletions xinfer/transformers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
from .transformers_model import TransformerVisionLanguageModel
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
import requests
import torch
from PIL import Image
from transformers import Blip2ForConditionalGeneration, Blip2Processor
from transformers import AutoModelForVision2Seq, AutoProcessor

from ..base_model import BaseModel


class BLIP2(BaseModel):
def __init__(self, model_name: str = "Salesforce/blip2-opt-2.7b", **kwargs):
class TransformerVisionLanguageModel(BaseModel):
def __init__(self, model_name: str, **kwargs):
self.model_name = model_name
self.device = "cuda" if torch.cuda.is_available() else "cpu"
self.load_model(**kwargs)

def load_model(self, **kwargs):
self.processor = Blip2Processor.from_pretrained(self.model_name, **kwargs)
self.model = Blip2ForConditionalGeneration.from_pretrained(
self.processor = AutoProcessor.from_pretrained(self.model_name, **kwargs)
self.model = AutoModelForVision2Seq.from_pretrained(
self.model_name, **kwargs
).to(self.device, torch.bfloat16)

Expand All @@ -26,11 +26,9 @@ def preprocess(self, image: str | Image.Image, prompt: str = None):
if image.startswith(("http://", "https://")):
image = Image.open(requests.get(image, stream=True).raw).convert("RGB")
else:
raise ValueError("Input string must be an image URL for BLIP2")
else:
raise ValueError(
"Input must be either an image URL or a PIL Image for BLIP2"
)
raise ValueError("Input string must be an image URL")
elif not isinstance(image, Image.Image):
raise ValueError("Input must be either an image URL or a PIL Image")

return self.processor(images=image, text=prompt, return_tensors="pt").to(
self.device
Expand All @@ -52,7 +50,12 @@ def inference(self, image, prompt, **generate_kwargs):
return self.postprocess(prediction)


class VLRMBlip2(BLIP2):
# class BLIP2(TransformerVisionLanguageModel):
# def __init__(self, model_name: str = "Salesforce/blip2-opt-2.7b", **kwargs):
# super().__init__(model_name, **kwargs)


class VLRMBlip2(TransformerVisionLanguageModel):
def __init__(self, model_name: str = "sashakunitsyn/vlrm-blip2-opt-2.7b", **kwargs):
super().__init__(model_name, **kwargs)
self.load_vlrm_weights()
Expand Down

0 comments on commit f6c1eaf

Please sign in to comment.