Skip to content

Commit

Permalink
Add examples for HF datasets (#1371)
Browse files Browse the repository at this point in the history
* intial_commit

* more examples

* run precommit

* hf mnist

* add complete MNIST example

* remove old py file

* add imdb bert

* update mnist example

* update hf_bert example

* add mds example

* update bert example

* update function name

* run precommit

* update bert example

* update mnist notebook

* update mds

* delete ipynb ckpts

* remove mds and simplify examples

* fix some typos

* simplify and remove test train mentiond

* remove headings

* add titles

* fix typo

* update url
  • Loading branch information
ramanishsingh authored and Andrew Ho committed Dec 12, 2024
1 parent ed90168 commit d64cfc1
Show file tree
Hide file tree
Showing 2 changed files with 343 additions and 0 deletions.
182 changes: 182 additions & 0 deletions examples/nodes/hf_datasets_nodes_mnist.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "1cb04783-4c83-43bc-91c6-b980d836de34",
"metadata": {},
"source": [
"### Loading and processing the MNIST dataset\n",
"In this example, we will load the MNIST dataset from Hugging Face, \n",
"use `torchdata.nodes` to process it and generate training batches."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "93026478-6dbd-4ac0-8507-360a3a2000c5",
"metadata": {},
"outputs": [],
"source": [
"from datasets import load_dataset\n",
"# Load the mnist dataset from HuggingFace datasets and convert the format to \"torch\"\n",
"dataset = load_dataset(\"ylecun/mnist\").with_format(\"torch\")\n",
"\n",
"# Getting the train dataset\n",
"dataset = dataset[\"train\"]"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "b0c46c4b-0194-4127-a218-e24ec54a3149",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"from torch.utils.data import default_collate, RandomSampler, SequentialSampler\n",
"\n",
"torch.manual_seed(42)\n",
"\n",
"# Defining samplers\n",
"# Since datasets is a Map-style dataset, we can setup a sampler to shuffle the data\n",
"sampler = RandomSampler(dataset)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "6f643c48-c6fb-4e8a-9461-fdf96b45b04b",
"metadata": {},
"outputs": [],
"source": [
"# Now we can set up some torchdata.nodes to create our pre-proc pipeline\n",
"from torchdata.nodes import MapStyleWrapper, ParallelMapper, Batcher, PinMemory, Loader\n",
"\n",
"# All torchdata.nodes.BaseNode implementations are Iterators.\n",
"# MapStyleWrapper creates an Iterator that combines sampler and dataset to create an iterator.\n",
"#\n",
"# Under the hood, MapStyleWrapper just does:\n",
"# > node = IterableWrapper(sampler)\n",
"# > node = Mapper(node, map_fn=dataset.__getitem__) # You can parallelize this with ParallelMapper\n",
"\n",
"node = MapStyleWrapper(map_dataset=dataset, sampler=sampler)\n",
"\n",
"# Now we want to transform the raw inputs. We can just use another Mapper with\n",
"# a custom map_fn to perform this. Using ParallelMapper allows us to use multiple\n",
"# threads (or processes) to parallelize this work and have it run in the background\n",
"# We need a mapper function to convert a dtype and also normalize\n",
"def map_fn(item):\n",
" image = item[\"image\"].to(torch.float32)/255\n",
" label = item[\"label\"]\n",
"\n",
" return {\"image\":image, \"label\":label}\n",
" \n",
"node = ParallelMapper(node, map_fn=map_fn, num_workers=2) # output items are Dict[str, tensor]\n",
"\n",
"\n",
"# Hyperparameters\n",
"batch_size = 2 \n",
"\n",
"# Next we batch the inputs, and then apply a collate_fn with another Mapper\n",
"# to stack the tensor. We use torch.utils.data.default_collate for this\n",
"node = Batcher(node, batch_size=batch_size) # output items are List[Dict[str, tensor]]\n",
"node = ParallelMapper(node, map_fn=default_collate, num_workers=2) # outputs are Dict[str, tensor]\n",
"\n",
"# we can optionally apply pin_memory to the batches\n",
"if torch.cuda.is_available():\n",
" node = PinMemory(node)\n",
"\n",
"# Since nodes are iterators, they need to be manually .reset() between epochs.\n",
"# Instead, we can wrap the root node in Loader to convert it to a more conventional Iterable.\n",
"loader = Loader(node)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "c97a79ba-e6b3-4ac7-a4c5-edc8f9c58ff4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'image': tensor([[[[0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" ...,\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.]]],\n",
"\n",
"\n",
" [[[0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" ...,\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.]]]]), 'label': tensor([1, 4])}\n",
"There are 2 samples in this batch\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAOUAAAH4CAYAAAC19irnAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAUqElEQVR4nO3dbWyV9f3H8c+x3JQxBSmtDqKFWkjpUqezgwZLLN6sTsnSLui2LMPGhCUOXcdA0QdSdpOxTpkE8aaZE2x4BhbjBnHLYskyU1uJA4QJFEKHNA5a6lpYw013rv+Daf+ycl0t5fT008P7lfCA871+h98xvP0BV89pLAiCQABsXDXcGwBwIaIEzBAlYIYoATNECZghSsAMUQJmiBIwQ5SAGaI01dLSolgspmeffTZhz7ljxw7FYjHt2LEjYc+JxCPKBNq4caNisZh27tw53FsZEgcOHNDSpUs1d+5cpaenKxaLqaWlZbi3lXKIEgPW0NCgdevW6dSpU5o1a9ZwbydlESUG7Jvf/Kb+9a9/6YMPPtD3vve94d5OyiLKJDt37pxWrlyp2267TRMmTND48eM1b9481dfXh6557rnnlJ2drXHjxumOO+7Q3r17+1yzf/9+LVy4UJMmTVJ6eroKCwv15ptv9ruf7u5u7d+/X+3t7f1eO2nSJF199dX9XofLQ5RJ1tXVpVdeeUUlJSWqrq7WqlWr1NbWptLSUu3atavP9bW1tVq3bp2WLFmip556Snv37tWdd96p48eP916zb98+FRUV6cMPP9STTz6pNWvWaPz48SorK9PWrVsj99PU1KRZs2Zp/fr1iX6pGKRRw72BK821116rlpYWjRkzpvexxYsXKy8vT88//7x+97vfXXD9oUOH1NzcrKlTp0qS7r33Xs2ZM0fV1dX6zW9+I0mqrKzUjTfeqPfee09jx46VJP3whz9UcXGxVqxYofLy8iS9OiQCJ2WSpaWl9QYZj8fV0dGhnp4eFRYW6v333+9zfVlZWW+QkjR79mzNmTNH27dvlyR1dHTo7bff1oMPPqhTp06pvb1d7e3tOnnypEpLS9Xc3KzW1tbQ/ZSUlCgIAq1atSqxLxSDRpTD4LXXXtPNN9+s9PR0ZWRkKDMzU9u2bVNnZ2efa2fMmNHnsZkzZ/beijh06JCCINDTTz+tzMzMC35UVVVJkk6cODGkrweJxR9fk2zTpk2qqKhQWVmZHn/8cWVlZSktLU2rV6/W4cOHL/n54vG4JGn58uUqLS296DW5ubmXtWckF1Em2ZYtW5STk6O6ujrFYrHexz871f5Xc3Nzn8cOHjyoadOmSZJycnIkSaNHj9bdd9+d+A0j6fjja5KlpaVJkj7/eWWNjY1qaGi46PVvvPHGBX8nbGpqUmNjo77xjW9IkrKyslRSUqKamhp9/PHHfda3tbVF7udSbokgOTgph8Crr76qt956q8/jlZWVWrBggerq6lReXq77779fR44c0csvv6z8/HydPn26z5rc3FwVFxfrkUce0dmzZ7V27VplZGToiSee6L3mhRdeUHFxsQoKCrR48WLl5OTo+PHjamho0LFjx7R79+7QvTY1NWn+/Pmqqqrq9x97Ojs79fzzz0uS3nnnHUnS+vXrNXHiRE2cOFGPPvroQP7zoD8BEmbDhg2BpNAfH330URCPx4Nf/vKXQXZ2djB27Njg1ltvDf7whz8EDz30UJCdnd37XEeOHAkkBc8880ywZs2a4IYbbgjGjh0bzJs3L9i9e3efX/vw4cPBokWLguuvvz4YPXp0MHXq1GDBggXBli1beq+pr68PJAX19fV9Hquqqur39X22p4v9+PzecXliQcDnvgJO+DslYIYoATNECZghSsAMUQJmiBIwQ5SAmQF/Rc/nv04TwOAM5MsCOCkBM0QJmCFKwAxRAmaIEjBDlIAZ3uScQn7wgx+EzmpqaiLXTp8+PXTG9wtJLk5KwAxRAmaIEjBDlIAZogTMECVghigBM9ynTCHvvvtu6Ky/twzdfvvtoTPuUyYXJyVghigBM0QJmCFKwAxRAmaIEjDDLRFIkqZMmTLcW8CnOCkBM0QJmCFKwAxRAmaIEjBDlIAZogTMECVghigBM0QJmCFKwAxRAmaIEjBDlIAZogTMECVghigBM0QJmCFKwAxRAmaIEjBDlIAZogTMECVghigBM0QJmCFKwAxRAmaIEjBDlIAZogTMECVghigBM0QJmCFKwAxRAmaIEjAzarg3AA8HDhwY7i3gU5yUgBmiBMwQJWCGKAEzRAmYIUrADLdEUkheXt6g1+7ZsyeBO8Hl4KQEzBAlYIYoATNECZghSsAMUQJmiBIww33KFPLlL395uLeABOCkBMwQJWCGKAEzRAmYIUrADFECZrglkkK+/vWvD/cWkACclIAZogTMECVghigBM0QJmCFKwAxRAma4T5lCrrnmmtDZvn37Ite2trYmejsYJE5KwAxRAmaIEjBDlIAZogTMECVghlsiI8h1110XOc/KygqdtbS0RK49f/78YLaEIcBJCZghSsAMUQJmiBIwQ5SAGaIEzBAlYIb7lCNIRkbGoOcvvvhioreDIcJJCZghSsAMUQJmiBIwQ5SAGaIEzHBL5AqxefPm4d4CBoiTEjBDlIAZogTMECVghigBM0QJmCFKwAz3KUeQhx9+eNBrz5w5k8CdYChxUgJmiBIwQ5SAGaIEzBAlYIYoATPcEkmyoqKiyPnEiRNDZ+Xl5YP+defNmxc5X7JkSeiss7Mzcu2uXbtCZ9u3b49ce+7cucj5lYiTEjBDlIAZogTMECVghigBM0QJmCFKwEwsCIJgQBfGYkO9FyvZ2dmhs7Kyssi1CxcuDJ31d58yLS0tcj7S/Pvf/46cR/32+/3vfx+5dtGiRaGz//znP9EbGyYDyY2TEjBDlIAZogTMECVghigBM0QJmBnRb93q7/bBY489FjqrrKyMXDtlypTQ2ejRo6M3FqG5uTly3tjYGDr705/+NOhftz8nT54Mne3Zsydy7fz580NnCxYsiFw7ZsyY0Nmf//znyLXxeDxyPlJxUgJmiBIwQ5SAGaIEzBAlYIYoATNECZgZ0W/dys3NjZwfPHhw0M/d1dUVOnvooYci127bti101t+9tVS994b/4q1bwAhElIAZogTMECVghigBM0QJmBnRt0SAkYZbIsAIRJSAGaIEzBAlYIYoATNECZghSsAMUQJmiBIwQ5SAGaIEzBAlYIYoATNECZghSsAMUQJmiBIwQ5SAGaIEzBAlYIYoATNECZghSsAMUQJmiBIwQ5SAGaIEzBAlYIYoATNECZghSsAMUQJmiBIwQ5SAGaIEzBAlYIYoATNECZghSsAMUQJmiBIwQ5SAGaIEzBAlYIYoATNECZghSsAMUQJmiBIwQ5SAGaIEzBAlYGbUQC8MgmAo9wHgU5yUgBmiBMwQJWCGKAEzRAmYIUrADFECZogSMEOUgBmiBMwQJWCGKAEzRAmYIUrADFECZogSMEOUgBmiBMwQJWCGKE21tLQoFovp2WefTdhz7tixQ7FYTDt27EjYcyLxiDKBNm7cqFgspp07dw73VpLinnvuUSwW06OPPjrcW0kpRIlBqaurU0NDw3BvIyURJS7ZmTNntGzZMq1YsWK4t5KSiDLJzp07p5UrV+q2227ThAkTNH78eM2bN0/19fWha5577jllZ2dr3LhxuuOOO7R3794+1+zfv18LFy7UpEmTlJ6ersLCQr355pv97qe7u1v79+9Xe3v7gF/Dr3/9a8XjcS1fvnzAazBwRJlkXV1deuWVV1RSUqLq6mqtWrVKbW1tKi0t1a5du/pcX1tbq3Xr1mnJkiV66qmntHfvXt155506fvx47zX79u1TUVGRPvzwQz355JNas2aNxo8fr7KyMm3dujVyP01NTZo1a5bWr18/oP0fPXpUv/rVr1RdXa1x48Zd0mvHAAVImA0bNgSSgvfeey/0mp6enuDs2bMXPPbJJ58E1113XfDwww/3PnbkyJFAUjBu3Ljg2LFjvY83NjYGkoKlS5f2PnbXXXcFBQUFwZkzZ3ofi8fjwdy5c4MZM2b0PlZfXx9ICurr6/s8VlVVNaDXuHDhwmDu3Lm9P5cULFmyZEBrMTCclEmWlpamMWPGSJLi8bg6OjrU09OjwsJCvf/++32uLysr09SpU3t/Pnv2bM2ZM0fbt2+XJHV0dOjtt9/Wgw8+qFOnTqm9vV3t7e06efKkSktL1dzcrNbW1tD9lJSUKAgCrVq1qt+919fX6/XXX9fatWsv7UXjkhDlMHjttdd08803Kz09XRkZGcrMzNS2bdvU2dnZ59oZM2b0eWzmzJlqaWmRJB06dEhBEOjpp59WZmbmBT+qqqokSSdOnLjsPff09OhHP/qRvv/97+trX/vaZT8fwg34G/wgMTZt2qSKigqVlZXp8ccfV1ZWltLS0rR69WodPnz4kp8vHo9LkpYvX67S0tKLXpObm3tZe5b++3fbAwcOqKampvd/CJ85deqUWlpalJWVpS984QuX/Wtd6YgyybZs2aKcnBzV1dUpFov1Pv7Zqfa/mpub+zx28OBBTZs2TZKUk5MjSRo9erTuvvvuxG/4U0ePHtX58+d1++2395nV1taqtrZWW7duVVlZ2ZDt4UpBlEmWlpYm6b/fWvCzKBsbG9XQ0KAbb7yxz/VvvPGGWltbe/9e2dTUpMbGRv34xz+WJGVlZamkpEQ1NTV67LHH9KUvfemC9W1tbcrMzAzdT3d3t44eParJkydr8uTJodd95zvf0S233NLn8fLyct13331avHix5syZE/naMTBEOQReffVVvfXWW30er6ys1IIFC1RXV6fy8nLdf//9OnLkiF5++WXl5+fr9OnTfdbk5uaquLhYjzzyiM6ePau1a9cqIyNDTzzxRO81L7zwgoqLi1VQUKDFixcrJydHx48fV0NDg44dO6bdu3eH7rWpqUnz589XVVVV5D/25OXlKS8v76Kz6dOnc0ImEFEOgZdeeumij1dUVKiiokL//Oc/VVNToz/+8Y/Kz8/Xpk2btHnz5ot+ofiiRYt01VVXae3atTpx4oRmz56t9evXX3Ai5ufna+fOnfrpT3+qjRs36uTJk8rKytKtt96qlStXDtXLxBCJBQHfohlwwi0RwAxRAmaIEjBDlIAZogTMECVghigBMwP+4oHPf50mgMEZyJcFcFICZogSMEOUgBmiBMwQJWCGKAEzRAmYIUrADFECZogSMEOUgBmiBMwQJWCGKAEzRAmYIUrADFECZogSMEOUgBmiBMwQJWCGKAEzRAmYIUrADFECZogSMEOUgBmiBMwQJWCGKAEzRAmYIUrADFECZogSMEOUgBmiBMyMGu4NXI7CwsLI+bJly0Jn3/3udxO9nZS1YsWKyPlVV4X/v3316tWJ3k7K46QEzBAlYIYoATNECZghSsAMUQJmiBIwM6LvU/b09ETO77nnntBZf/feqqurB7WnVNTd3R05LygoSNJOrgyclIAZogTMECVghigBM0QJmCFKwMyIviWya9euyHltbW3orLS0NHItt0T+36hR0b9NiouLk7STKwMnJWCGKAEzRAmYIUrADFECZogSMEOUgJkRfZ+yP11dXcO9hZSwefPmyPkvfvGL0NnkyZMj17a3tw9qT6mMkxIwQ5SAGaIEzBAlYIYoATNECZhJ6VsiSIxjx45Fzs+dOxc6e+CBByLXvvTSS4PaUyrjpATMECVghigBM0QJmCFKwAxRAmaIEjCT0vcpz58/HzqLxWJJ3Elqa2xsDJ3NnTs3ci33KfvipATMECVghigBM0QJmCFKwAxRAmZS+pbItm3bQmd33XVXEneS2t55553Q2U9+8pMk7iQ1cFICZogSMEOUgBmiBMwQJWCGKAEzRAmYiQVBEAzowhH4VqdbbrkldPaXv/wlcm1+fn7orL+PXLzSTJgwIXTW0tISufbaa69N8G68DSQ3TkrADFECZogSMEOUgBmiBMwQJWAmpd+6dfr06dDZqFHRL/3b3/526GzNmjWD3lMq6uzsDJ1dc801kWuLiopCZ+++++6g9zSScVICZogSMEOUgBmiBMwQJWCGKAEzRAmYSem3bkV5/fXXI+dRr/db3/pWoreTsuLxeOT8mWeeCZ2tWLEi0dsZdrx1CxiBiBIwQ5SAGaIEzBAlYIYoATMp/datKFu2bImcX2lvz0pLSwudXX311YN+3r/97W+R85kzZ4bOsrOzI9d2dXWFzj755JPojRnjpATMECVghigBM0QJmCFKwAxRAmaIEjBzxd6n/PjjjyPn119/fejspptuilx7+PDh0FleXl7k2nvvvTdyHuWGG24InX31q1+NXPvFL34xdDZ9+vTItVH3C6dNmxa5tqCgIHT2la98JXLtz3/+89DZhg0bItc646QEzBAlYIYoATNECZghSsAMUQJmrthPs+vvn+obGhpCZ/297eujjz4KnVVVVUWu7ejoCJ3t2bMncm3Uvvp7C1WU1tbWyHlbW1vobOXKlZFrKyoqQmc5OTmRa0ciPs0OGIGIEjBDlIAZogTMECVghigBM0QJmLli71P2Z9myZaGzn/3sZ5Frm5qaQmeVlZWRa1taWkJnUW+RcjVlypTI+c6dO0NnUR8/KUmnT58e1J6GE/cpgRGIKAEzRAmYIUrADFECZogSMHPFfppdf/7617+Gzj744IPItUuXLg2d9ff2q1TT3d0dOY/61MAHHnggcu1I/sS6KJyUgBmiBMwQJWCGKAEzRAmYIUrADFECZrhPGaKxsTF0VlRUlMSdpLaotwRmZ2cncSc+OCkBM0QJmCFKwAxRAmaIEjBDlIAZbolgWEV9utsAP2gx5XBSAmaIEjBDlIAZogTMECVghigBM0QJmOG7bmFIpaenR87//ve/h87+8Y9/RK6dP3/+oPY0nPiuW8AIRJSAGaIEzBAlYIYoATNECZjhrVsYUmfOnImc//a3vw2d9fddt1IVJyVghigBM0QJmCFKwAxRAmaIEjBDlIAZ7lNiWG3evDl0xn1KABaIEjBDlIAZogTMECVghigBM3yaHZBEfJodMAIRJWCGKAEzRAmYIUrADFECZogSMEOUgBmiBMwQJWCGKAEzRAmYIUrADFECZogSMEOUgBmiBMwQJWCGKAEzRAmYIUrADFECZgb8XbcG+EmUAC4TJyVghigBM0QJmCFKwAxRAmaIEjBDlIAZogTMECVg5v8AWWLQZxmV+6kAAAAASUVORK5CYII=",
"text/plain": [
"<Figure size 800x600 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Once we have the loader, we can get batches from it over multiple epochs, to train the model\n",
"# Let us look at one batch \n",
"import matplotlib.pyplot as plt\n",
"fig, axs = plt.subplots(2, figsize=(8, 6))\n",
"\n",
"batch = next(iter(loader))\n",
" \n",
"\n",
"print(batch)\n",
"print(f\"There are {len(batch)} samples in this batch\")\n",
"\n",
"# Since we used default_collate, each batch is a dictionary, with two keys: \"image\" and \"label\"\n",
"# The value of key \"image\" is a stacked tensor of images in the batch\n",
"# Similarly, the value of key \"label\" is a stacked tensor of labels in the batch\n",
"images = batch[\"image\"]\n",
"labels = batch[\"label\"]\n",
"\n",
"#let's also display the two items\n",
"for i in range(len(images)):\n",
" axs[i].imshow(images[i].squeeze(), cmap='gray')\n",
" axs[i].set_title(f\"Label: {labels[i]}\") \n",
" axs[i].set_axis_off()\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
161 changes: 161 additions & 0 deletions examples/nodes/hf_imdb_bert.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "d8513771-36ac-4d03-b890-35108bce2211",
"metadata": {},
"source": [
"### Loading and processing IMDB movie review dataset\n",
"In this example, we will load the IMDB dataset from Hugging Face, \n",
"use `torchdata.nodes` to process it and generate training batches."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "eb3b507c-2ad1-410d-a834-6847182de684",
"metadata": {},
"outputs": [],
"source": [
"from datasets import load_dataset\n",
"from transformers import BertTokenizer, BertForSequenceClassification"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "089f1126-7125-4274-9d71-5c949ccc7bbd",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch.utils.data import default_collate, RandomSampler, SequentialSampler"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "2afac7d9-3d66-4195-8647-dc7034d306f2",
"metadata": {},
"outputs": [],
"source": [
"# Load IMDB dataset from huggingface datasets and select the \"train\" split\n",
"dataset = load_dataset(\"imdb\", streaming=False)\n",
"dataset = dataset[\"train\"]\n",
"# Since dataset is a Map-style dataset, we can setup a sampler to shuffle the data\n",
"# Please refer to the migration guide here https://pytorch.org/data/main/migrate_to_nodes_from_utils.html\n",
"# to migrate from torch.utils.data to torchdata.nodes\n",
"\n",
"sampler = RandomSampler(dataset)\n",
"# Use a standard bert tokenizer\n",
"tokenizer = BertTokenizer.from_pretrained(\"bert-base-uncased\")\n",
"# Now we can set up some torchdata.nodes to create our pre-proc pipeline"
]
},
{
"cell_type": "markdown",
"id": "09e08a47-573c-4d32-9a02-36cd8150db60",
"metadata": {},
"source": [
"All torchdata.nodes.BaseNode implementations are Iterators.\n",
"MapStyleWrapper creates an Iterator that combines sampler and dataset to create an iterator.\n",
"Under the hood, MapStyleWrapper just does:\n",
"```python\n",
"node = IterableWrapper(sampler)\n",
"node = Mapper(node, map_fn=dataset.__getitem__) # You can parallelize this with ParallelMapper\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "02af5479-ee69-41d8-ab2d-bf154b84bc15",
"metadata": {},
"outputs": [],
"source": [
"from torchdata.nodes import MapStyleWrapper, ParallelMapper, Batcher, PinMemory, Loader\n",
"node = MapStyleWrapper(map_dataset=dataset, sampler=sampler)\n",
"\n",
"# Now we want to transform the raw inputs. We can just use another Mapper with\n",
"# a custom map_fn to perform this. Using ParallelMapper allows us to use multiple\n",
"# threads (or processes) to parallelize this work and have it run in the background\n",
"max_len = 512\n",
"batch_size = 2\n",
"def bert_transform(item):\n",
" encoding = tokenizer.encode_plus(\n",
" item[\"text\"],\n",
" add_special_tokens=True,\n",
" max_length=max_len,\n",
" padding=\"max_length\",\n",
" truncation=True,\n",
" return_attention_mask=True,\n",
" return_tensors=\"pt\",\n",
" )\n",
" return {\n",
" \"input_ids\": encoding[\"input_ids\"].flatten(),\n",
" \"attention_mask\": encoding[\"attention_mask\"].flatten(),\n",
" \"labels\": torch.tensor(item[\"label\"], dtype=torch.long),\n",
" }\n",
"node = ParallelMapper(node, map_fn=bert_transform, num_workers=2) # output items are Dict[str, tensor]\n",
"\n",
"# Next we batch the inputs, and then apply a collate_fn with another Mapper\n",
"# to stack the tensors between. We use torch.utils.data.default_collate for this\n",
"node = Batcher(node, batch_size=batch_size) # output items are List[Dict[str, tensor]]\n",
"node = ParallelMapper(node, map_fn=default_collate, num_workers=2) # outputs are Dict[str, tensor]\n",
"\n",
"# we can optionally apply pin_memory to the batches\n",
"if torch.cuda.is_available():\n",
" node = PinMemory(node)\n",
"\n",
"# Since nodes are iterators, they need to be manually .reset() between epochs.\n",
"# We can wrap the root node in Loader to convert it to a more conventional Iterable.\n",
"loader = Loader(node)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "60fd54f3-62ef-47aa-a790-853cb4899f13",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'input_ids': tensor([[ 101, 1045, 2572, ..., 2143, 2000, 102],\n",
" [ 101, 2004, 1037, ..., 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, ..., 1, 1, 1],\n",
" [1, 1, 1, ..., 0, 0, 0]]), 'labels': tensor([0, 1])}\n"
]
}
],
"source": [
"# Inspect a batch\n",
"batch = next(iter(loader))\n",
"print(batch)\n",
"# In a batch we get three keys, as defined in the method `bert_transform`.\n",
"# Since the batch size is 2, two samples are stacked together for each key."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

0 comments on commit d64cfc1

Please sign in to comment.