Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add examples for HF datasets #1371

Merged
merged 24 commits into from
Dec 12, 2024
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
182 changes: 182 additions & 0 deletions examples/nodes/hf_datasets_nodes_mnist.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,182 @@
{
"cells": [
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a title

{
"cell_type": "markdown",
"id": "1cb04783-4c83-43bc-91c6-b980d836de34",
"metadata": {},
"source": [
"### Loading and processing the MNIST dataset\n",
"In this example, we will load the MNIST dataset from Hugging Face, \n",
"use `torchdata.nodes` to process it and generate training batches."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "93026478-6dbd-4ac0-8507-360a3a2000c5",
"metadata": {},
"outputs": [],
"source": [
"from datasets import load_dataset\n",
"# Load the mnist dataset from HuggingFace datasets and convert the format to \"torch\"\n",
"dataset = load_dataset(\"ylecun/mnist\").with_format(\"torch\")\n",
"\n",
"# Getting the train dataset\n",
"dataset = dataset[\"train\"]"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "b0c46c4b-0194-4127-a218-e24ec54a3149",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"from torch.utils.data import default_collate, RandomSampler, SequentialSampler\n",
"\n",
"torch.manual_seed(42)\n",
"\n",
"# Defining samplers\n",
"# Since datasets is a Map-style dataset, we can setup a sampler to shuffle the data\n",
"sampler = RandomSampler(dataset)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "6f643c48-c6fb-4e8a-9461-fdf96b45b04b",
"metadata": {},
"outputs": [],
"source": [
"# Now we can set up some torchdata.nodes to create our pre-proc pipeline\n",
"from torchdata.nodes import MapStyleWrapper, ParallelMapper, Batcher, PinMemory, Loader\n",
"\n",
"# All torchdata.nodes.BaseNode implementations are Iterators.\n",
"# MapStyleWrapper creates an Iterator that combines sampler and dataset to create an iterator.\n",
"#\n",
"# Under the hood, MapStyleWrapper just does:\n",
"# > node = IterableWrapper(sampler)\n",
"# > node = Mapper(node, map_fn=dataset.__getitem__) # You can parallelize this with ParallelMapper\n",
"\n",
"node = MapStyleWrapper(map_dataset=dataset, sampler=sampler)\n",
"\n",
"# Now we want to transform the raw inputs. We can just use another Mapper with\n",
"# a custom map_fn to perform this. Using ParallelMapper allows us to use multiple\n",
"# threads (or processes) to parallelize this work and have it run in the background\n",
"# We need a mapper function to convert a dtype and also normalize\n",
"def map_fn(item):\n",
" image = item[\"image\"].to(torch.float32)/255\n",
" label = item[\"label\"]\n",
"\n",
" return {\"image\":image, \"label\":label}\n",
" \n",
"node = ParallelMapper(node, map_fn=map_fn, num_workers=2) # output items are Dict[str, tensor]\n",
"\n",
"\n",
"# Hyperparameters\n",
"batch_size = 2 \n",
"\n",
"# Next we batch the inputs, and then apply a collate_fn with another Mapper\n",
"# to stack the tensor. We use torch.utils.data.default_collate for this\n",
"node = Batcher(node, batch_size=batch_size) # output items are List[Dict[str, tensor]]\n",
"node = ParallelMapper(node, map_fn=default_collate, num_workers=2) # outputs are Dict[str, tensor]\n",
"\n",
"# we can optionally apply pin_memory to the batches\n",
"if torch.cuda.is_available():\n",
" node = PinMemory(node)\n",
"\n",
"# Since nodes are iterators, they need to be manually .reset() between epochs.\n",
"# Instead, we can wrap the root node in Loader to convert it to a more conventional Iterable.\n",
"loader = Loader(node)"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "c97a79ba-e6b3-4ac7-a4c5-edc8f9c58ff4",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'image': tensor([[[[0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" ...,\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.]]],\n",
"\n",
"\n",
" [[[0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" ...,\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.],\n",
" [0., 0., 0., ..., 0., 0., 0.]]]]), 'label': tensor([1, 4])}\n",
"There are 2 samples in this batch\n"
]
},
{
"data": {
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAOUAAAH4CAYAAAC19irnAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAUqElEQVR4nO3dbWyV9f3H8c+x3JQxBSmtDqKFWkjpUqezgwZLLN6sTsnSLui2LMPGhCUOXcdA0QdSdpOxTpkE8aaZE2x4BhbjBnHLYskyU1uJA4QJFEKHNA5a6lpYw013rv+Daf+ycl0t5fT008P7lfCA871+h98xvP0BV89pLAiCQABsXDXcGwBwIaIEzBAlYIYoATNECZghSsAMUQJmiBIwQ5SAGaI01dLSolgspmeffTZhz7ljxw7FYjHt2LEjYc+JxCPKBNq4caNisZh27tw53FsZEgcOHNDSpUs1d+5cpaenKxaLqaWlZbi3lXKIEgPW0NCgdevW6dSpU5o1a9ZwbydlESUG7Jvf/Kb+9a9/6YMPPtD3vve94d5OyiLKJDt37pxWrlyp2267TRMmTND48eM1b9481dfXh6557rnnlJ2drXHjxumOO+7Q3r17+1yzf/9+LVy4UJMmTVJ6eroKCwv15ptv9ruf7u5u7d+/X+3t7f1eO2nSJF199dX9XofLQ5RJ1tXVpVdeeUUlJSWqrq7WqlWr1NbWptLSUu3atavP9bW1tVq3bp2WLFmip556Snv37tWdd96p48eP916zb98+FRUV6cMPP9STTz6pNWvWaPz48SorK9PWrVsj99PU1KRZs2Zp/fr1iX6pGKRRw72BK821116rlpYWjRkzpvexxYsXKy8vT88//7x+97vfXXD9oUOH1NzcrKlTp0qS7r33Xs2ZM0fV1dX6zW9+I0mqrKzUjTfeqPfee09jx46VJP3whz9UcXGxVqxYofLy8iS9OiQCJ2WSpaWl9QYZj8fV0dGhnp4eFRYW6v333+9zfVlZWW+QkjR79mzNmTNH27dvlyR1dHTo7bff1oMPPqhTp06pvb1d7e3tOnnypEpLS9Xc3KzW1tbQ/ZSUlCgIAq1atSqxLxSDRpTD4LXXXtPNN9+s9PR0ZWRkKDMzU9u2bVNnZ2efa2fMmNHnsZkzZ/beijh06JCCINDTTz+tzMzMC35UVVVJkk6cODGkrweJxR9fk2zTpk2qqKhQWVmZHn/8cWVlZSktLU2rV6/W4cOHL/n54vG4JGn58uUqLS296DW5ubmXtWckF1Em2ZYtW5STk6O6ujrFYrHexz871f5Xc3Nzn8cOHjyoadOmSZJycnIkSaNHj9bdd9+d+A0j6fjja5KlpaVJkj7/eWWNjY1qaGi46PVvvPHGBX8nbGpqUmNjo77xjW9IkrKyslRSUqKamhp9/PHHfda3tbVF7udSbokgOTgph8Crr76qt956q8/jlZWVWrBggerq6lReXq77779fR44c0csvv6z8/HydPn26z5rc3FwVFxfrkUce0dmzZ7V27VplZGToiSee6L3mhRdeUHFxsQoKCrR48WLl5OTo+PHjamho0LFjx7R79+7QvTY1NWn+/Pmqqqrq9x97Ojs79fzzz0uS3nnnHUnS+vXrNXHiRE2cOFGPPvroQP7zoD8BEmbDhg2BpNAfH330URCPx4Nf/vKXQXZ2djB27Njg1ltvDf7whz8EDz30UJCdnd37XEeOHAkkBc8880ywZs2a4IYbbgjGjh0bzJs3L9i9e3efX/vw4cPBokWLguuvvz4YPXp0MHXq1GDBggXBli1beq+pr68PJAX19fV9Hquqqur39X22p4v9+PzecXliQcDnvgJO+DslYIYoATNECZghSsAMUQJmiBIwQ5SAmQF/Rc/nv04TwOAM5MsCOCkBM0QJmCFKwAxRAmaIEjBDlIAZ3uScQn7wgx+EzmpqaiLXTp8+PXTG9wtJLk5KwAxRAmaIEjBDlIAZogTMECVghigBM9ynTCHvvvtu6Ky/twzdfvvtoTPuUyYXJyVghigBM0QJmCFKwAxRAmaIEjDDLRFIkqZMmTLcW8CnOCkBM0QJmCFKwAxRAmaIEjBDlIAZogTMECVghigBM0QJmCFKwAxRAmaIEjBDlIAZogTMECVghigBM0QJmCFKwAxRAmaIEjBDlIAZogTMECVghigBM0QJmCFKwAxRAmaIEjBDlIAZogTMECVghigBM0QJmCFKwAxRAmaIEjAzarg3AA8HDhwY7i3gU5yUgBmiBMwQJWCGKAEzRAmYIUrADLdEUkheXt6g1+7ZsyeBO8Hl4KQEzBAlYIYoATNECZghSsAMUQJmiBIww33KFPLlL395uLeABOCkBMwQJWCGKAEzRAmYIUrADFECZrglkkK+/vWvD/cWkACclIAZogTMECVghigBM0QJmCFKwAxRAma4T5lCrrnmmtDZvn37Ite2trYmejsYJE5KwAxRAmaIEjBDlIAZogTMECVghlsiI8h1110XOc/KygqdtbS0RK49f/78YLaEIcBJCZghSsAMUQJmiBIwQ5SAGaIEzBAlYIb7lCNIRkbGoOcvvvhioreDIcJJCZghSsAMUQJmiBIwQ5SAGaIEzHBL5AqxefPm4d4CBoiTEjBDlIAZogTMECVghigBM0QJmCFKwAz3KUeQhx9+eNBrz5w5k8CdYChxUgJmiBIwQ5SAGaIEzBAlYIYoATPcEkmyoqKiyPnEiRNDZ+Xl5YP+defNmxc5X7JkSeiss7Mzcu2uXbtCZ9u3b49ce+7cucj5lYiTEjBDlIAZogTMECVghigBM0QJmCFKwEwsCIJgQBfGYkO9FyvZ2dmhs7Kyssi1CxcuDJ31d58yLS0tcj7S/Pvf/46cR/32+/3vfx+5dtGiRaGz//znP9EbGyYDyY2TEjBDlIAZogTMECVghigBM0QJmBnRb93q7/bBY489FjqrrKyMXDtlypTQ2ejRo6M3FqG5uTly3tjYGDr705/+NOhftz8nT54Mne3Zsydy7fz580NnCxYsiFw7ZsyY0Nmf//znyLXxeDxyPlJxUgJmiBIwQ5SAGaIEzBAlYIYoATNECZgZ0W/dys3NjZwfPHhw0M/d1dUVOnvooYci127bti101t+9tVS994b/4q1bwAhElIAZogTMECVghigBM0QJmBnRt0SAkYZbIsAIRJSAGaIEzBAlYIYoATNECZghSsAMUQJmiBIwQ5SAGaIEzBAlYIYoATNECZghSsAMUQJmiBIwQ5SAGaIEzBAlYIYoATNECZghSsAMUQJmiBIwQ5SAGaIEzBAlYIYoATNECZghSsAMUQJmiBIwQ5SAGaIEzBAlYIYoATNECZghSsAMUQJmiBIwQ5SAGaIEzBAlYIYoATNECZghSsAMUQJmiBIwQ5SAGaIEzBAlYGbUQC8MgmAo9wHgU5yUgBmiBMwQJWCGKAEzRAmYIUrADFECZogSMEOUgBmiBMwQJWCGKAEzRAmYIUrADFECZogSMEOUgBmiBMwQJWCGKE21tLQoFovp2WefTdhz7tixQ7FYTDt27EjYcyLxiDKBNm7cqFgspp07dw73VpLinnvuUSwW06OPPjrcW0kpRIlBqaurU0NDw3BvIyURJS7ZmTNntGzZMq1YsWK4t5KSiDLJzp07p5UrV+q2227ThAkTNH78eM2bN0/19fWha5577jllZ2dr3LhxuuOOO7R3794+1+zfv18LFy7UpEmTlJ6ersLCQr355pv97qe7u1v79+9Xe3v7gF/Dr3/9a8XjcS1fvnzAazBwRJlkXV1deuWVV1RSUqLq6mqtWrVKbW1tKi0t1a5du/pcX1tbq3Xr1mnJkiV66qmntHfvXt155506fvx47zX79u1TUVGRPvzwQz355JNas2aNxo8fr7KyMm3dujVyP01NTZo1a5bWr18/oP0fPXpUv/rVr1RdXa1x48Zd0mvHAAVImA0bNgSSgvfeey/0mp6enuDs2bMXPPbJJ58E1113XfDwww/3PnbkyJFAUjBu3Ljg2LFjvY83NjYGkoKlS5f2PnbXXXcFBQUFwZkzZ3ofi8fjwdy5c4MZM2b0PlZfXx9ICurr6/s8VlVVNaDXuHDhwmDu3Lm9P5cULFmyZEBrMTCclEmWlpamMWPGSJLi8bg6OjrU09OjwsJCvf/++32uLysr09SpU3t/Pnv2bM2ZM0fbt2+XJHV0dOjtt9/Wgw8+qFOnTqm9vV3t7e06efKkSktL1dzcrNbW1tD9lJSUKAgCrVq1qt+919fX6/XXX9fatWsv7UXjkhDlMHjttdd08803Kz09XRkZGcrMzNS2bdvU2dnZ59oZM2b0eWzmzJlqaWmRJB06dEhBEOjpp59WZmbmBT+qqqokSSdOnLjsPff09OhHP/qRvv/97+trX/vaZT8fwg34G/wgMTZt2qSKigqVlZXp8ccfV1ZWltLS0rR69WodPnz4kp8vHo9LkpYvX67S0tKLXpObm3tZe5b++3fbAwcOqKampvd/CJ85deqUWlpalJWVpS984QuX/Wtd6YgyybZs2aKcnBzV1dUpFov1Pv7Zqfa/mpub+zx28OBBTZs2TZKUk5MjSRo9erTuvvvuxG/4U0ePHtX58+d1++2395nV1taqtrZWW7duVVlZ2ZDt4UpBlEmWlpYm6b/fWvCzKBsbG9XQ0KAbb7yxz/VvvPGGWltbe/9e2dTUpMbGRv34xz+WJGVlZamkpEQ1NTV67LHH9KUvfemC9W1tbcrMzAzdT3d3t44eParJkydr8uTJodd95zvf0S233NLn8fLyct13331avHix5syZE/naMTBEOQReffVVvfXWW30er6ys1IIFC1RXV6fy8nLdf//9OnLkiF5++WXl5+fr9OnTfdbk5uaquLhYjzzyiM6ePau1a9cqIyNDTzzxRO81L7zwgoqLi1VQUKDFixcrJydHx48fV0NDg44dO6bdu3eH7rWpqUnz589XVVVV5D/25OXlKS8v76Kz6dOnc0ImEFEOgZdeeumij1dUVKiiokL//Oc/VVNToz/+8Y/Kz8/Xpk2btHnz5ot+ofiiRYt01VVXae3atTpx4oRmz56t9evXX3Ai5ufna+fOnfrpT3+qjRs36uTJk8rKytKtt96qlStXDtXLxBCJBQHfohlwwi0RwAxRAmaIEjBDlIAZogTMECVghigBMwP+4oHPf50mgMEZyJcFcFICZogSMEOUgBmiBMwQJWCGKAEzRAmYIUrADFECZogSMEOUgBmiBMwQJWCGKAEzRAmYIUrADFECZogSMEOUgBmiBMwQJWCGKAEzRAmYIUrADFECZogSMEOUgBmiBMwQJWCGKAEzRAmYIUrADFECZogSMEOUgBmiBMyMGu4NXI7CwsLI+bJly0Jn3/3udxO9nZS1YsWKyPlVV4X/v3316tWJ3k7K46QEzBAlYIYoATNECZghSsAMUQJmiBIwM6LvU/b09ETO77nnntBZf/feqqurB7WnVNTd3R05LygoSNJOrgyclIAZogTMECVghigBM0QJmCFKwMyIviWya9euyHltbW3orLS0NHItt0T+36hR0b9NiouLk7STKwMnJWCGKAEzRAmYIUrADFECZogSMEOUgJkRfZ+yP11dXcO9hZSwefPmyPkvfvGL0NnkyZMj17a3tw9qT6mMkxIwQ5SAGaIEzBAlYIYoATNECZhJ6VsiSIxjx45Fzs+dOxc6e+CBByLXvvTSS4PaUyrjpATMECVghigBM0QJmCFKwAxRAmaIEjCT0vcpz58/HzqLxWJJ3Elqa2xsDJ3NnTs3ci33KfvipATMECVghigBM0QJmCFKwAxRAmZS+pbItm3bQmd33XVXEneS2t55553Q2U9+8pMk7iQ1cFICZogSMEOUgBmiBMwQJWCGKAEzRAmYiQVBEAzowhH4VqdbbrkldPaXv/wlcm1+fn7orL+PXLzSTJgwIXTW0tISufbaa69N8G68DSQ3TkrADFECZogSMEOUgBmiBMwQJWAmpd+6dfr06dDZqFHRL/3b3/526GzNmjWD3lMq6uzsDJ1dc801kWuLiopCZ+++++6g9zSScVICZogSMEOUgBmiBMwQJWCGKAEzRAmYSem3bkV5/fXXI+dRr/db3/pWoreTsuLxeOT8mWeeCZ2tWLEi0dsZdrx1CxiBiBIwQ5SAGaIEzBAlYIYoATMp/datKFu2bImcX2lvz0pLSwudXX311YN+3r/97W+R85kzZ4bOsrOzI9d2dXWFzj755JPojRnjpATMECVghigBM0QJmCFKwAxRAmaIEjBzxd6n/PjjjyPn119/fejspptuilx7+PDh0FleXl7k2nvvvTdyHuWGG24InX31q1+NXPvFL34xdDZ9+vTItVH3C6dNmxa5tqCgIHT2la98JXLtz3/+89DZhg0bItc646QEzBAlYIYoATNECZghSsAMUQJmrthPs+vvn+obGhpCZ/297eujjz4KnVVVVUWu7ejoCJ3t2bMncm3Uvvp7C1WU1tbWyHlbW1vobOXKlZFrKyoqQmc5OTmRa0ciPs0OGIGIEjBDlIAZogTMECVghigBM0QJmLli71P2Z9myZaGzn/3sZ5Frm5qaQmeVlZWRa1taWkJnUW+RcjVlypTI+c6dO0NnUR8/KUmnT58e1J6GE/cpgRGIKAEzRAmYIUrADFECZogSMHPFfppdf/7617+Gzj744IPItUuXLg2d9ff2q1TT3d0dOY/61MAHHnggcu1I/sS6KJyUgBmiBMwQJWCGKAEzRAmYIUrADFECZrhPGaKxsTF0VlRUlMSdpLaotwRmZ2cncSc+OCkBM0QJmCFKwAxRAmaIEjBDlIAZbolgWEV9utsAP2gx5XBSAmaIEjBDlIAZogTMECVghigBM0QJmOG7bmFIpaenR87//ve/h87+8Y9/RK6dP3/+oPY0nPiuW8AIRJSAGaIEzBAlYIYoATNECZjhrVsYUmfOnImc//a3vw2d9fddt1IVJyVghigBM0QJmCFKwAxRAmaIEjBDlIAZ7lNiWG3evDl0xn1KABaIEjBDlIAZogTMECVghigBM3yaHZBEfJodMAIRJWCGKAEzRAmYIUrADFECZogSMEOUgBmiBMwQJWCGKAEzRAmYIUrADFECZogSMEOUgBmiBMwQJWCGKAEzRAmYIUrADFECZgb8XbcG+EmUAC4TJyVghigBM0QJmCFKwAxRAmaIEjBDlIAZogTMECVg5v8AWWLQZxmV+6kAAAAASUVORK5CYII=",
"text/plain": [
"<Figure size 800x600 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"# Once we have the loader, we can get batches from it over multiple epochs, to train the model\n",
"# Let us look at one batch \n",
"import matplotlib.pyplot as plt\n",
"fig, axs = plt.subplots(2, figsize=(8, 6))\n",
"\n",
"batch = next(iter(loader))\n",
" \n",
"\n",
"print(batch)\n",
"print(f\"There are {len(batch)} samples in this batch\")\n",
"\n",
"# Since we used default_collate, each batch is a dictionary, with two keys: \"image\" and \"label\"\n",
"# The value of key \"image\" is a stacked tensor of images in the batch\n",
"# Similarly, the value of key \"label\" is a stacked tensor of labels in the batch\n",
"images = batch[\"image\"]\n",
"labels = batch[\"label\"]\n",
"\n",
"#let's also display the two items\n",
"for i in range(len(images)):\n",
" axs[i].imshow(images[i].squeeze(), cmap='gray')\n",
" axs[i].set_title(f\"Label: {labels[i]}\") \n",
" axs[i].set_axis_off()\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
161 changes: 161 additions & 0 deletions examples/nodes/hf_imdb_bert.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,161 @@
{
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add a title

"cells": [
{
"cell_type": "markdown",
"id": "d8513771-36ac-4d03-b890-35108bce2211",
"metadata": {},
"source": [
"### Loading and processing IMDB movie review dataset\n",
"In this example, we will load the IMDB dataset from Hugging Face, \n",
"use `torchdata.nodes` to process it and generate training batches."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "eb3b507c-2ad1-410d-a834-6847182de684",
"metadata": {},
"outputs": [],
"source": [
"from datasets import load_dataset\n",
"from transformers import BertTokenizer, BertForSequenceClassification"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "089f1126-7125-4274-9d71-5c949ccc7bbd",
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"from torch.utils.data import default_collate, RandomSampler, SequentialSampler"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "2afac7d9-3d66-4195-8647-dc7034d306f2",
"metadata": {},
"outputs": [],
"source": [
"# Load IMDB dataset from huggingface datasets and select the \"train\" split\n",
"dataset = load_dataset(\"imdb\", streaming=False)\n",
"dataset = dataset[\"train\"]\n",
"# Since dataset is a Map-style dataset, we can setup a sampler to shuffle the data\n",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's link to the docs similar to other example

"# Please refer to the migration guide here https://pytorch.org/data/docs/build/html/migrate_to_nodes_from_utils.html\n",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I gave you the wrong link, after nightly run this link is live now: https://pytorch.org/data/main/migrate_to_nodes_from_utils.html

"# to migrate from torch.utils.data to torchdata.nodes\n",
"\n",
"sampler = RandomSampler(dataset)\n",
"# Use a standard bert tokenizer\n",
"tokenizer = BertTokenizer.from_pretrained(\"bert-base-uncased\")\n",
"# Now we can set up some torchdata.nodes to create our pre-proc pipeline"
]
},
{
"cell_type": "markdown",
"id": "09e08a47-573c-4d32-9a02-36cd8150db60",
"metadata": {},
"source": [
"All torchdata.nodes.BaseNode implementations are Iterators.\n",
"MapStyleWrapper creates an Iterator that combines sampler and dataset to create an iterator.\n",
"Under the hood, MapStyleWrapper just does:\n",
"```python\n",
"node = IterableWrapper(sampler)\n",
"node = Mapper(node, map_fn=dataset.__getitem__) # You can parallelize this with ParallelMapper\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "02af5479-ee69-41d8-ab2d-bf154b84bc15",
"metadata": {},
"outputs": [],
"source": [
"from torchdata.nodes import MapStyleWrapper, ParallelMapper, Batcher, PinMemory, Loader\n",
"node = MapStyleWrapper(map_dataset=dataset, sampler=sampler)\n",
"\n",
"# Now we want to transform the raw inputs. We can just use another Mapper with\n",
"# a custom map_fn to perform this. Using ParallelMapper allows us to use multiple\n",
"# threads (or processes) to parallelize this work and have it run in the background\n",
"max_len = 512\n",
"batch_size = 2\n",
"def bert_transform(item):\n",
" encoding = tokenizer.encode_plus(\n",
" item[\"text\"],\n",
" add_special_tokens=True,\n",
" max_length=max_len,\n",
" padding=\"max_length\",\n",
" truncation=True,\n",
" return_attention_mask=True,\n",
" return_tensors=\"pt\",\n",
" )\n",
" return {\n",
" \"input_ids\": encoding[\"input_ids\"].flatten(),\n",
" \"attention_mask\": encoding[\"attention_mask\"].flatten(),\n",
" \"labels\": torch.tensor(item[\"label\"], dtype=torch.long),\n",
" }\n",
"node = ParallelMapper(node, map_fn=bert_transform, num_workers=2) # output items are Dict[str, tensor]\n",
"\n",
"# Next we batch the inputs, and then apply a collate_fn with another Mapper\n",
"# to stack the tensors between. We use torch.utils.data.default_collate for this\n",
"node = Batcher(node, batch_size=batch_size) # output items are List[Dict[str, tensor]]\n",
"node = ParallelMapper(node, map_fn=default_collate, num_workers=2) # outputs are Dict[str, tensor]\n",
"\n",
"# we can optionally apply pin_memory to the batches\n",
"if torch.cuda.is_available():\n",
" node = PinMemory(node)\n",
"\n",
"# Since nodes are iterators, they need to be manually .reset() between epochs.\n",
"# We can wrap the root node in Loader to convert it to a more conventional Iterable.\n",
"loader = Loader(node)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "60fd54f3-62ef-47aa-a790-853cb4899f13",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{'input_ids': tensor([[ 101, 1045, 2572, ..., 2143, 2000, 102],\n",
" [ 101, 2004, 1037, ..., 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, ..., 1, 1, 1],\n",
" [1, 1, 1, ..., 0, 0, 0]]), 'labels': tensor([0, 1])}\n"
]
}
],
"source": [
"# Inspect a batch\n",
"batch = next(iter(loader))\n",
"print(batch)\n",
"# In a batch we get three keys, as defined in the method `bert_transform`.\n",
"# Since the batch size is 2, two samples are stacked together for each key."
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.6"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Loading