\n",
- "\n"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "wR53lePHuiP-"
- },
- "source": [
+ "\n",
+ "\n",
"This notebook shows how to fine-tune [PaliGemma](https://ai.google.dev/gemma/docs/paligemma) on a vision-language task with [JAX](https://jax.readthedocs.io/en/latest/index.html). *Fine-tuning* is a process that can improve your model's performance on specific tasks or help the model adhere to specific output requirements when instructions aren't sufficient and you have a set of examples that demonstrate the outputs you want. Gemma-based models like PaliGemma require fine-tuning to produce expected results.\n",
"\n",
"### What's in this notebook\n",
@@ -172,7 +165,11 @@
"# vars as appropriate or make your credentials available in ~/.kaggle/kaggle.json\n",
"\n",
"os.environ[\"KAGGLE_USERNAME\"] = userdata.get('KAGGLE_USERNAME')\n",
- "os.environ[\"KAGGLE_KEY\"] = userdata.get('KAGGLE_KEY')"
+ "os.environ[\"KAGGLE_KEY\"] = userdata.get('KAGGLE_KEY')\n",
+ "\n",
+ "# The T4 runtime is tight on memory to finetune this model. Preallocate\n",
+ "# all memory ahead of time to avoid out-of-memory due to fragmentation.\n",
+ "os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"] = \"1.0\""
]
},
{
@@ -265,7 +262,7 @@
"tf.config.set_visible_devices([], \"GPU\")\n",
"tf.config.set_visible_devices([], \"TPU\")\n",
"\n",
- "backend = jax.lib.xla_bridge.get_backend()\n",
+ "backend = jax.extend.backend.get_backend()\n",
"print(f\"JAX version: {jax.__version__}\")\n",
"print(f\"JAX platform: {backend.platform}\")\n",
"print(f\"JAX devices: {jax.device_count()}\")"
@@ -292,7 +289,7 @@
"\n",
"PaliGemma includes several model variations. For this tutorial, you'll use the base [JAX/FLAX PaliGemma 3B weight model](https://www.kaggle.com/models/google/paligemma/jax/paligemma-3b-pt-224).\n",
"\n",
- "Download the `float16` version of the model checkpoint from Kaggle by running the following code. This process takes several minutes to complete."
+ "Download the model checkpoint from Kaggle by running the following code. This process takes several minutes to complete."
]
},
{
@@ -306,12 +303,19 @@
"import os\n",
"import kagglehub\n",
"\n",
- "MODEL_PATH = \"./pt_224_128.params.f16.npz\"\n",
+ "# Use these for PaliGemma-2 3B 224px²\n",
+ "LLM_VARIANT = \"gemma2_2b\"\n",
+ "MODEL_PATH = \"./paligemma2-3b-pt-224.b16.npz\"\n",
+ "KAGGLE_HANDLE = \"google/paligemma-2/jax/paligemma2-3b-pt-224\" # Path to fetch from Kaggle.\n",
+ "\n",
+ "# Use these for PaliGemma 1:\n",
+ "# LLM_VARIANT = \"gemma_2b\"\n",
+ "# MODEL_PATH = \"./paligemma-3b-pt-224.f16.npz\"\n",
+ "# KAGGLE_HANDLE = \"google/paligemma/jax/paligemma-3b-pt-224\"\n",
+ "\n",
"if not os.path.exists(MODEL_PATH):\n",
" print(\"Downloading the checkpoint from Kaggle, this could take a few minutes....\")\n",
- " # Note: kaggle archive contains the same checkpoint in multiple formats.\n",
- " # Download only the float16 model.\n",
- " MODEL_PATH = kagglehub.model_download('google/paligemma/jax/paligemma-3b-pt-224', 'paligemma-3b-pt-224.f16.npz')\n",
+ " MODEL_PATH = kagglehub.model_download(KAGGLE_HANDLE, MODEL_PATH)\n",
" print(f\"Model path: {MODEL_PATH}\")\n",
"\n",
"TOKENIZER_PATH = \"./paligemma_tokenizer.model\"\n",
@@ -360,8 +364,11 @@
"outputs": [],
"source": [
"# Define model\n",
+ "\n",
+ "# IMPORTANT: Gemma-2 has a \"final_logits_softcap\" property, we set it to 0.0\n",
+ "# for better transfer results.\n",
"model_config = ml_collections.FrozenConfigDict({\n",
- " \"llm\": {\"vocab_size\": 257_152},\n",
+ " \"llm\": {\"vocab_size\": 257_152, \"variant\": LLM_VARIANT, \"final_logits_softcap\": 0.0},\n",
" \"img\": {\"variant\": \"So400m/14\", \"pool_type\": \"none\", \"scan\": True, \"dtype_mm\": \"float16\"}\n",
"})\n",
"model = paligemma.Model(**model_config)\n",
@@ -420,7 +427,9 @@
"\n",
"@functools.partial(jax.jit, donate_argnums=(0,), static_argnums=(1,))\n",
"def maybe_cast_to_f32(params, trainable):\n",
- " return jax.tree.map(lambda p, m: p.astype(jnp.float32) if m else p,\n",
+ " # Cast others to float16, since some GPUs don't support bf16.\n",
+ " return jax.tree.map(lambda p, m: p.astype(jnp.float32)\n",
+ " if m else p.astype(jnp.float16),\n",
" params, trainable)\n",
"\n",
"# Loading all params in simultaneous - albeit much faster and more succinct -\n",
diff --git a/site/en/gemma/docs/paligemma/inference-with-keras.ipynb b/site/en/gemma/docs/paligemma/inference-with-keras.ipynb
index 32581fb4b..43b5498ab 100644
--- a/site/en/gemma/docs/paligemma/inference-with-keras.ipynb
+++ b/site/en/gemma/docs/paligemma/inference-with-keras.ipynb
@@ -41,24 +41,17 @@
"# limitations under the License."
]
},
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "etcMXWCUJApZ"
- },
- "source": [
- "# Inference with Keras\n"
- ]
- },
{
"cell_type": "markdown",
"metadata": {
"id": "Q5_nIe-8gdJV"
},
"source": [
+ "# Generate PaliGemma output with Keras\n",
+ "\n",
"
\n",
- "\n",
- "When your AI model produces a conclusion or a prediction, it goes through a process called *inference*. This tutorial goes over how to use PaliGemma with Keras to set up a simple model that can infer information about supplied images and answer questions about them."
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "9hhIuS9sEKHx"
+ },
+ "source": [
+ "PaliGemma models have *multimodal* capabilities, allowing you to generate output using both text and image input data. You can use image data with these models to provide additional context for your requests, or use the model to analyze the content of images. This tutorial shows you how to use PaliGemma with Keras to can analyze images and answer questions about them."
]
},
{
@@ -223,7 +223,7 @@
"outputs": [],
"source": [
"import keras\n",
- "import keras_nlp\n",
+ "import keras_hub\n",
"import numpy as np\n",
"import PIL\n",
"import requests\n",
@@ -237,26 +237,17 @@
"keras.config.set_floatx(\"bfloat16\")"
]
},
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "ftjt5DiueVkL"
- },
- "source": [
- "## Create your model\n",
- "\n",
- "Now that you've set everything up, you can download the pre-trained model and create some utility methods to help your model generate its responses."
- ]
- },
{
"cell_type": "markdown",
"metadata": {
"id": "X-LE2E1uiSpP"
},
"source": [
- "### Download the model checkpoint\n",
+ "## Load the model\n",
+ "\n",
+ "Now that you've set everything up, you can download the pre-trained model and create some utility methods to help your model generate its responses.\n",
+ "In this step, you download a model using `PaliGemmaCausalLM` from Keras Hub. This class helps you manage and run the causal visual language model structure of PaliGemma. A *causal visual language model* predicts the next token based on previous tokens. Keras Hub provides implementations of many popular [model architectures](https://keras.io/keras_hub/api/models/).\n",
"\n",
- "KerasNLP provides implementations of many popular [model architectures](https://keras.io/api/keras_nlp/models/). In this notebook, you'll create a model using `PaliGemmaCausalLM`, an end-to-end PaliGemma model for *causal visual language modeling*. A causal visual language model predicts the next token based on previous tokens.\n",
"\n",
"Create the model using the `from_preset` method and print its summary. This process will take about a minute to complete."
]
@@ -269,7 +260,7 @@
},
"outputs": [],
"source": [
- "paligemma = keras_nlp.models.PaliGemmaCausalLM.from_preset(\"pali_gemma_3b_mix_224\")\n",
+ "paligemma = keras_hub.models.PaliGemmaCausalLM.from_preset(\"paligemma_3b_mix_224\")\n",
"paligemma.summary()"
]
},
@@ -279,7 +270,7 @@
"id": "FBsWvKEvoGMe"
},
"source": [
- "### Create utility methods\n",
+ "## Create utility methods\n",
"\n",
"To help you generate responses from your model, create two utility methods:\n",
"\n",
@@ -352,26 +343,46 @@
"\n",
" plt.show()\n",
"\n",
- "def display_segment_output(image, segment_mask, target_image_size):\n",
- " # Calculate scaling factors\n",
- " h, w = target_image_size\n",
- " x_scale = w / 64\n",
- " y_scale = h / 64\n",
- "\n",
- " # Create coordinate grids for the new image\n",
- " x_coords = np.arange(w)\n",
- " y_coords = np.arange(h)\n",
- " x_coords = (x_coords / x_scale).astype(int)\n",
- " y_coords = (y_coords / y_scale).astype(int)\n",
- " resized_array = segment_mask[y_coords[:, np.newaxis], x_coords]\n",
- " # Create a figure and axis\n",
- " fig, ax = plt.subplots()\n",
- "\n",
- " # Display the image\n",
- " ax.imshow(image)\n",
- "\n",
- " # Overlay the mask with transparency\n",
- " ax.imshow(resized_array, cmap='jet', alpha=0.5)"
+ "def display_segment_output(image, bounding_box, segment_mask, target_image_size):\n",
+ " # Initialize a full mask with the target size\n",
+ " full_mask = np.zeros(target_image_size, dtype=np.uint8)\n",
+ " target_width, target_height = target_image_size\n",
+ "\n",
+ " for bbox, mask in zip(bounding_box, segment_mask):\n",
+ " y1, x1, y2, x2 = bbox\n",
+ " x1 = int(x1 * target_width)\n",
+ " y1 = int(y1 * target_height)\n",
+ " x2 = int(x2 * target_width)\n",
+ " y2 = int(y2 * target_height)\n",
+ "\n",
+ " # Ensure mask is 2D before converting to Image\n",
+ " if mask.ndim == 3:\n",
+ " mask = mask.squeeze(axis=-1)\n",
+ " mask = Image.fromarray(mask)\n",
+ " mask = mask.resize((x2 - x1, y2 - y1), resample=Image.NEAREST)\n",
+ " mask = np.array(mask)\n",
+ " binary_mask = (mask > 0.5).astype(np.uint8)\n",
+ "\n",
+ "\n",
+ " # Place the binary mask onto the full mask\n",
+ " full_mask[y1:y2, x1:x2] = np.maximum(full_mask[y1:y2, x1:x2], binary_mask)\n",
+ " cmap = plt.get_cmap('jet')\n",
+ " colored_mask = cmap(full_mask / 1.0)\n",
+ " colored_mask = (colored_mask[:, :, :3] * 255).astype(np.uint8)\n",
+ " if isinstance(image, Image.Image):\n",
+ " image = np.array(image)\n",
+ " blended_image = image.copy()\n",
+ " mask_indices = full_mask > 0\n",
+ " alpha = 0.5\n",
+ "\n",
+ " for c in range(3):\n",
+ " blended_image[:, :, c] = np.where(mask_indices,\n",
+ " (1 - alpha) * image[:, :, c] + alpha * colored_mask[:, :, c],\n",
+ " image[:, :, c])\n",
+ "\n",
+ " fig, ax = plt.subplots()\n",
+ " ax.imshow(blended_image)\n",
+ " plt.show()"
]
},
{
@@ -380,11 +391,11 @@
"id": "AeVUHA_zP8ZF"
},
"source": [
- "## Test your model\n",
+ "## Generate output\n",
"\n",
- "Now you're ready to give an image and prompt to your model and have it infer the response.\n",
+ "After loading the model and creating utility methods, you can prompt the model with image and text data to generate a responses. PaliGemma models are trained with specific prompt syntax for specific tasks, such as `answer`, `caption`, and `detect`. For more information about PaliGemma prompt task syntax, see [PaliGemma prompt and system instructions](https://ai.google.com/gemma/docs/paligemma/prompt-system-instructions##prompt_task_syntax).\n",
"\n",
- "Lets look at our test image and read it\n"
+ "Prepare an image for use in a generation prompt by using the following code to load a test image into an object:"
]
},
{
@@ -407,9 +418,9 @@
"id": "R3FE63iMqb2G"
},
"source": [
- "Here's a generation call with a single image and prompt. The prompts have to end with a `\\n`.\n",
+ "### Answer in a specific language\n",
"\n",
- "We've supplied you with several example prompts — play around with it! Comment and uncomment the prompt variables to change what prompt you supply the model with."
+ "The following example code shows how to prompt the PaliGemma model for information about an object appearing in a provided image. This example uses the `answer {lang}` syntax and shows additional questions in other languages:"
]
},
{
@@ -421,9 +432,10 @@
"outputs": [],
"source": [
"prompt = 'answer en where is the cow standing?\\n'\n",
- "# prompt = 'svar no hvor står kuen?'\n",
- "# prompt = 'answer fr quelle couleur est le ciel?'\n",
- "# prompt = 'responda pt qual a cor do animal?'\n",
+ "# prompt = 'svar no hvor står kuen?\\n'\n",
+ "# prompt = 'answer fr quelle couleur est le ciel?\\n'\n",
+ "# prompt = 'responda pt qual a cor do animal?\\n'\n",
+ "\n",
"output = paligemma.generate(\n",
" inputs={\n",
" \"images\": cow_image,\n",
@@ -436,69 +448,10 @@
{
"cell_type": "markdown",
"metadata": {
- "id": "MCMqS0vI6S36"
- },
- "source": [
- "Here's a generation call with batched inputs."
- ]
- },
- {
- "cell_type": "code",
- "execution_count": null,
- "metadata": {
- "id": "UZle4sJP6YwB"
+ "id": "tMPv8tfISvif"
},
- "outputs": [],
"source": [
- "prompts = [\n",
- " 'answer en where is the cow standing?\\n',\n",
- " 'answer en what color is the cow?\\n',\n",
- " 'describe en\\n',\n",
- " 'detect cow\\n',\n",
- " 'segment cow\\n',\n",
- "]\n",
- "images = [cow_image, cow_image, cow_image, cow_image, cow_image]\n",
- "outputs = paligemma.generate(\n",
- " inputs={\n",
- " \"images\": images,\n",
- " \"prompts\": prompts,\n",
- " }\n",
- ")\n",
- "for output in outputs:\n",
- " print(output)"
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "qJKvH8to6kb6"
- },
- "source": [
- "We've supplied you with several example prompts — play around with it! Comment and uncomment the `prompt` variables to change what prompt you supply the model with."
- ]
- },
- {
- "cell_type": "markdown",
- "metadata": {
- "id": "-uqMwZQY31mu"
- },
- "source": [
- "### Other styles of prompts\n",
- "\n",
- "You may have noticed in the previous step that the provided examples are in several different languages. PaliGemma supports language recognition for 34 different languages. You can find the list of supported languages on [GitHub](https://github.com/google/crossmodal-3600/blob/main/web-data/README.md).\n",
- "\n",
- "PaliGemma can handle several other prompt styles:\n",
- "\n",
- "* **`\"cap {lang}\\n\"`:** Very raw short caption (from WebLI-alt)\n",
- "* **`\"caption {lang}\\n\"`:** Nice, COCO-like short captions\n",
- "* **`\"describe {lang}\\n\"`:** Somewhat longer, more descriptive captions\n",
- "* **`\"ocr\"`:** Optical character recognition\n",
- "* **`\"answer en {question}\\n\"`:** Question answering about the image contents\n",
- "* **`\"question {lang} {answer}\\n\"`:** Question generation for a given answer\n",
- "* **`\"detect {object} ; {object}\\n\"`:** Count objects in a scene and return the bounding boxes for the objects\n",
- "* **`\"segment {object}\\n\"`:** Do image segmentation of the object in the scene\n",
- "\n",
- "Try them out!"
+ "Note: Prompts that use PaliGemma prompt command syntax must end with an \"`\\n`\" character."
]
},
{
@@ -507,7 +460,9 @@
"id": "eUgPcUGDpsdx"
},
"source": [
- "### Parse detect output"
+ "### Use `detect` prompt\n",
+ "\n",
+ "The following example code uses the `detect` prompt syntax to locate an object in the provided image. The code uses the previously defined `parse_bbox_and_labels()` and `display_boxes()` functions to interpret the model output and display the generated bounding boxes."
]
},
{
@@ -532,22 +487,24 @@
{
"cell_type": "markdown",
"metadata": {
- "id": "mUwNeaV0sGau"
+ "id": "TET0Bz10mzxD"
},
"source": [
- "### Parse segment output"
+ "### Use `segment` prompt\n",
+ "\n",
+ "The following example code uses the `segment` prompt syntax to locate the area of an image occupied by an object. It uses the Google `big_vision` library to interpret the model output and generate a mask for the segemented object.\n",
+ "\n",
+ "Before getting started, install the `big_vision` library and its dependencies, as shown in this code example:\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
- "cellView": "form",
"id": "tuxyIMAvsM4B"
},
"outputs": [],
"source": [
- "# @title Fetch big_vision code and install dependencies.\n",
"import os\n",
"import sys\n",
"\n",
@@ -576,7 +533,7 @@
"id": "18S4uHWqutps"
},
"source": [
- "Let's take a look at another example image."
+ "For this segmentation example, load and prepare a different image that includes a cat."
]
},
{
@@ -668,8 +625,47 @@
},
"outputs": [],
"source": [
- "_, seg_output = parse_segments(output)\n",
- "display_segment_output(cat, seg_output[0], target_size)"
+ "bboxes, seg_masks = parse_segments(output)\n",
+ "display_segment_output(cat, bboxes, seg_masks, target_size)"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "metadata": {
+ "id": "jI9YAgLuGyAb"
+ },
+ "source": [
+ "### Batch prompts\n",
+ "\n",
+ "You can provide more than one prompt command within a single prompt as a batch of instructions. The following example demonstrates how to structure your prompt text to provide multiple instructions.\n",
+ "\n",
+ "Important: Each prompt command must end with an \"\\n\" character, as shown."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "metadata": {
+ "id": "UZle4sJP6YwB"
+ },
+ "outputs": [],
+ "source": [
+ "prompts = [\n",
+ " 'answer en where is the cow standing?\\n',\n",
+ " 'answer en what color is the cow?\\n',\n",
+ " 'describe en\\n',\n",
+ " 'detect cow\\n',\n",
+ " 'segment cow\\n',\n",
+ "]\n",
+ "images = [cow_image, cow_image, cow_image, cow_image, cow_image]\n",
+ "outputs = paligemma.generate(\n",
+ " inputs={\n",
+ " \"images\": images,\n",
+ " \"prompts\": prompts,\n",
+ " }\n",
+ ")\n",
+ "for output in outputs:\n",
+ " print(output)"
]
}
],