Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Updating PaliGemma notebooks #543

Merged
merged 2 commits into from
Dec 5, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
92 changes: 62 additions & 30 deletions site/en/gemma/docs/paligemma/fine-tuning-paligemma.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
"id": "G3MMAcssHTML"
},
"source": [
"<link rel=\"stylesheet\" href=\"/site-assets/css/gemma.css\">\n",
"<link rel=\"stylesheet\" href=\"https://fonts.googleapis.com/css2?family=Google+Symbols:opsz,wght,FILL,[email protected],100..700,0..1,-50..200\" />"
"<link rel=\"stylesheet\" href=\"/site-assets/css/style.css\">\n",
"<link rel=\"stylesheet\" href=\"/site-assets/css/gemma.css\">\n"
]
},
{
Expand Down Expand Up @@ -59,15 +59,8 @@
"<td>\n",
"<a target=\"_blank\" href=\"https://github.com/google/generative-ai-docs/blob/main/site/en/gemma/docs/paligemma/fine-tuning-paligemma.ipynb\"><img src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a>\n",
"</td>\n",
"</table>\n"
]
},
{
"cell_type": "markdown",
"metadata": {
"id": "wR53lePHuiP-"
},
"source": [
"</table>\n",
"\n",
"This notebook shows how to fine-tune [PaliGemma](https://ai.google.dev/gemma/docs/paligemma) on a vision-language task with [JAX](https://jax.readthedocs.io/en/latest/index.html). *Fine-tuning* is a process that can improve your model's performance on specific tasks or help the model adhere to specific output requirements when instructions aren't sufficient and you have a set of examples that demonstrate the outputs you want. Gemma-based models like PaliGemma require fine-tuning to produce expected results.\n",
"\n",
"### What's in this notebook\n",
Expand Down Expand Up @@ -128,7 +121,8 @@
"\n",
"To generate a Kaggle API key, open your [**Settings** page in Kaggle](https://www.kaggle.com/settings) and click **Create New Token**. This triggers the download of a `kaggle.json` file containing your API credentials.\n",
"\n",
"Then, in Colab, select **Secrets** (🔑) in the left pane and add your Kaggle username and Kaggle API key. Store your username under the name `KAGGLE_USERNAME` and your API key under the name `KAGGLE_KEY`.\n"
"Then, in Colab, select **Secrets** (🔑) in the left pane and add your Kaggle username and Kaggle API key. Store your username under the name `KAGGLE_USERNAME` and your API key under the name `KAGGLE_KEY`.\n",
"\n"
]
},
{
Expand Down Expand Up @@ -172,7 +166,11 @@
"# vars as appropriate or make your credentials available in ~/.kaggle/kaggle.json\n",
"\n",
"os.environ[\"KAGGLE_USERNAME\"] = userdata.get('KAGGLE_USERNAME')\n",
"os.environ[\"KAGGLE_KEY\"] = userdata.get('KAGGLE_KEY')"
"os.environ[\"KAGGLE_KEY\"] = userdata.get('KAGGLE_KEY')\n",
"\n",
"# The T4 runtime is tight on memory to finetune this model. Preallocate\n",
"# all memory ahead of time to avoid out-of-memory due to fragmentation.\n",
"os.environ[\"XLA_PYTHON_CLIENT_MEM_FRACTION\"] = \"1.0\""
]
},
{
Expand Down Expand Up @@ -265,7 +263,7 @@
"tf.config.set_visible_devices([], \"GPU\")\n",
"tf.config.set_visible_devices([], \"TPU\")\n",
"\n",
"backend = jax.lib.xla_bridge.get_backend()\n",
"backend = jax.extend.backend.get_backend()\n",
"print(f\"JAX version: {jax.__version__}\")\n",
"print(f\"JAX platform: {backend.platform}\")\n",
"print(f\"JAX devices: {jax.device_count()}\")"
Expand All @@ -292,7 +290,7 @@
"\n",
"PaliGemma includes several model variations. For this tutorial, you'll use the base [JAX/FLAX PaliGemma 3B weight model](https://www.kaggle.com/models/google/paligemma/jax/paligemma-3b-pt-224).\n",
"\n",
"Download the `float16` version of the model checkpoint from Kaggle by running the following code. This process takes several minutes to complete."
"Download the model checkpoint from Kaggle by running the following code. This process takes several minutes to complete."
]
},
{
Expand All @@ -306,12 +304,19 @@
"import os\n",
"import kagglehub\n",
"\n",
"MODEL_PATH = \"./pt_224_128.params.f16.npz\"\n",
"# Use these for PaliGemma-2 3B 224px²\n",
"LLM_VARIANT = \"gemma2_2b\"\n",
"MODEL_PATH = \"./paligemma2-3b-pt-224.b16.npz\"\n",
"KAGGLE_HANDLE = \"google/paligemma-2/jax/paligemma2-3b-pt-224\" # Path to fetch from Kaggle.\n",
"\n",
"# Use these for PaliGemma 1:\n",
"# LLM_VARIANT = \"gemma_2b\"\n",
"# MODEL_PATH = \"./paligemma-3b-pt-224.f16.npz\"\n",
"# KAGGLE_HANDLE = \"google/paligemma/jax/paligemma-3b-pt-224\"\n",
"\n",
"if not os.path.exists(MODEL_PATH):\n",
" print(\"Downloading the checkpoint from Kaggle, this could take a few minutes....\")\n",
" # Note: kaggle archive contains the same checkpoint in multiple formats.\n",
" # Download only the float16 model.\n",
" MODEL_PATH = kagglehub.model_download('google/paligemma/jax/paligemma-3b-pt-224', 'paligemma-3b-pt-224.f16.npz')\n",
" MODEL_PATH = kagglehub.model_download(KAGGLE_HANDLE, MODEL_PATH)\n",
" print(f\"Model path: {MODEL_PATH}\")\n",
"\n",
"TOKENIZER_PATH = \"./paligemma_tokenizer.model\"\n",
Expand Down Expand Up @@ -360,8 +365,11 @@
"outputs": [],
"source": [
"# Define model\n",
"\n",
"# IMPORTANT: Gemma-2 has a \"final_logits_softcap\" property, we set it to 0.0\n",
"# for better transfer results.\n",
"model_config = ml_collections.FrozenConfigDict({\n",
" \"llm\": {\"vocab_size\": 257_152},\n",
" \"llm\": {\"vocab_size\": 257_152, \"variant\": LLM_VARIANT, \"final_logits_softcap\": 0.0},\n",
" \"img\": {\"variant\": \"So400m/14\", \"pool_type\": \"none\", \"scan\": True, \"dtype_mm\": \"float16\"}\n",
"})\n",
"model = paligemma.Model(**model_config)\n",
Expand Down Expand Up @@ -420,7 +428,9 @@
"\n",
"@functools.partial(jax.jit, donate_argnums=(0,), static_argnums=(1,))\n",
"def maybe_cast_to_f32(params, trainable):\n",
" return jax.tree.map(lambda p, m: p.astype(jnp.float32) if m else p,\n",
" # Cast others to float16, since some GPUs don't support bf16.\n",
" return jax.tree.map(lambda p, m: p.astype(jnp.float32)\n",
" if m else p.astype(jnp.float16),\n",
" params, trainable)\n",
"\n",
"# Loading all params in simultaneous - albeit much faster and more succinct -\n",
Expand Down Expand Up @@ -492,7 +502,7 @@
"\n",
" image = tf.constant(image)\n",
" image = tf.image.resize(image, (size, size), method='bilinear', antialias=True)\n",
" return image.numpy() / 127.5 - 1.0 # [0, 255]->[-1,1]\n",
" return image.numpy() / 127.5 - 1.0 # [0, 255]-\u003e[-1,1]\n",
"\n",
"def preprocess_tokens(prefix, suffix=None, seqlen=None):\n",
" # Model has been trained to handle tokenized text composed of a prefix with\n",
Expand Down Expand Up @@ -632,12 +642,12 @@
" return f\"data:image/jpeg;base64,{image_b64}\"\n",
"\n",
"def render_example(image, caption):\n",
" image = ((image + 1)/2 * 255).astype(np.uint8) # [-1,1] -> [0, 255]\n",
" image = ((image + 1)/2 * 255).astype(np.uint8) # [-1,1] -\u003e [0, 255]\n",
" return f\"\"\"\n",
" <div style=\"display: inline-flex; align-items: center; justify-content: center;\">\n",
" <img style=\"width:128px; height:128px;\" src=\"{render_inline(image, resize=(64,64))}\" />\n",
" <p style=\"width:256px; margin:10px; font-size:small;\">{html.escape(caption)}</p>\n",
" </div>\n",
" \u003cdiv style=\"display: inline-flex; align-items: center; justify-content: center;\"\u003e\n",
" \u003cimg style=\"width:128px; height:128px;\" src=\"{render_inline(image, resize=(64,64))}\" /\u003e\n",
" \u003cp style=\"width:256px; margin:10px; font-size:small;\"\u003e{html.escape(caption)}\u003c/p\u003e\n",
" \u003c/div\u003e\n",
" \"\"\"\n",
"\n",
"html_out = \"\"\n",
Expand Down Expand Up @@ -754,7 +764,7 @@
" # Append to html output.\n",
" for example, response in zip(examples, responses):\n",
" outputs.append((example[\"image\"], response))\n",
" if num_examples and len(outputs) >= num_examples:\n",
" if num_examples and len(outputs) \u003e= num_examples:\n",
" return outputs"
]
},
Expand Down Expand Up @@ -862,14 +872,36 @@
],
"metadata": {
"colab": {
"name": "fine-tuning-paligemma.ipynb",
"gpuType": "T4",
"last_runtime": {
"build_target": "//learning/grp/tools/ml_python:ml_notebook",
"kind": "private"
},
"private_outputs": true,
"provenance": [
{
"file_id": "17AiK8gRY7oiquQGkBH0d08PFQo3Kyx1I",
"timestamp": 1715287187925
},
{
"file_id": "1qZlJfPyfKRrNcz2shxQ93HnnE5Ge1LLn",
"timestamp": 1715019972450
},
{
"file_id": "1JFnlD2kSiTNexdPw_NYRtuW6uuSTI0kD",
"timestamp": 1714585741026
}
],
"toc_visible": true
},
"kernelspec": {
"display_name": "Python 3",
"name": "python3"
},
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 0
}
}
Loading
Loading