-
Notifications
You must be signed in to change notification settings - Fork 926
/
Copy pathclip.py
31 lines (24 loc) · 1022 Bytes
/
clip.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
from typing import Tuple
from image_processor import CLIPImageProcessor
from model import CLIPModel
from tokenizer import CLIPTokenizer
def load(model_dir: str) -> Tuple[CLIPModel, CLIPTokenizer, CLIPImageProcessor]:
model = CLIPModel.from_pretrained(model_dir)
tokenizer = CLIPTokenizer.from_pretrained(model_dir)
img_processor = CLIPImageProcessor.from_pretrained(model_dir)
return model, tokenizer, img_processor
if __name__ == "__main__":
from PIL import Image
model, tokenizer, img_processor = load("mlx_model")
inputs = {
"input_ids": tokenizer(["a photo of a cat", "a photo of a dog"]),
"pixel_values": img_processor(
[Image.open("assets/cat.jpeg"), Image.open("assets/dog.jpeg")]
),
}
output = model(**inputs)
# Get text and image embeddings:
text_embeds = output.text_embeds
image_embeds = output.image_embeds
print("Text embeddings shape:", text_embeds.shape)
print("Image embeddings shape:", image_embeds.shape)