-
Notifications
You must be signed in to change notification settings - Fork 155
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add examples for HF datasets (#1371)
* intial_commit * more examples * run precommit * hf mnist * add complete MNIST example * remove old py file * add imdb bert * update mnist example * update hf_bert example * add mds example * update bert example * update function name * run precommit * update bert example * update mnist notebook * update mds * delete ipynb ckpts * remove mds and simplify examples * fix some typos * simplify and remove test train mentiond * remove headings * add titles * fix typo * update url
- Loading branch information
1 parent
ed90168
commit d64cfc1
Showing
2 changed files
with
343 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "1cb04783-4c83-43bc-91c6-b980d836de34", | ||
"metadata": {}, | ||
"source": [ | ||
"### Loading and processing the MNIST dataset\n", | ||
"In this example, we will load the MNIST dataset from Hugging Face, \n", | ||
"use `torchdata.nodes` to process it and generate training batches." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"id": "93026478-6dbd-4ac0-8507-360a3a2000c5", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from datasets import load_dataset\n", | ||
"# Load the mnist dataset from HuggingFace datasets and convert the format to \"torch\"\n", | ||
"dataset = load_dataset(\"ylecun/mnist\").with_format(\"torch\")\n", | ||
"\n", | ||
"# Getting the train dataset\n", | ||
"dataset = dataset[\"train\"]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"id": "b0c46c4b-0194-4127-a218-e24ec54a3149", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import torch\n", | ||
"import torch.nn as nn\n", | ||
"from torch.utils.data import default_collate, RandomSampler, SequentialSampler\n", | ||
"\n", | ||
"torch.manual_seed(42)\n", | ||
"\n", | ||
"# Defining samplers\n", | ||
"# Since datasets is a Map-style dataset, we can setup a sampler to shuffle the data\n", | ||
"sampler = RandomSampler(dataset)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"id": "6f643c48-c6fb-4e8a-9461-fdf96b45b04b", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Now we can set up some torchdata.nodes to create our pre-proc pipeline\n", | ||
"from torchdata.nodes import MapStyleWrapper, ParallelMapper, Batcher, PinMemory, Loader\n", | ||
"\n", | ||
"# All torchdata.nodes.BaseNode implementations are Iterators.\n", | ||
"# MapStyleWrapper creates an Iterator that combines sampler and dataset to create an iterator.\n", | ||
"#\n", | ||
"# Under the hood, MapStyleWrapper just does:\n", | ||
"# > node = IterableWrapper(sampler)\n", | ||
"# > node = Mapper(node, map_fn=dataset.__getitem__) # You can parallelize this with ParallelMapper\n", | ||
"\n", | ||
"node = MapStyleWrapper(map_dataset=dataset, sampler=sampler)\n", | ||
"\n", | ||
"# Now we want to transform the raw inputs. We can just use another Mapper with\n", | ||
"# a custom map_fn to perform this. Using ParallelMapper allows us to use multiple\n", | ||
"# threads (or processes) to parallelize this work and have it run in the background\n", | ||
"# We need a mapper function to convert a dtype and also normalize\n", | ||
"def map_fn(item):\n", | ||
" image = item[\"image\"].to(torch.float32)/255\n", | ||
" label = item[\"label\"]\n", | ||
"\n", | ||
" return {\"image\":image, \"label\":label}\n", | ||
" \n", | ||
"node = ParallelMapper(node, map_fn=map_fn, num_workers=2) # output items are Dict[str, tensor]\n", | ||
"\n", | ||
"\n", | ||
"# Hyperparameters\n", | ||
"batch_size = 2 \n", | ||
"\n", | ||
"# Next we batch the inputs, and then apply a collate_fn with another Mapper\n", | ||
"# to stack the tensor. We use torch.utils.data.default_collate for this\n", | ||
"node = Batcher(node, batch_size=batch_size) # output items are List[Dict[str, tensor]]\n", | ||
"node = ParallelMapper(node, map_fn=default_collate, num_workers=2) # outputs are Dict[str, tensor]\n", | ||
"\n", | ||
"# we can optionally apply pin_memory to the batches\n", | ||
"if torch.cuda.is_available():\n", | ||
" node = PinMemory(node)\n", | ||
"\n", | ||
"# Since nodes are iterators, they need to be manually .reset() between epochs.\n", | ||
"# Instead, we can wrap the root node in Loader to convert it to a more conventional Iterable.\n", | ||
"loader = Loader(node)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"id": "c97a79ba-e6b3-4ac7-a4c5-edc8f9c58ff4", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"{'image': tensor([[[[0., 0., 0., ..., 0., 0., 0.],\n", | ||
" [0., 0., 0., ..., 0., 0., 0.],\n", | ||
" [0., 0., 0., ..., 0., 0., 0.],\n", | ||
" ...,\n", | ||
" [0., 0., 0., ..., 0., 0., 0.],\n", | ||
" [0., 0., 0., ..., 0., 0., 0.],\n", | ||
" [0., 0., 0., ..., 0., 0., 0.]]],\n", | ||
"\n", | ||
"\n", | ||
" [[[0., 0., 0., ..., 0., 0., 0.],\n", | ||
" [0., 0., 0., ..., 0., 0., 0.],\n", | ||
" [0., 0., 0., ..., 0., 0., 0.],\n", | ||
" ...,\n", | ||
" [0., 0., 0., ..., 0., 0., 0.],\n", | ||
" [0., 0., 0., ..., 0., 0., 0.],\n", | ||
" [0., 0., 0., ..., 0., 0., 0.]]]]), 'label': tensor([1, 4])}\n", | ||
"There are 2 samples in this batch\n" | ||
] | ||
}, | ||
{ | ||
"data": { | ||
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAOUAAAH4CAYAAAC19irnAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAUqElEQVR4nO3dbWyV9f3H8c+x3JQxBSmtDqKFWkjpUqezgwZLLN6sTsnSLui2LMPGhCUOXcdA0QdSdpOxTpkE8aaZE2x4BhbjBnHLYskyU1uJA4QJFEKHNA5a6lpYw013rv+Daf+ycl0t5fT008P7lfCA871+h98xvP0BV89pLAiCQABsXDXcGwBwIaIEzBAlYIYoATNECZghSsAMUQJmiBIwQ5SAGaI01dLSolgspmeffTZhz7ljxw7FYjHt2LEjYc+JxCPKBNq4caNisZh27tw53FsZEgcOHNDSpUs1d+5cpaenKxaLqaWlZbi3lXKIEgPW0NCgdevW6dSpU5o1a9ZwbydlESUG7Jvf/Kb+9a9/6YMPPtD3vve94d5OyiLKJDt37pxWrlyp2267TRMmTND48eM1b9481dfXh6557rnnlJ2drXHjxumOO+7Q3r17+1yzf/9+LVy4UJMmTVJ6eroKCwv15ptv9ruf7u5u7d+/X+3t7f1eO2nSJF199dX9XofLQ5RJ1tXVpVdeeUUlJSWqrq7WqlWr1NbWptLSUu3atavP9bW1tVq3bp2WLFmip556Snv37tWdd96p48eP916zb98+FRUV6cMPP9STTz6pNWvWaPz48SorK9PWrVsj99PU1KRZs2Zp/fr1iX6pGKRRw72BK821116rlpYWjRkzpvexxYsXKy8vT88//7x+97vfXXD9oUOH1NzcrKlTp0qS7r33Xs2ZM0fV1dX6zW9+I0mqrKzUjTfeqPfee09jx46VJP3whz9UcXGxVqxYofLy8iS9OiQCJ2WSpaWl9QYZj8fV0dGhnp4eFRYW6v333+9zfVlZWW+QkjR79mzNmTNH27dvlyR1dHTo7bff1oMPPqhTp06pvb1d7e3tOnnypEpLS9Xc3KzW1tbQ/ZSUlCgIAq1atSqxLxSDRpTD4LXXXtPNN9+s9PR0ZWRkKDMzU9u2bVNnZ2efa2fMmNHnsZkzZ/beijh06JCCINDTTz+tzMzMC35UVVVJkk6cODGkrweJxR9fk2zTpk2qqKhQWVmZHn/8cWVlZSktLU2rV6/W4cOHL/n54vG4JGn58uUqLS296DW5ubmXtWckF1Em2ZYtW5STk6O6ujrFYrHexz871f5Xc3Nzn8cOHjyoadOmSZJycnIkSaNHj9bdd9+d+A0j6fjja5KlpaVJkj7/eWWNjY1qaGi46PVvvPHGBX8nbGpqUmNjo77xjW9IkrKyslRSUqKamhp9/PHHfda3tbVF7udSbokgOTgph8Crr76qt956q8/jlZWVWrBggerq6lReXq77779fR44c0csvv6z8/HydPn26z5rc3FwVFxfrkUce0dmzZ7V27VplZGToiSee6L3mhRdeUHFxsQoKCrR48WLl5OTo+PHjamho0LFjx7R79+7QvTY1NWn+/Pmqqqrq9x97Ojs79fzzz0uS3nnnHUnS+vXrNXHiRE2cOFGPPvroQP7zoD8BEmbDhg2BpNAfH330URCPx4Nf/vKXQXZ2djB27Njg1ltvDf7whz8EDz30UJCdnd37XEeOHAkkBc8880ywZs2a4IYbbgjGjh0bzJs3L9i9e3efX/vw4cPBokWLguuvvz4YPXp0MHXq1GDBggXBli1beq+pr68PJAX19fV9Hquqqur39X22p4v9+PzecXliQcDnvgJO+DslYIYoATNECZghSsAMUQJmiBIwQ5SAmQF/Rc/nv04TwOAM5MsCOCkBM0QJmCFKwAxRAmaIEjBDlIAZ3uScQn7wgx+EzmpqaiLXTp8+PXTG9wtJLk5KwAxRAmaIEjBDlIAZogTMECVghigBM9ynTCHvvvtu6Ky/twzdfvvtoTPuUyYXJyVghigBM0QJmCFKwAxRAmaIEjDDLRFIkqZMmTLcW8CnOCkBM0QJmCFKwAxRAmaIEjBDlIAZogTMECVghigBM0QJmCFKwAxRAmaIEjBDlIAZogTMECVghigBM0QJmCFKwAxRAmaIEjBDlIAZogTMECVghigBM0QJmCFKwAxRAmaIEjBDlIAZogTMECVghigBM0QJmCFKwAxRAmaIEjAzarg3AA8HDhwY7i3gU5yUgBmiBMwQJWCGKAEzRAmYIUrADLdEUkheXt6g1+7ZsyeBO8Hl4KQEzBAlYIYoATNECZghSsAMUQJmiBIww33KFPLlL395uLeABOCkBMwQJWCGKAEzRAmYIUrADFECZrglkkK+/vWvD/cWkACclIAZogTMECVghigBM0QJmCFKwAxRAma4T5lCrrnmmtDZvn37Ite2trYmejsYJE5KwAxRAmaIEjBDlIAZogTMECVghlsiI8h1110XOc/KygqdtbS0RK49f/78YLaEIcBJCZghSsAMUQJmiBIwQ5SAGaIEzBAlYIb7lCNIRkbGoOcvvvhioreDIcJJCZghSsAMUQJmiBIwQ5SAGaIEzHBL5AqxefPm4d4CBoiTEjBDlIAZogTMECVghigBM0QJmCFKwAz3KUeQhx9+eNBrz5w5k8CdYChxUgJmiBIwQ5SAGaIEzBAlYIYoATPcEkmyoqKiyPnEiRNDZ+Xl5YP+defNmxc5X7JkSeiss7Mzcu2uXbtCZ9u3b49ce+7cucj5lYiTEjBDlIAZogTMECVghigBM0QJmCFKwEwsCIJgQBfGYkO9FyvZ2dmhs7Kyssi1CxcuDJ31d58yLS0tcj7S/Pvf/46cR/32+/3vfx+5dtGiRaGz//znP9EbGyYDyY2TEjBDlIAZogTMECVghigBM0QJmBnRb93q7/bBY489FjqrrKyMXDtlypTQ2ejRo6M3FqG5uTly3tjYGDr705/+NOhftz8nT54Mne3Zsydy7fz580NnCxYsiFw7ZsyY0Nmf//znyLXxeDxyPlJxUgJmiBIwQ5SAGaIEzBAlYIYoATNECZgZ0W/dys3NjZwfPHhw0M/d1dUVOnvooYci127bti101t+9tVS994b/4q1bwAhElIAZogTMECVghigBM0QJmBnRt0SAkYZbIsAIRJSAGaIEzBAlYIYoATNECZghSsAMUQJmiBIwQ5SAGaIEzBAlYIYoATNECZghSsAMUQJmiBIwQ5SAGaIEzBAlYIYoATNECZghSsAMUQJmiBIwQ5SAGaIEzBAlYIYoATNECZghSsAMUQJmiBIwQ5SAGaIEzBAlYIYoATNECZghSsAMUQJmiBIwQ5SAGaIEzBAlYIYoATNECZghSsAMUQJmiBIwQ5SAGaIEzBAlYGbUQC8MgmAo9wHgU5yUgBmiBMwQJWCGKAEzRAmYIUrADFECZogSMEOUgBmiBMwQJWCGKAEzRAmYIUrADFECZogSMEOUgBmiBMwQJWCGKE21tLQoFovp2WefTdhz7tixQ7FYTDt27EjYcyLxiDKBNm7cqFgspp07dw73VpLinnvuUSwW06OPPjrcW0kpRIlBqaurU0NDw3BvIyURJS7ZmTNntGzZMq1YsWK4t5KSiDLJzp07p5UrV+q2227ThAkTNH78eM2bN0/19fWha5577jllZ2dr3LhxuuOOO7R3794+1+zfv18LFy7UpEmTlJ6ersLCQr355pv97qe7u1v79+9Xe3v7gF/Dr3/9a8XjcS1fvnzAazBwRJlkXV1deuWVV1RSUqLq6mqtWrVKbW1tKi0t1a5du/pcX1tbq3Xr1mnJkiV66qmntHfvXt155506fvx47zX79u1TUVGRPvzwQz355JNas2aNxo8fr7KyMm3dujVyP01NTZo1a5bWr18/oP0fPXpUv/rVr1RdXa1x48Zd0mvHAAVImA0bNgSSgvfeey/0mp6enuDs2bMXPPbJJ58E1113XfDwww/3PnbkyJFAUjBu3Ljg2LFjvY83NjYGkoKlS5f2PnbXXXcFBQUFwZkzZ3ofi8fjwdy5c4MZM2b0PlZfXx9ICurr6/s8VlVVNaDXuHDhwmDu3Lm9P5cULFmyZEBrMTCclEmWlpamMWPGSJLi8bg6OjrU09OjwsJCvf/++32uLysr09SpU3t/Pnv2bM2ZM0fbt2+XJHV0dOjtt9/Wgw8+qFOnTqm9vV3t7e06efKkSktL1dzcrNbW1tD9lJSUKAgCrVq1qt+919fX6/XXX9fatWsv7UXjkhDlMHjttdd08803Kz09XRkZGcrMzNS2bdvU2dnZ59oZM2b0eWzmzJlqaWmRJB06dEhBEOjpp59WZmbmBT+qqqokSSdOnLjsPff09OhHP/qRvv/97+trX/vaZT8fwg34G/wgMTZt2qSKigqVlZXp8ccfV1ZWltLS0rR69WodPnz4kp8vHo9LkpYvX67S0tKLXpObm3tZe5b++3fbAwcOqKampvd/CJ85deqUWlpalJWVpS984QuX/Wtd6YgyybZs2aKcnBzV1dUpFov1Pv7Zqfa/mpub+zx28OBBTZs2TZKUk5MjSRo9erTuvvvuxG/4U0ePHtX58+d1++2395nV1taqtrZWW7duVVlZ2ZDt4UpBlEmWlpYm6b/fWvCzKBsbG9XQ0KAbb7yxz/VvvPGGWltbe/9e2dTUpMbGRv34xz+WJGVlZamkpEQ1NTV67LHH9KUvfemC9W1tbcrMzAzdT3d3t44eParJkydr8uTJodd95zvf0S233NLn8fLyct13331avHix5syZE/naMTBEOQReffVVvfXWW30er6ys1IIFC1RXV6fy8nLdf//9OnLkiF5++WXl5+fr9OnTfdbk5uaquLhYjzzyiM6ePau1a9cqIyNDTzzxRO81L7zwgoqLi1VQUKDFixcrJydHx48fV0NDg44dO6bdu3eH7rWpqUnz589XVVVV5D/25OXlKS8v76Kz6dOnc0ImEFEOgZdeeumij1dUVKiiokL//Oc/VVNToz/+8Y/Kz8/Xpk2btHnz5ot+ofiiRYt01VVXae3atTpx4oRmz56t9evXX3Ai5ufna+fOnfrpT3+qjRs36uTJk8rKytKtt96qlStXDtXLxBCJBQHfohlwwi0RwAxRAmaIEjBDlIAZogTMECVghigBMwP+4oHPf50mgMEZyJcFcFICZogSMEOUgBmiBMwQJWCGKAEzRAmYIUrADFECZogSMEOUgBmiBMwQJWCGKAEzRAmYIUrADFECZogSMEOUgBmiBMwQJWCGKAEzRAmYIUrADFECZogSMEOUgBmiBMwQJWCGKAEzRAmYIUrADFECZogSMEOUgBmiBMyMGu4NXI7CwsLI+bJly0Jn3/3udxO9nZS1YsWKyPlVV4X/v3316tWJ3k7K46QEzBAlYIYoATNECZghSsAMUQJmiBIwM6LvU/b09ETO77nnntBZf/feqqurB7WnVNTd3R05LygoSNJOrgyclIAZogTMECVghigBM0QJmCFKwMyIviWya9euyHltbW3orLS0NHItt0T+36hR0b9NiouLk7STKwMnJWCGKAEzRAmYIUrADFECZogSMEOUgJkRfZ+yP11dXcO9hZSwefPmyPkvfvGL0NnkyZMj17a3tw9qT6mMkxIwQ5SAGaIEzBAlYIYoATNECZhJ6VsiSIxjx45Fzs+dOxc6e+CBByLXvvTSS4PaUyrjpATMECVghigBM0QJmCFKwAxRAmaIEjCT0vcpz58/HzqLxWJJ3Elqa2xsDJ3NnTs3ci33KfvipATMECVghigBM0QJmCFKwAxRAmZS+pbItm3bQmd33XVXEneS2t55553Q2U9+8pMk7iQ1cFICZogSMEOUgBmiBMwQJWCGKAEzRAmYiQVBEAzowhH4VqdbbrkldPaXv/wlcm1+fn7orL+PXLzSTJgwIXTW0tISufbaa69N8G68DSQ3TkrADFECZogSMEOUgBmiBMwQJWAmpd+6dfr06dDZqFHRL/3b3/526GzNmjWD3lMq6uzsDJ1dc801kWuLiopCZ+++++6g9zSScVICZogSMEOUgBmiBMwQJWCGKAEzRAmYSem3bkV5/fXXI+dRr/db3/pWoreTsuLxeOT8mWeeCZ2tWLEi0dsZdrx1CxiBiBIwQ5SAGaIEzBAlYIYoATMp/datKFu2bImcX2lvz0pLSwudXX311YN+3r/97W+R85kzZ4bOsrOzI9d2dXWFzj755JPojRnjpATMECVghigBM0QJmCFKwAxRAmaIEjBzxd6n/PjjjyPn119/fejspptuilx7+PDh0FleXl7k2nvvvTdyHuWGG24InX31q1+NXPvFL34xdDZ9+vTItVH3C6dNmxa5tqCgIHT2la98JXLtz3/+89DZhg0bItc646QEzBAlYIYoATNECZghSsAMUQJmrthPs+vvn+obGhpCZ/297eujjz4KnVVVVUWu7ejoCJ3t2bMncm3Uvvp7C1WU1tbWyHlbW1vobOXKlZFrKyoqQmc5OTmRa0ciPs0OGIGIEjBDlIAZogTMECVghigBM0QJmLli71P2Z9myZaGzn/3sZ5Frm5qaQmeVlZWRa1taWkJnUW+RcjVlypTI+c6dO0NnUR8/KUmnT58e1J6GE/cpgRGIKAEzRAmYIUrADFECZogSMHPFfppdf/7617+Gzj744IPItUuXLg2d9ff2q1TT3d0dOY/61MAHHnggcu1I/sS6KJyUgBmiBMwQJWCGKAEzRAmYIUrADFECZrhPGaKxsTF0VlRUlMSdpLaotwRmZ2cncSc+OCkBM0QJmCFKwAxRAmaIEjBDlIAZbolgWEV9utsAP2gx5XBSAmaIEjBDlIAZogTMECVghigBM0QJmOG7bmFIpaenR87//ve/h87+8Y9/RK6dP3/+oPY0nPiuW8AIRJSAGaIEzBAlYIYoATNECZjhrVsYUmfOnImc//a3vw2d9fddt1IVJyVghigBM0QJmCFKwAxRAmaIEjBDlIAZ7lNiWG3evDl0xn1KABaIEjBDlIAZogTMECVghigBM3yaHZBEfJodMAIRJWCGKAEzRAmYIUrADFECZogSMEOUgBmiBMwQJWCGKAEzRAmYIUrADFECZogSMEOUgBmiBMwQJWCGKAEzRAmYIUrADFECZgb8XbcG+EmUAC4TJyVghigBM0QJmCFKwAxRAmaIEjBDlIAZogTMECVg5v8AWWLQZxmV+6kAAAAASUVORK5CYII=", | ||
"text/plain": [ | ||
"<Figure size 800x600 with 2 Axes>" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
} | ||
], | ||
"source": [ | ||
"# Once we have the loader, we can get batches from it over multiple epochs, to train the model\n", | ||
"# Let us look at one batch \n", | ||
"import matplotlib.pyplot as plt\n", | ||
"fig, axs = plt.subplots(2, figsize=(8, 6))\n", | ||
"\n", | ||
"batch = next(iter(loader))\n", | ||
" \n", | ||
"\n", | ||
"print(batch)\n", | ||
"print(f\"There are {len(batch)} samples in this batch\")\n", | ||
"\n", | ||
"# Since we used default_collate, each batch is a dictionary, with two keys: \"image\" and \"label\"\n", | ||
"# The value of key \"image\" is a stacked tensor of images in the batch\n", | ||
"# Similarly, the value of key \"label\" is a stacked tensor of labels in the batch\n", | ||
"images = batch[\"image\"]\n", | ||
"labels = batch[\"label\"]\n", | ||
"\n", | ||
"#let's also display the two items\n", | ||
"for i in range(len(images)):\n", | ||
" axs[i].imshow(images[i].squeeze(), cmap='gray')\n", | ||
" axs[i].set_title(f\"Label: {labels[i]}\") \n", | ||
" axs[i].set_axis_off()\n" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.12.6" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "d8513771-36ac-4d03-b890-35108bce2211", | ||
"metadata": {}, | ||
"source": [ | ||
"### Loading and processing IMDB movie review dataset\n", | ||
"In this example, we will load the IMDB dataset from Hugging Face, \n", | ||
"use `torchdata.nodes` to process it and generate training batches." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"id": "eb3b507c-2ad1-410d-a834-6847182de684", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from datasets import load_dataset\n", | ||
"from transformers import BertTokenizer, BertForSequenceClassification" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"id": "089f1126-7125-4274-9d71-5c949ccc7bbd", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import torch\n", | ||
"from torch.utils.data import default_collate, RandomSampler, SequentialSampler" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"id": "2afac7d9-3d66-4195-8647-dc7034d306f2", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Load IMDB dataset from huggingface datasets and select the \"train\" split\n", | ||
"dataset = load_dataset(\"imdb\", streaming=False)\n", | ||
"dataset = dataset[\"train\"]\n", | ||
"# Since dataset is a Map-style dataset, we can setup a sampler to shuffle the data\n", | ||
"# Please refer to the migration guide here https://pytorch.org/data/main/migrate_to_nodes_from_utils.html\n", | ||
"# to migrate from torch.utils.data to torchdata.nodes\n", | ||
"\n", | ||
"sampler = RandomSampler(dataset)\n", | ||
"# Use a standard bert tokenizer\n", | ||
"tokenizer = BertTokenizer.from_pretrained(\"bert-base-uncased\")\n", | ||
"# Now we can set up some torchdata.nodes to create our pre-proc pipeline" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "09e08a47-573c-4d32-9a02-36cd8150db60", | ||
"metadata": {}, | ||
"source": [ | ||
"All torchdata.nodes.BaseNode implementations are Iterators.\n", | ||
"MapStyleWrapper creates an Iterator that combines sampler and dataset to create an iterator.\n", | ||
"Under the hood, MapStyleWrapper just does:\n", | ||
"```python\n", | ||
"node = IterableWrapper(sampler)\n", | ||
"node = Mapper(node, map_fn=dataset.__getitem__) # You can parallelize this with ParallelMapper\n", | ||
"```" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"id": "02af5479-ee69-41d8-ab2d-bf154b84bc15", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from torchdata.nodes import MapStyleWrapper, ParallelMapper, Batcher, PinMemory, Loader\n", | ||
"node = MapStyleWrapper(map_dataset=dataset, sampler=sampler)\n", | ||
"\n", | ||
"# Now we want to transform the raw inputs. We can just use another Mapper with\n", | ||
"# a custom map_fn to perform this. Using ParallelMapper allows us to use multiple\n", | ||
"# threads (or processes) to parallelize this work and have it run in the background\n", | ||
"max_len = 512\n", | ||
"batch_size = 2\n", | ||
"def bert_transform(item):\n", | ||
" encoding = tokenizer.encode_plus(\n", | ||
" item[\"text\"],\n", | ||
" add_special_tokens=True,\n", | ||
" max_length=max_len,\n", | ||
" padding=\"max_length\",\n", | ||
" truncation=True,\n", | ||
" return_attention_mask=True,\n", | ||
" return_tensors=\"pt\",\n", | ||
" )\n", | ||
" return {\n", | ||
" \"input_ids\": encoding[\"input_ids\"].flatten(),\n", | ||
" \"attention_mask\": encoding[\"attention_mask\"].flatten(),\n", | ||
" \"labels\": torch.tensor(item[\"label\"], dtype=torch.long),\n", | ||
" }\n", | ||
"node = ParallelMapper(node, map_fn=bert_transform, num_workers=2) # output items are Dict[str, tensor]\n", | ||
"\n", | ||
"# Next we batch the inputs, and then apply a collate_fn with another Mapper\n", | ||
"# to stack the tensors between. We use torch.utils.data.default_collate for this\n", | ||
"node = Batcher(node, batch_size=batch_size) # output items are List[Dict[str, tensor]]\n", | ||
"node = ParallelMapper(node, map_fn=default_collate, num_workers=2) # outputs are Dict[str, tensor]\n", | ||
"\n", | ||
"# we can optionally apply pin_memory to the batches\n", | ||
"if torch.cuda.is_available():\n", | ||
" node = PinMemory(node)\n", | ||
"\n", | ||
"# Since nodes are iterators, they need to be manually .reset() between epochs.\n", | ||
"# We can wrap the root node in Loader to convert it to a more conventional Iterable.\n", | ||
"loader = Loader(node)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"id": "60fd54f3-62ef-47aa-a790-853cb4899f13", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"{'input_ids': tensor([[ 101, 1045, 2572, ..., 2143, 2000, 102],\n", | ||
" [ 101, 2004, 1037, ..., 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, ..., 1, 1, 1],\n", | ||
" [1, 1, 1, ..., 0, 0, 0]]), 'labels': tensor([0, 1])}\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# Inspect a batch\n", | ||
"batch = next(iter(loader))\n", | ||
"print(batch)\n", | ||
"# In a batch we get three keys, as defined in the method `bert_transform`.\n", | ||
"# Since the batch size is 2, two samples are stacked together for each key." | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.12.6" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |