Skip to content

Commit

Permalink
Merge branch 'dev' into oom-observer
Browse files Browse the repository at this point in the history
  • Loading branch information
cli99 authored Feb 2, 2024
2 parents f2f94d3 + e75914f commit 43118ca
Show file tree
Hide file tree
Showing 10 changed files with 644 additions and 36 deletions.
62 changes: 56 additions & 6 deletions composer/loggers/mlflow_logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@
from __future__ import annotations

import fnmatch
import logging
import os
import pathlib
import textwrap
import time
import warnings
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional, Sequence, Union

import numpy as np
import torch
Expand All @@ -24,6 +25,8 @@
if TYPE_CHECKING:
from mlflow import ModelVersion # pyright: ignore[reportGeneralTypeIssues]

log = logging.getLogger(__name__)

__all__ = ['MLFlowLogger']

DEFAULT_MLFLOW_EXPERIMENT_NAME = 'my-mlflow-experiment'
Expand Down Expand Up @@ -262,28 +265,75 @@ def register_model(
tags=tags,
)

def save_model(self, flavor: str, **kwargs):
def save_model(self, flavor: Literal['transformers', 'peft'], **kwargs):
"""Save a model to MLflow.
Note: The ``'peft'`` flavor is experimental and the API is subject to change without warning.
Args:
flavor (str): The MLflow model flavor to use. Currently only ``'transformers'`` is supported.
flavor (Literal['transformers', 'peft']): The MLflow model flavor to use. Currently only ``'transformers'`` and ``'peft'`` are supported.
**kwargs: Keyword arguments to pass to the MLflow model saving function.
Raises:
NotImplementedError: If ``flavor`` is not ``'transformers'``.
NotImplementedError: If ``flavor`` is not ``'transformers'`` or ``'peft'``.
"""
if self._enabled:
import mlflow
if flavor == 'transformers':
mlflow.transformers.save_model(**kwargs,)
elif flavor == 'peft':
import transformers

# TODO: Remove after mlflow fixes the bug that makes this necessary
mlflow.store._unity_catalog.registry.rest_store.get_feature_dependencies = lambda *args, **kwargs: '' # type: ignore

# This is a temporary workaround until MLflow adds full support for saving PEFT models.
# https://github.com/mlflow/mlflow/issues/9256
log.warning(
'Saving PEFT models using MLflow is experimental and the API is subject to change without warning.')
expected_keys = {'path', 'save_pretrained_dir'}
if not expected_keys.issubset(kwargs.keys()):
raise ValueError(f'Expected keys {expected_keys} but got {kwargs.keys()}')

# This does not implement predict for now, as we will wait for the full MLflow support
# for PEFT models.
class PeftModel(mlflow.pyfunc.PythonModel):

def load_context(self, context):
self.model = transformers.AutoModelForCausalLM.from_pretrained(
context.artifacts['lora_checkpoint'])
self.tokenizer = transformers.AutoTokenizer.from_pretrained(
context.artifacts['lora_checkpoint'])

from mlflow.models.signature import ModelSignature
from mlflow.types import ColSpec, DataType, Schema

# This is faked for now, until MLflow adds full support for saving PEFT models.
input_schema = Schema([
ColSpec(DataType.string, 'fake_input'),
])
output_schema = Schema([ColSpec(DataType.string)])
signature = ModelSignature(inputs=input_schema, outputs=output_schema)

# Symlink the directory so that we control the path that MLflow saves the model under
os.symlink(kwargs['save_pretrained_dir'], 'lora_checkpoint')

mlflow.pyfunc.save_model(
path=kwargs['path'],
artifacts={'lora_checkpoint': 'lora_checkpoint'},
python_model=PeftModel(),
signature=signature,
)

os.unlink('lora_checkpoint')
else:
raise NotImplementedError(f'flavor {flavor} not supported.')

def log_model(self, flavor: str, **kwargs):
def log_model(self, flavor: Literal['transformers'], **kwargs):
"""Log a model to MLflow.
Args:
flavor (str): The MLflow model flavor to use. Currently only ``'transformers'`` is supported.
flavor (Literal['transformers']): The MLflow model flavor to use. Currently only ``'transformers'`` is supported.
**kwargs: Keyword arguments to pass to the MLflow model logging function.
Raises:
Expand Down
77 changes: 75 additions & 2 deletions examples/TPU_Training_in_composer.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@
"# %pip install 'mosaicml @ git+https://github.com/mosaicml/composer.git'\"\n",
"\n",
"from composer import Trainer\n",
"from composer import models"
"from composer.models import ComposerClassifier"
]
},
{
Expand Down Expand Up @@ -88,9 +88,82 @@
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"import torch_xla.core.xla_model as xm\n",
"\n",
"model = models.composer_resnet_cifar(model_name='resnet_20', num_classes=10)\n",
"class Block(nn.Module):\n",
" \"\"\"A ResNet block.\"\"\"\n",
"\n",
" def __init__(self, f_in: int, f_out: int, downsample: bool = False):\n",
" super(Block, self).__init__()\n",
"\n",
" stride = 2 if downsample else 1\n",
" self.conv1 = nn.Conv2d(f_in, f_out, kernel_size=3, stride=stride, padding=1, bias=False)\n",
" self.bn1 = nn.BatchNorm2d(f_out)\n",
" self.conv2 = nn.Conv2d(f_out, f_out, kernel_size=3, stride=1, padding=1, bias=False)\n",
" self.bn2 = nn.BatchNorm2d(f_out)\n",
" self.relu = nn.ReLU(inplace=True)\n",
"\n",
" # No parameters for shortcut connections.\n",
" if downsample or f_in != f_out:\n",
" self.shortcut = nn.Sequential(\n",
" nn.Conv2d(f_in, f_out, kernel_size=1, stride=2, bias=False),\n",
" nn.BatchNorm2d(f_out),\n",
" )\n",
" else:\n",
" self.shortcut = nn.Sequential()\n",
"\n",
" def forward(self, x: torch.Tensor):\n",
" out = self.relu(self.bn1(self.conv1(x)))\n",
" out = self.bn2(self.conv2(out))\n",
" out += self.shortcut(x)\n",
" return self.relu(out)\n",
"\n",
"class ResNetCIFAR(nn.Module):\n",
" \"\"\"A residual neural network as originally designed for CIFAR-10.\"\"\"\n",
"\n",
" def __init__(self, outputs: int = 10):\n",
" super(ResNetCIFAR, self).__init__()\n",
"\n",
" depth = 20\n",
" width = 16\n",
" num_blocks = (depth - 2) // 6\n",
"\n",
" plan = [(width, num_blocks), (2 * width, num_blocks), (4 * width, num_blocks)]\n",
"\n",
" self.num_classes = outputs\n",
"\n",
" # Initial convolution.\n",
" current_filters = plan[0][0]\n",
" self.conv = nn.Conv2d(3, current_filters, kernel_size=3, stride=1, padding=1, bias=False)\n",
" self.bn = nn.BatchNorm2d(current_filters)\n",
" self.relu = nn.ReLU(inplace=True)\n",
"\n",
" # The subsequent blocks of the ResNet.\n",
" blocks = []\n",
" for segment_index, (filters, num_blocks) in enumerate(plan):\n",
" for block_index in range(num_blocks):\n",
" downsample = segment_index > 0 and block_index == 0\n",
" blocks.append(Block(current_filters, filters, downsample))\n",
" current_filters = filters\n",
"\n",
" self.blocks = nn.Sequential(*blocks)\n",
"\n",
" # Final fc layer. Size = number of filters in last segment.\n",
" self.fc = nn.Linear(plan[-1][0], outputs)\n",
" self.criterion = nn.CrossEntropyLoss()\n",
"\n",
" def forward(self, x: torch.Tensor):\n",
" out = self.relu(self.bn(self.conv(x)))\n",
" out = self.blocks(out)\n",
" out = F.avg_pool2d(out, out.size()[3])\n",
" out = out.view(out.size(0), -1)\n",
" out = self.fc(out)\n",
" return out\n",
"\n",
"model = ComposerClassifier(module=ResNetCIFAR(), num_classes=10)\n",
"\n",
"model = model.to(xm.xla_device())\n",
"\n",
"optimizer = torch.optim.SGD(\n",
Expand Down
78 changes: 76 additions & 2 deletions examples/auto_microbatching.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -101,8 +101,82 @@
"metadata": {},
"outputs": [],
"source": [
"from composer import models\n",
"model = models.composer_resnet_cifar(model_name='resnet_56', num_classes=10)\n",
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"from composer.models import ComposerClassifier\n",
"\n",
"class Block(nn.Module):\n",
" \"\"\"A ResNet block.\"\"\"\n",
"\n",
" def __init__(self, f_in: int, f_out: int, downsample: bool = False):\n",
" super(Block, self).__init__()\n",
"\n",
" stride = 2 if downsample else 1\n",
" self.conv1 = nn.Conv2d(f_in, f_out, kernel_size=3, stride=stride, padding=1, bias=False)\n",
" self.bn1 = nn.BatchNorm2d(f_out)\n",
" self.conv2 = nn.Conv2d(f_out, f_out, kernel_size=3, stride=1, padding=1, bias=False)\n",
" self.bn2 = nn.BatchNorm2d(f_out)\n",
" self.relu = nn.ReLU(inplace=True)\n",
"\n",
" # No parameters for shortcut connections.\n",
" if downsample or f_in != f_out:\n",
" self.shortcut = nn.Sequential(\n",
" nn.Conv2d(f_in, f_out, kernel_size=1, stride=2, bias=False),\n",
" nn.BatchNorm2d(f_out),\n",
" )\n",
" else:\n",
" self.shortcut = nn.Sequential()\n",
"\n",
" def forward(self, x: torch.Tensor):\n",
" out = self.relu(self.bn1(self.conv1(x)))\n",
" out = self.bn2(self.conv2(out))\n",
" out += self.shortcut(x)\n",
" return self.relu(out)\n",
"\n",
"class ResNetCIFAR(nn.Module):\n",
" \"\"\"A residual neural network as originally designed for CIFAR-10.\"\"\"\n",
"\n",
" def __init__(self, outputs: int = 10):\n",
" super(ResNetCIFAR, self).__init__()\n",
"\n",
" depth = 56\n",
" width = 16\n",
" num_blocks = (depth - 2) // 6\n",
"\n",
" plan = [(width, num_blocks), (2 * width, num_blocks), (4 * width, num_blocks)]\n",
"\n",
" self.num_classes = outputs\n",
"\n",
" # Initial convolution.\n",
" current_filters = plan[0][0]\n",
" self.conv = nn.Conv2d(3, current_filters, kernel_size=3, stride=1, padding=1, bias=False)\n",
" self.bn = nn.BatchNorm2d(current_filters)\n",
" self.relu = nn.ReLU(inplace=True)\n",
"\n",
" # The subsequent blocks of the ResNet.\n",
" blocks = []\n",
" for segment_index, (filters, num_blocks) in enumerate(plan):\n",
" for block_index in range(num_blocks):\n",
" downsample = segment_index > 0 and block_index == 0\n",
" blocks.append(Block(current_filters, filters, downsample))\n",
" current_filters = filters\n",
"\n",
" self.blocks = nn.Sequential(*blocks)\n",
"\n",
" # Final fc layer. Size = number of filters in last segment.\n",
" self.fc = nn.Linear(plan[-1][0], outputs)\n",
" self.criterion = nn.CrossEntropyLoss()\n",
"\n",
" def forward(self, x: torch.Tensor):\n",
" out = self.relu(self.bn(self.conv(x)))\n",
" out = self.blocks(out)\n",
" out = F.avg_pool2d(out, out.size()[3])\n",
" out = out.view(out.size(0), -1)\n",
" out = self.fc(out)\n",
" return out\n",
"\n",
"model = ComposerClassifier(module=ResNetCIFAR(), num_classes=10)\n",
"\n",
"optimizer = composer.optim.DecoupledSGDW(\n",
" model.parameters(), # Model parameters to update\n",
Expand Down
39 changes: 37 additions & 2 deletions examples/checkpoint_autoresume.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,41 @@
"Simply configure the instance to start Composer with the same command every time until training has finished!"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"\n",
"class MNISTModel(nn.Module):\n",
" \"\"\"Toy convolutional neural network architecture in pytorch for MNIST.\"\"\"\n",
"\n",
" def __init__(self, num_classes: int = 10):\n",
" super().__init__()\n",
" self.num_classes = num_classes\n",
" self.conv1 = nn.Conv2d(1, 16, (3, 3), padding=0)\n",
" self.conv2 = nn.Conv2d(16, 32, (3, 3), padding=0)\n",
" self.bn = nn.BatchNorm2d(32)\n",
" self.fc1 = nn.Linear(32 * 16, 32)\n",
" self.fc2 = nn.Linear(32, num_classes)\n",
"\n",
" def forward(self, x):\n",
" out = self.conv1(x)\n",
" out = F.relu(out)\n",
" out = self.conv2(out)\n",
" out = self.bn(out)\n",
" out = F.relu(out)\n",
" out = F.adaptive_avg_pool2d(out, (4, 4))\n",
" out = torch.flatten(out, 1, -1)\n",
" out = self.fc1(out)\n",
" out = F.relu(out)\n",
" return self.fc2(out)"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand All @@ -83,10 +118,10 @@
"from torchvision.transforms import ToTensor\n",
"\n",
"from composer import Trainer\n",
"from composer.models.classify_mnist import mnist_model\n",
"from composer.models import ComposerClassifier\n",
"\n",
"# Configure the trainer -- here, we train a simple MNIST classifier\n",
"model = mnist_model(num_classes=10)\n",
"model = ComposerClassifier(module=MNISTModel(num_classes=10), num_classes=10)\n",
"optimizer = SGD(model.parameters(), lr=0.01)\n",
"train_dataloader = torch.utils.data.DataLoader(\n",
" dataset=MNIST('~/datasets', train=True, download=True, transform=ToTensor()),\n",
Expand Down
Loading

0 comments on commit 43118ca

Please sign in to comment.