Skip to content

Commit

Permalink
update api (#1)
Browse files Browse the repository at this point in the history
  • Loading branch information
dnth authored Oct 11, 2024
1 parent 5076b8a commit 61b076e
Show file tree
Hide file tree
Showing 7 changed files with 128 additions and 166 deletions.
59 changes: 23 additions & 36 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -45,87 +45,74 @@ Install PyTorch and transformers in your environment.
Here's a quick example demonstrating how to use xinfer with a Transformers model:

```python
from xinfer import get_model
import xinfer

# Instantiate a Transformers model
model = get_model("Salesforce/blip2-opt-2.7b", backend="transformers")
model = xinfer.create_model("vikhyatk/moondream2", "transformers")

# Input data
image = "https://img.freepik.com/free-photo/adorable-black-white-kitty-with-monochrome-wall-her_23-2148955182.jpg"
prompt = "What's in this image? Answer:"
image = "https://raw.githubusercontent.com/vikhyat/moondream/main/assets/demo-1.jpg"
prompt = "Describe this image. "

# Run inference
processed_input = model.preprocess(image, prompt)

prediction = model.predict(processed_input)
output = model.postprocess(prediction)

print(output)

>>> A cat on a yellow background


image = "https://img.freepik.com/free-photo/adorable-black-white-kitty-with-monochrome-wall-her_23-2148955182.jpg"
prompt = "Describe this image in concise detail. Answer:"

processed_input = model.preprocess(image, prompt)

# Change the max_new_tokens to 200
prediction = model.predict(processed_input, max_new_tokens=200)
output = model.postprocess(prediction)
output = model.inference(image, prompt, max_new_tokens=50)

print(output)
>>> a black and white cat sitting on a table looking up at the camera

>>> An animated character with long hair and a serious expression is eating a large burger at a table, with other characters in the background.
```

See [example.ipynb](nbs/example.ipynb) for more examples.


## Supported Models
Transformers:
- [Salesforce/blip2-opt-2.7b](https://huggingface.co/Salesforce/blip2-opt-2.7b)
- [sashakunitsyn/vlrm-blip2-opt-2.7b](https://huggingface.co/sashakunitsyn/vlrm-blip2-opt-2.7b)
- [vikhyatk/moondream2](https://huggingface.co/vikhyatk/moondream2)

Get a list of available models:
```python
from xinfer import list_models
import xinfer

list_models()
xinfer.list_models()
```

<table>
<thead>
<tr>
<th colspan="2">Available Models</th>
<th colspan="3">Available Models</th>
</tr>
<tr>
<th>backend</th>
<th>Model Type</th>
<th>Backend</th>
<th>Model ID</th>
<th>Input/Output</th>
</tr>
</thead>
<tbody>
<tr>
<td>transformers</td>
<td>Salesforce/blip2-opt-2.7b</td>
<td>image-text --> text</td>
</tr>
<tr>
<td>transformers</td>
<td>sashakunitsyn/vlrm-blip2-opt-2.7b</td>
<td>image-text --> text</td>
</tr>
<tr>
<td>transformers</td>
<td>vikhyatk/moondream2</td>
<td>image-text --> text</td>
</tr>
</tbody>
</table>

See [example.ipynb](nbs/example.ipynb) for more examples.


## Adding New Models

+ Step 1: Create a new model class that implements the `BaseModel` interface.

+ Step 2: Implement the required abstract methods:
- `load_model`
- `preprocess`
- `predict`
- `postprocess`
+ Step 2: Implement the required abstract methods `load_model` and `inference`.

+ Step 3: Update `register_models` in `model_factory.py` to import the new model class and register it.

34 changes: 14 additions & 20 deletions docs/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,45 +23,39 @@ This will display a table of available models and their backends and input/outpu
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
│ transformers │ Salesforce/blip2-opt-2.7b │ image-text --> text │
│ transformers │ sashakunitsyn/vlrm-blip2-opt-2.7b │ image-text --> text │
│ transformers │ vikhyatk/moondream2 │ image-text --> text │
└──────────────┴───────────────────────────────────┴─────────────────────┘
```

## Loading and Using a Model

### BLIP2 Model

Here's an example of how to load and use the BLIP2 model:
You can load and use any of the available models. Here's an example using the Moondream2 model:

```python
# Instantiate a Transformers model
model = xinfer.create_model("Salesforce/blip2-opt-2.7b", backend="transformers")
model = xinfer.create_model("vikhyatk/moondream2", backend="transformers")

# Input data
image = "https://example.com/path/to/image.jpg"
prompt = "What's in this image? Answer:"
image = "https://raw.githubusercontent.com/vikhyat/moondream/main/assets/demo-1.jpg"
prompt = "Describe this image."

# Run inference
processed_input = model.preprocess(image, prompt)
prediction = model.predict(processed_input)
output = model.postprocess(prediction)
output = model.inference(image, prompt, max_new_tokens=50)

print(output)
```

You can also customize the generation parameters:

```python
prediction = model.predict(processed_input, max_new_tokens=200)
```
This will produce a description of the image, such as:
"An animated character with long hair and a serious expression is eating a large burger at a table, with other characters in the background."

### VLRM-finetuned BLIP2 Model

Similarly, you can use the VLRM-finetuned BLIP2 model:
You can use the same pattern for other models like BLIP2 or VLRM-finetuned BLIP2:

```python
model = xinfer.create_model("sashakunitsyn/vlrm-blip2-opt-2.7b", backend="transformers")
# For BLIP2
model = xinfer.create_model("Salesforce/blip2-opt-2.7b", backend="transformers")

# Use the model in the same way as the BLIP2 model
# For VLRM-finetuned BLIP2
model = xinfer.create_model("sashakunitsyn/vlrm-blip2-opt-2.7b", backend="transformers")
```

Both models can be used for tasks like image description and visual question answering.
Use the models in the same way as demonstrated with the Moondream2 model.
121 changes: 21 additions & 100 deletions nbs/example.ipynb → nbs/demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩\n",
"│<span style=\"color: #008080; text-decoration-color: #008080\"> transformers </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> Salesforce/blip2-opt-2.7b </span>│<span style=\"color: #008000; text-decoration-color: #008000\"> image-text --&gt; text </span>│\n",
"│<span style=\"color: #008080; text-decoration-color: #008080\"> transformers </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> sashakunitsyn/vlrm-blip2-opt-2.7b </span>│<span style=\"color: #008000; text-decoration-color: #008000\"> image-text --&gt; text </span>│\n",
"│<span style=\"color: #008080; text-decoration-color: #008080\"> transformers </span>│<span style=\"color: #800080; text-decoration-color: #800080\"> vikhyatk/moondream2 </span>│<span style=\"color: #008000; text-decoration-color: #008000\"> image-text --&gt; text </span>│\n",
"└──────────────┴───────────────────────────────────┴─────────────────────┘\n",
"</pre>\n"
],
Expand All @@ -32,6 +33,7 @@
"┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩\n",
"\u001b[36m \u001b[0m\u001b[36mtransformers\u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35mSalesforce/blip2-opt-2.7b \u001b[0m\u001b[35m \u001b[0m│\u001b[32m \u001b[0m\u001b[32mimage-text --> text\u001b[0m\u001b[32m \u001b[0m│\n",
"\u001b[36m \u001b[0m\u001b[36mtransformers\u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35msashakunitsyn/vlrm-blip2-opt-2.7b\u001b[0m\u001b[35m \u001b[0m│\u001b[32m \u001b[0m\u001b[32mimage-text --> text\u001b[0m\u001b[32m \u001b[0m│\n",
"\u001b[36m \u001b[0m\u001b[36mtransformers\u001b[0m\u001b[36m \u001b[0m│\u001b[35m \u001b[0m\u001b[35mvikhyatk/moondream2 \u001b[0m\u001b[35m \u001b[0m│\u001b[32m \u001b[0m\u001b[32mimage-text --> text\u001b[0m\u001b[32m \u001b[0m│\n",
"└──────────────┴───────────────────────────────────┴─────────────────────┘\n"
]
},
Expand All @@ -45,13 +47,6 @@
"xinfer.list_models()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load the BLIP2 Model and Run Inference"
]
},
{
"cell_type": "code",
"execution_count": 2,
Expand All @@ -61,76 +56,17 @@
"name": "stderr",
"output_type": "stream",
"text": [
"Loading checkpoint shards: 100%|██████████| 2/2 [00:00<00:00, 6.83it/s]\n",
"Expanding inputs for image tokens in BLIP-2 should be done in processing. Please follow instruction here (https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042) to update your BLIP-2 model. Using processors without these attributes in the config is deprecated and will throw an error in v4.47.\n",
"Expanding inputs for image tokens in BLIP-2 should be done in processing. Please follow instruction here (https://gist.github.com/zucchini-nlp/e9f20b054fa322f84ac9311d9ab67042) to update your BLIP-2 model. Using processors without these attributes in the config is deprecated and will throw an error in v4.47.\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" A cat sitting on a table, looking up at\n"
]
}
],
"source": [
"# Instantiate a Transformers model\n",
"model = xinfer.create_model(\"Salesforce/blip2-opt-2.7b\", backend=\"transformers\")\n",
"\n",
"# Input data\n",
"image = \"https://img.freepik.com/free-photo/adorable-black-white-kitty-with-monochrome-wall-her_23-2148955182.jpg\"\n",
"prompt = \"What's in this image? Answer:\"\n",
"\n",
"# Run inference\n",
"processed_input = model.preprocess(image, prompt)\n",
"\n",
"prediction = model.predict(processed_input)\n",
"output = model.postprocess(prediction)\n",
"\n",
"print(output)\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Both `max_new_tokens` (=200) and `max_length`(=51) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" a black and white cat sitting on a table looking up at the camera\n",
"\n"
"PhiForCausalLM has generative capabilities, as `prepare_inputs_for_generation` is explicitly overwritten. However, it doesn't directly inherit from `GenerationMixin`. From 👉v4.50👈 onwards, `PreTrainedModel` will NOT inherit from `GenerationMixin`, and this model will lose the ability to call `generate` and other related functions.\n",
" - If you're using `trust_remote_code=True`, you can get rid of this warning by loading the model with an auto class. See https://huggingface.co/docs/transformers/en/model_doc/auto#auto-classes\n",
" - If you are the owner of the model architecture code, please modify your model class such that it inherits from `GenerationMixin` (after `PreTrainedModel`, otherwise you'll get an exception).\n",
" - If you are not the owner of the model architecture class, please contact the model code owner to update it.\n"
]
}
],
"source": [
"# Input data\n",
"image = \"https://img.freepik.com/free-photo/adorable-black-white-kitty-with-monochrome-wall-her_23-2148955182.jpg\"\n",
"prompt = \"Describe this image in concise detail. Answer:\"\n",
"\n",
"# Run inference\n",
"processed_input = model.preprocess(image, prompt)\n",
"\n",
"prediction = model.predict(processed_input, max_new_tokens=200)\n",
"output = model.postprocess(prediction)\n",
"\n",
"print(output)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Load the VLRM-finetuned BLIP2 model"
"model = xinfer.create_model(\"vikhyatk/moondream2\", \"transformers\")\n",
"# model = xinfer.create_model(\"Salesforce/blip2-opt-2.7b\", \"transformers\")\n",
"# model = xinfer.create_model(\"sashakunitsyn/vlrm-blip2-opt-2.7b\", \"transformers\")\n"
]
},
{
Expand All @@ -139,36 +75,21 @@
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Loading checkpoint shards: 100%|██████████| 8/8 [00:02<00:00, 3.21it/s]\n",
"Both `max_new_tokens` (=200) and `max_length`(=51) seem to have been set. `max_new_tokens` will take precedence. Please refer to the documentation for more information. (https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" a black and white cat with white eyes sitting on a table in front of a yellow background\n",
"\n"
]
"data": {
"text/plain": [
"'An animated character with long hair and a serious expression is eating a large burger at a table, with other characters in the background.'"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = xinfer.create_model(\"sashakunitsyn/vlrm-blip2-opt-2.7b\", backend=\"transformers\")\n",
"\n",
"# Input data\n",
"image = \"https://img.freepik.com/free-photo/adorable-black-white-kitty-with-monochrome-wall-her_23-2148955182.jpg\"\n",
"prompt = \"What's in this image? Answer:\"\n",
"\n",
"# Run inference\n",
"processed_input = model.preprocess(image, prompt)\n",
"\n",
"prediction = model.predict(processed_input, max_new_tokens=200)\n",
"output = model.postprocess(prediction)\n",
"image = \"https://raw.githubusercontent.com/vikhyat/moondream/main/assets/demo-1.jpg\"\n",
"prompt = \"Describe this image. \"\n",
"\n",
"print(output)\n"
"model.inference(image, prompt, max_new_tokens=50)"
]
},
{
Expand All @@ -181,7 +102,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "xinfer",
"display_name": "xinfer-test",
"language": "python",
"name": "python3"
},
Expand Down
12 changes: 3 additions & 9 deletions xinfer/base_model.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,13 @@
from abc import ABC, abstractmethod

from PIL import Image


class BaseModel(ABC):
@abstractmethod
def load_model(self, **kwargs):
pass

@abstractmethod
def preprocess(self, input_data):
pass

@abstractmethod
def predict(self, processed_data):
pass

@abstractmethod
def postprocess(self, prediction):
def inference(self, image, prompt):
pass
8 changes: 8 additions & 0 deletions xinfer/model_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

from .model_registry import InputOutput, ModelRegistry
from .transformers.blip2 import BLIP2, VLRMBlip2
from .transformers.moondream import Moondream


def register_models():
Expand All @@ -19,6 +20,13 @@ def register_models():
input_output=InputOutput.IMAGE_TEXT_TO_TEXT,
)

ModelRegistry.register(
"transformers",
"vikhyatk/moondream2",
Moondream,
input_output=InputOutput.IMAGE_TEXT_TO_TEXT,
)


def create_model(model_id: str, backend: str, **kwargs):
return ModelRegistry.get_model(model_id, backend, **kwargs)
Expand Down
Loading

0 comments on commit 61b076e

Please sign in to comment.