-
Notifications
You must be signed in to change notification settings - Fork 155
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add examples for HF datasets #1371
Changes from 23 commits
5520cbd
2af0274
304ac41
a085263
5852720
36e1b5d
d906a44
f4e320b
c796b7b
06ac590
290e8e3
84013b5
5d37733
652d178
55b1472
4efcc85
3afd46c
b16c0b3
9d8115b
05df26e
e2cea9b
a479f5f
f7abbc6
c82b5fc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,182 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "1cb04783-4c83-43bc-91c6-b980d836de34", | ||
"metadata": {}, | ||
"source": [ | ||
"### Loading and processing the MNIST dataset\n", | ||
"In this example, we will load the MNIST dataset from Hugging Face, \n", | ||
"use `torchdata.nodes` to process it and generate training batches." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"id": "93026478-6dbd-4ac0-8507-360a3a2000c5", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from datasets import load_dataset\n", | ||
"# Load the mnist dataset from HuggingFace datasets and convert the format to \"torch\"\n", | ||
"dataset = load_dataset(\"ylecun/mnist\").with_format(\"torch\")\n", | ||
"\n", | ||
"# Getting the train dataset\n", | ||
"dataset = dataset[\"train\"]" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"id": "b0c46c4b-0194-4127-a218-e24ec54a3149", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import torch\n", | ||
"import torch.nn as nn\n", | ||
"from torch.utils.data import default_collate, RandomSampler, SequentialSampler\n", | ||
"\n", | ||
"torch.manual_seed(42)\n", | ||
"\n", | ||
"# Defining samplers\n", | ||
"# Since datasets is a Map-style dataset, we can setup a sampler to shuffle the data\n", | ||
"sampler = RandomSampler(dataset)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"id": "6f643c48-c6fb-4e8a-9461-fdf96b45b04b", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Now we can set up some torchdata.nodes to create our pre-proc pipeline\n", | ||
"from torchdata.nodes import MapStyleWrapper, ParallelMapper, Batcher, PinMemory, Loader\n", | ||
"\n", | ||
"# All torchdata.nodes.BaseNode implementations are Iterators.\n", | ||
"# MapStyleWrapper creates an Iterator that combines sampler and dataset to create an iterator.\n", | ||
"#\n", | ||
"# Under the hood, MapStyleWrapper just does:\n", | ||
"# > node = IterableWrapper(sampler)\n", | ||
"# > node = Mapper(node, map_fn=dataset.__getitem__) # You can parallelize this with ParallelMapper\n", | ||
"\n", | ||
"node = MapStyleWrapper(map_dataset=dataset, sampler=sampler)\n", | ||
"\n", | ||
"# Now we want to transform the raw inputs. We can just use another Mapper with\n", | ||
"# a custom map_fn to perform this. Using ParallelMapper allows us to use multiple\n", | ||
"# threads (or processes) to parallelize this work and have it run in the background\n", | ||
"# We need a mapper function to convert a dtype and also normalize\n", | ||
"def map_fn(item):\n", | ||
" image = item[\"image\"].to(torch.float32)/255\n", | ||
" label = item[\"label\"]\n", | ||
"\n", | ||
" return {\"image\":image, \"label\":label}\n", | ||
" \n", | ||
"node = ParallelMapper(node, map_fn=map_fn, num_workers=2) # output items are Dict[str, tensor]\n", | ||
"\n", | ||
"\n", | ||
"# Hyperparameters\n", | ||
"batch_size = 2 \n", | ||
"\n", | ||
"# Next we batch the inputs, and then apply a collate_fn with another Mapper\n", | ||
"# to stack the tensor. We use torch.utils.data.default_collate for this\n", | ||
"node = Batcher(node, batch_size=batch_size) # output items are List[Dict[str, tensor]]\n", | ||
"node = ParallelMapper(node, map_fn=default_collate, num_workers=2) # outputs are Dict[str, tensor]\n", | ||
"\n", | ||
"# we can optionally apply pin_memory to the batches\n", | ||
"if torch.cuda.is_available():\n", | ||
" node = PinMemory(node)\n", | ||
"\n", | ||
"# Since nodes are iterators, they need to be manually .reset() between epochs.\n", | ||
"# Instead, we can wrap the root node in Loader to convert it to a more conventional Iterable.\n", | ||
"loader = Loader(node)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"id": "c97a79ba-e6b3-4ac7-a4c5-edc8f9c58ff4", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"{'image': tensor([[[[0., 0., 0., ..., 0., 0., 0.],\n", | ||
" [0., 0., 0., ..., 0., 0., 0.],\n", | ||
" [0., 0., 0., ..., 0., 0., 0.],\n", | ||
" ...,\n", | ||
" [0., 0., 0., ..., 0., 0., 0.],\n", | ||
" [0., 0., 0., ..., 0., 0., 0.],\n", | ||
" [0., 0., 0., ..., 0., 0., 0.]]],\n", | ||
"\n", | ||
"\n", | ||
" [[[0., 0., 0., ..., 0., 0., 0.],\n", | ||
" [0., 0., 0., ..., 0., 0., 0.],\n", | ||
" [0., 0., 0., ..., 0., 0., 0.],\n", | ||
" ...,\n", | ||
" [0., 0., 0., ..., 0., 0., 0.],\n", | ||
" [0., 0., 0., ..., 0., 0., 0.],\n", | ||
" [0., 0., 0., ..., 0., 0., 0.]]]]), 'label': tensor([1, 4])}\n", | ||
"There are 2 samples in this batch\n" | ||
] | ||
}, | ||
{ | ||
"data": { | ||
"image/png": "iVBORw0KGgoAAAANSUhEUgAAAOUAAAH4CAYAAAC19irnAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjkuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8hTgPZAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAUqElEQVR4nO3dbWyV9f3H8c+x3JQxBSmtDqKFWkjpUqezgwZLLN6sTsnSLui2LMPGhCUOXcdA0QdSdpOxTpkE8aaZE2x4BhbjBnHLYskyU1uJA4QJFEKHNA5a6lpYw013rv+Daf+ycl0t5fT008P7lfCA871+h98xvP0BV89pLAiCQABsXDXcGwBwIaIEzBAlYIYoATNECZghSsAMUQJmiBIwQ5SAGaI01dLSolgspmeffTZhz7ljxw7FYjHt2LEjYc+JxCPKBNq4caNisZh27tw53FsZEgcOHNDSpUs1d+5cpaenKxaLqaWlZbi3lXKIEgPW0NCgdevW6dSpU5o1a9ZwbydlESUG7Jvf/Kb+9a9/6YMPPtD3vve94d5OyiLKJDt37pxWrlyp2267TRMmTND48eM1b9481dfXh6557rnnlJ2drXHjxumOO+7Q3r17+1yzf/9+LVy4UJMmTVJ6eroKCwv15ptv9ruf7u5u7d+/X+3t7f1eO2nSJF199dX9XofLQ5RJ1tXVpVdeeUUlJSWqrq7WqlWr1NbWptLSUu3atavP9bW1tVq3bp2WLFmip556Snv37tWdd96p48eP916zb98+FRUV6cMPP9STTz6pNWvWaPz48SorK9PWrVsj99PU1KRZs2Zp/fr1iX6pGKRRw72BK821116rlpYWjRkzpvexxYsXKy8vT88//7x+97vfXXD9oUOH1NzcrKlTp0qS7r33Xs2ZM0fV1dX6zW9+I0mqrKzUjTfeqPfee09jx46VJP3whz9UcXGxVqxYofLy8iS9OiQCJ2WSpaWl9QYZj8fV0dGhnp4eFRYW6v333+9zfVlZWW+QkjR79mzNmTNH27dvlyR1dHTo7bff1oMPPqhTp06pvb1d7e3tOnnypEpLS9Xc3KzW1tbQ/ZSUlCgIAq1atSqxLxSDRpTD4LXXXtPNN9+s9PR0ZWRkKDMzU9u2bVNnZ2efa2fMmNHnsZkzZ/beijh06JCCINDTTz+tzMzMC35UVVVJkk6cODGkrweJxR9fk2zTpk2qqKhQWVmZHn/8cWVlZSktLU2rV6/W4cOHL/n54vG4JGn58uUqLS296DW5ubmXtWckF1Em2ZYtW5STk6O6ujrFYrHexz871f5Xc3Nzn8cOHjyoadOmSZJycnIkSaNHj9bdd9+d+A0j6fjja5KlpaVJkj7/eWWNjY1qaGi46PVvvPHGBX8nbGpqUmNjo77xjW9IkrKyslRSUqKamhp9/PHHfda3tbVF7udSbokgOTgph8Crr76qt956q8/jlZWVWrBggerq6lReXq77779fR44c0csvv6z8/HydPn26z5rc3FwVFxfrkUce0dmzZ7V27VplZGToiSee6L3mhRdeUHFxsQoKCrR48WLl5OTo+PHjamho0LFjx7R79+7QvTY1NWn+/Pmqqqrq9x97Ojs79fzzz0uS3nnnHUnS+vXrNXHiRE2cOFGPPvroQP7zoD8BEmbDhg2BpNAfH330URCPx4Nf/vKXQXZ2djB27Njg1ltvDf7whz8EDz30UJCdnd37XEeOHAkkBc8880ywZs2a4IYbbgjGjh0bzJs3L9i9e3efX/vw4cPBokWLguuvvz4YPXp0MHXq1GDBggXBli1beq+pr68PJAX19fV9Hquqqur39X22p4v9+PzecXliQcDnvgJO+DslYIYoATNECZghSsAMUQJmiBIwQ5SAmQF/Rc/nv04TwOAM5MsCOCkBM0QJmCFKwAxRAmaIEjBDlIAZ3uScQn7wgx+EzmpqaiLXTp8+PXTG9wtJLk5KwAxRAmaIEjBDlIAZogTMECVghigBM9ynTCHvvvtu6Ky/twzdfvvtoTPuUyYXJyVghigBM0QJmCFKwAxRAmaIEjDDLRFIkqZMmTLcW8CnOCkBM0QJmCFKwAxRAmaIEjBDlIAZogTMECVghigBM0QJmCFKwAxRAmaIEjBDlIAZogTMECVghigBM0QJmCFKwAxRAmaIEjBDlIAZogTMECVghigBM0QJmCFKwAxRAmaIEjBDlIAZogTMECVghigBM0QJmCFKwAxRAmaIEjAzarg3AA8HDhwY7i3gU5yUgBmiBMwQJWCGKAEzRAmYIUrADLdEUkheXt6g1+7ZsyeBO8Hl4KQEzBAlYIYoATNECZghSsAMUQJmiBIww33KFPLlL395uLeABOCkBMwQJWCGKAEzRAmYIUrADFECZrglkkK+/vWvD/cWkACclIAZogTMECVghigBM0QJmCFKwAxRAma4T5lCrrnmmtDZvn37Ite2trYmejsYJE5KwAxRAmaIEjBDlIAZogTMECVghlsiI8h1110XOc/KygqdtbS0RK49f/78YLaEIcBJCZghSsAMUQJmiBIwQ5SAGaIEzBAlYIb7lCNIRkbGoOcvvvhioreDIcJJCZghSsAMUQJmiBIwQ5SAGaIEzHBL5AqxefPm4d4CBoiTEjBDlIAZogTMECVghigBM0QJmCFKwAz3KUeQhx9+eNBrz5w5k8CdYChxUgJmiBIwQ5SAGaIEzBAlYIYoATPcEkmyoqKiyPnEiRNDZ+Xl5YP+defNmxc5X7JkSeiss7Mzcu2uXbtCZ9u3b49ce+7cucj5lYiTEjBDlIAZogTMECVghigBM0QJmCFKwEwsCIJgQBfGYkO9FyvZ2dmhs7Kyssi1CxcuDJ31d58yLS0tcj7S/Pvf/46cR/32+/3vfx+5dtGiRaGz//znP9EbGyYDyY2TEjBDlIAZogTMECVghigBM0QJmBnRb93q7/bBY489FjqrrKyMXDtlypTQ2ejRo6M3FqG5uTly3tjYGDr705/+NOhftz8nT54Mne3Zsydy7fz580NnCxYsiFw7ZsyY0Nmf//znyLXxeDxyPlJxUgJmiBIwQ5SAGaIEzBAlYIYoATNECZgZ0W/dys3NjZwfPHhw0M/d1dUVOnvooYci127bti101t+9tVS994b/4q1bwAhElIAZogTMECVghigBM0QJmBnRt0SAkYZbIsAIRJSAGaIEzBAlYIYoATNECZghSsAMUQJmiBIwQ5SAGaIEzBAlYIYoATNECZghSsAMUQJmiBIwQ5SAGaIEzBAlYIYoATNECZghSsAMUQJmiBIwQ5SAGaIEzBAlYIYoATNECZghSsAMUQJmiBIwQ5SAGaIEzBAlYIYoATNECZghSsAMUQJmiBIwQ5SAGaIEzBAlYIYoATNECZghSsAMUQJmiBIwQ5SAGaIEzBAlYGbUQC8MgmAo9wHgU5yUgBmiBMwQJWCGKAEzRAmYIUrADFECZogSMEOUgBmiBMwQJWCGKAEzRAmYIUrADFECZogSMEOUgBmiBMwQJWCGKE21tLQoFovp2WefTdhz7tixQ7FYTDt27EjYcyLxiDKBNm7cqFgspp07dw73VpLinnvuUSwW06OPPjrcW0kpRIlBqaurU0NDw3BvIyURJS7ZmTNntGzZMq1YsWK4t5KSiDLJzp07p5UrV+q2227ThAkTNH78eM2bN0/19fWha5577jllZ2dr3LhxuuOOO7R3794+1+zfv18LFy7UpEmTlJ6ersLCQr355pv97qe7u1v79+9Xe3v7gF/Dr3/9a8XjcS1fvnzAazBwRJlkXV1deuWVV1RSUqLq6mqtWrVKbW1tKi0t1a5du/pcX1tbq3Xr1mnJkiV66qmntHfvXt155506fvx47zX79u1TUVGRPvzwQz355JNas2aNxo8fr7KyMm3dujVyP01NTZo1a5bWr18/oP0fPXpUv/rVr1RdXa1x48Zd0mvHAAVImA0bNgSSgvfeey/0mp6enuDs2bMXPPbJJ58E1113XfDwww/3PnbkyJFAUjBu3Ljg2LFjvY83NjYGkoKlS5f2PnbXXXcFBQUFwZkzZ3ofi8fjwdy5c4MZM2b0PlZfXx9ICurr6/s8VlVVNaDXuHDhwmDu3Lm9P5cULFmyZEBrMTCclEmWlpamMWPGSJLi8bg6OjrU09OjwsJCvf/++32uLysr09SpU3t/Pnv2bM2ZM0fbt2+XJHV0dOjtt9/Wgw8+qFOnTqm9vV3t7e06efKkSktL1dzcrNbW1tD9lJSUKAgCrVq1qt+919fX6/XXX9fatWsv7UXjkhDlMHjttdd08803Kz09XRkZGcrMzNS2bdvU2dnZ59oZM2b0eWzmzJlqaWmRJB06dEhBEOjpp59WZmbmBT+qqqokSSdOnLjsPff09OhHP/qRvv/97+trX/vaZT8fwg34G/wgMTZt2qSKigqVlZXp8ccfV1ZWltLS0rR69WodPnz4kp8vHo9LkpYvX67S0tKLXpObm3tZe5b++3fbAwcOqKampvd/CJ85deqUWlpalJWVpS984QuX/Wtd6YgyybZs2aKcnBzV1dUpFov1Pv7Zqfa/mpub+zx28OBBTZs2TZKUk5MjSRo9erTuvvvuxG/4U0ePHtX58+d1++2395nV1taqtrZWW7duVVlZ2ZDt4UpBlEmWlpYm6b/fWvCzKBsbG9XQ0KAbb7yxz/VvvPGGWltbe/9e2dTUpMbGRv34xz+WJGVlZamkpEQ1NTV67LHH9KUvfemC9W1tbcrMzAzdT3d3t44eParJkydr8uTJodd95zvf0S233NLn8fLyct13331avHix5syZE/naMTBEOQReffVVvfXWW30er6ys1IIFC1RXV6fy8nLdf//9OnLkiF5++WXl5+fr9OnTfdbk5uaquLhYjzzyiM6ePau1a9cqIyNDTzzxRO81L7zwgoqLi1VQUKDFixcrJydHx48fV0NDg44dO6bdu3eH7rWpqUnz589XVVVV5D/25OXlKS8v76Kz6dOnc0ImEFEOgZdeeumij1dUVKiiokL//Oc/VVNToz/+8Y/Kz8/Xpk2btHnz5ot+ofiiRYt01VVXae3atTpx4oRmz56t9evXX3Ai5ufna+fOnfrpT3+qjRs36uTJk8rKytKtt96qlStXDtXLxBCJBQHfohlwwi0RwAxRAmaIEjBDlIAZogTMECVghigBMwP+4oHPf50mgMEZyJcFcFICZogSMEOUgBmiBMwQJWCGKAEzRAmYIUrADFECZogSMEOUgBmiBMwQJWCGKAEzRAmYIUrADFECZogSMEOUgBmiBMwQJWCGKAEzRAmYIUrADFECZogSMEOUgBmiBMwQJWCGKAEzRAmYIUrADFECZogSMEOUgBmiBMyMGu4NXI7CwsLI+bJly0Jn3/3udxO9nZS1YsWKyPlVV4X/v3316tWJ3k7K46QEzBAlYIYoATNECZghSsAMUQJmiBIwM6LvU/b09ETO77nnntBZf/feqqurB7WnVNTd3R05LygoSNJOrgyclIAZogTMECVghigBM0QJmCFKwMyIviWya9euyHltbW3orLS0NHItt0T+36hR0b9NiouLk7STKwMnJWCGKAEzRAmYIUrADFECZogSMEOUgJkRfZ+yP11dXcO9hZSwefPmyPkvfvGL0NnkyZMj17a3tw9qT6mMkxIwQ5SAGaIEzBAlYIYoATNECZhJ6VsiSIxjx45Fzs+dOxc6e+CBByLXvvTSS4PaUyrjpATMECVghigBM0QJmCFKwAxRAmaIEjCT0vcpz58/HzqLxWJJ3Elqa2xsDJ3NnTs3ci33KfvipATMECVghigBM0QJmCFKwAxRAmZS+pbItm3bQmd33XVXEneS2t55553Q2U9+8pMk7iQ1cFICZogSMEOUgBmiBMwQJWCGKAEzRAmYiQVBEAzowhH4VqdbbrkldPaXv/wlcm1+fn7orL+PXLzSTJgwIXTW0tISufbaa69N8G68DSQ3TkrADFECZogSMEOUgBmiBMwQJWAmpd+6dfr06dDZqFHRL/3b3/526GzNmjWD3lMq6uzsDJ1dc801kWuLiopCZ+++++6g9zSScVICZogSMEOUgBmiBMwQJWCGKAEzRAmYSem3bkV5/fXXI+dRr/db3/pWoreTsuLxeOT8mWeeCZ2tWLEi0dsZdrx1CxiBiBIwQ5SAGaIEzBAlYIYoATMp/datKFu2bImcX2lvz0pLSwudXX311YN+3r/97W+R85kzZ4bOsrOzI9d2dXWFzj755JPojRnjpATMECVghigBM0QJmCFKwAxRAmaIEjBzxd6n/PjjjyPn119/fejspptuilx7+PDh0FleXl7k2nvvvTdyHuWGG24InX31q1+NXPvFL34xdDZ9+vTItVH3C6dNmxa5tqCgIHT2la98JXLtz3/+89DZhg0bItc646QEzBAlYIYoATNECZghSsAMUQJmrthPs+vvn+obGhpCZ/297eujjz4KnVVVVUWu7ejoCJ3t2bMncm3Uvvp7C1WU1tbWyHlbW1vobOXKlZFrKyoqQmc5OTmRa0ciPs0OGIGIEjBDlIAZogTMECVghigBM0QJmLli71P2Z9myZaGzn/3sZ5Frm5qaQmeVlZWRa1taWkJnUW+RcjVlypTI+c6dO0NnUR8/KUmnT58e1J6GE/cpgRGIKAEzRAmYIUrADFECZogSMHPFfppdf/7617+Gzj744IPItUuXLg2d9ff2q1TT3d0dOY/61MAHHnggcu1I/sS6KJyUgBmiBMwQJWCGKAEzRAmYIUrADFECZrhPGaKxsTF0VlRUlMSdpLaotwRmZ2cncSc+OCkBM0QJmCFKwAxRAmaIEjBDlIAZbolgWEV9utsAP2gx5XBSAmaIEjBDlIAZogTMECVghigBM0QJmOG7bmFIpaenR87//ve/h87+8Y9/RK6dP3/+oPY0nPiuW8AIRJSAGaIEzBAlYIYoATNECZjhrVsYUmfOnImc//a3vw2d9fddt1IVJyVghigBM0QJmCFKwAxRAmaIEjBDlIAZ7lNiWG3evDl0xn1KABaIEjBDlIAZogTMECVghigBM3yaHZBEfJodMAIRJWCGKAEzRAmYIUrADFECZogSMEOUgBmiBMwQJWCGKAEzRAmYIUrADFECZogSMEOUgBmiBMwQJWCGKAEzRAmYIUrADFECZgb8XbcG+EmUAC4TJyVghigBM0QJmCFKwAxRAmaIEjBDlIAZogTMECVg5v8AWWLQZxmV+6kAAAAASUVORK5CYII=", | ||
"text/plain": [ | ||
"<Figure size 800x600 with 2 Axes>" | ||
] | ||
}, | ||
"metadata": {}, | ||
"output_type": "display_data" | ||
} | ||
], | ||
"source": [ | ||
"# Once we have the loader, we can get batches from it over multiple epochs, to train the model\n", | ||
"# Let us look at one batch \n", | ||
"import matplotlib.pyplot as plt\n", | ||
"fig, axs = plt.subplots(2, figsize=(8, 6))\n", | ||
"\n", | ||
"batch = next(iter(loader))\n", | ||
" \n", | ||
"\n", | ||
"print(batch)\n", | ||
"print(f\"There are {len(batch)} samples in this batch\")\n", | ||
"\n", | ||
"# Since we used default_collate, each batch is a dictionary, with two keys: \"image\" and \"label\"\n", | ||
"# The value of key \"image\" is a stacked tensor of images in the batch\n", | ||
"# Similarly, the value of key \"label\" is a stacked tensor of labels in the batch\n", | ||
"images = batch[\"image\"]\n", | ||
"labels = batch[\"label\"]\n", | ||
"\n", | ||
"#let's also display the two items\n", | ||
"for i in range(len(images)):\n", | ||
" axs[i].imshow(images[i].squeeze(), cmap='gray')\n", | ||
" axs[i].set_title(f\"Label: {labels[i]}\") \n", | ||
" axs[i].set_axis_off()\n" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.12.6" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,161 @@ | ||
{ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Add a title |
||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"id": "d8513771-36ac-4d03-b890-35108bce2211", | ||
"metadata": {}, | ||
"source": [ | ||
"### Loading and processing IMDB movie review dataset\n", | ||
"In this example, we will load the IMDB dataset from Hugging Face, \n", | ||
"use `torchdata.nodes` to process it and generate training batches." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 1, | ||
"id": "eb3b507c-2ad1-410d-a834-6847182de684", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from datasets import load_dataset\n", | ||
"from transformers import BertTokenizer, BertForSequenceClassification" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 2, | ||
"id": "089f1126-7125-4274-9d71-5c949ccc7bbd", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import torch\n", | ||
"from torch.utils.data import default_collate, RandomSampler, SequentialSampler" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 3, | ||
"id": "2afac7d9-3d66-4195-8647-dc7034d306f2", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Load IMDB dataset from huggingface datasets and select the \"train\" split\n", | ||
"dataset = load_dataset(\"imdb\", streaming=False)\n", | ||
"dataset = dataset[\"train\"]\n", | ||
"# Since dataset is a Map-style dataset, we can setup a sampler to shuffle the data\n", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's link to the docs similar to other example |
||
"# Please refer to the migration guide here https://pytorch.org/data/docs/build/html/migrate_to_nodes_from_utils.html\n", | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I gave you the wrong link, after nightly run this link is live now: https://pytorch.org/data/main/migrate_to_nodes_from_utils.html |
||
"# to migrate from torch.utils.data to torchdata.nodes\n", | ||
"\n", | ||
"sampler = RandomSampler(dataset)\n", | ||
"# Use a standard bert tokenizer\n", | ||
"tokenizer = BertTokenizer.from_pretrained(\"bert-base-uncased\")\n", | ||
"# Now we can set up some torchdata.nodes to create our pre-proc pipeline" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"id": "09e08a47-573c-4d32-9a02-36cd8150db60", | ||
"metadata": {}, | ||
"source": [ | ||
"All torchdata.nodes.BaseNode implementations are Iterators.\n", | ||
"MapStyleWrapper creates an Iterator that combines sampler and dataset to create an iterator.\n", | ||
"Under the hood, MapStyleWrapper just does:\n", | ||
"```python\n", | ||
"node = IterableWrapper(sampler)\n", | ||
"node = Mapper(node, map_fn=dataset.__getitem__) # You can parallelize this with ParallelMapper\n", | ||
"```" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 4, | ||
"id": "02af5479-ee69-41d8-ab2d-bf154b84bc15", | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"from torchdata.nodes import MapStyleWrapper, ParallelMapper, Batcher, PinMemory, Loader\n", | ||
"node = MapStyleWrapper(map_dataset=dataset, sampler=sampler)\n", | ||
"\n", | ||
"# Now we want to transform the raw inputs. We can just use another Mapper with\n", | ||
"# a custom map_fn to perform this. Using ParallelMapper allows us to use multiple\n", | ||
"# threads (or processes) to parallelize this work and have it run in the background\n", | ||
"max_len = 512\n", | ||
"batch_size = 2\n", | ||
"def bert_transform(item):\n", | ||
" encoding = tokenizer.encode_plus(\n", | ||
" item[\"text\"],\n", | ||
" add_special_tokens=True,\n", | ||
" max_length=max_len,\n", | ||
" padding=\"max_length\",\n", | ||
" truncation=True,\n", | ||
" return_attention_mask=True,\n", | ||
" return_tensors=\"pt\",\n", | ||
" )\n", | ||
" return {\n", | ||
" \"input_ids\": encoding[\"input_ids\"].flatten(),\n", | ||
" \"attention_mask\": encoding[\"attention_mask\"].flatten(),\n", | ||
" \"labels\": torch.tensor(item[\"label\"], dtype=torch.long),\n", | ||
" }\n", | ||
"node = ParallelMapper(node, map_fn=bert_transform, num_workers=2) # output items are Dict[str, tensor]\n", | ||
"\n", | ||
"# Next we batch the inputs, and then apply a collate_fn with another Mapper\n", | ||
"# to stack the tensors between. We use torch.utils.data.default_collate for this\n", | ||
"node = Batcher(node, batch_size=batch_size) # output items are List[Dict[str, tensor]]\n", | ||
"node = ParallelMapper(node, map_fn=default_collate, num_workers=2) # outputs are Dict[str, tensor]\n", | ||
"\n", | ||
"# we can optionally apply pin_memory to the batches\n", | ||
"if torch.cuda.is_available():\n", | ||
" node = PinMemory(node)\n", | ||
"\n", | ||
"# Since nodes are iterators, they need to be manually .reset() between epochs.\n", | ||
"# We can wrap the root node in Loader to convert it to a more conventional Iterable.\n", | ||
"loader = Loader(node)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 5, | ||
"id": "60fd54f3-62ef-47aa-a790-853cb4899f13", | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"{'input_ids': tensor([[ 101, 1045, 2572, ..., 2143, 2000, 102],\n", | ||
" [ 101, 2004, 1037, ..., 0, 0, 0]]), 'attention_mask': tensor([[1, 1, 1, ..., 1, 1, 1],\n", | ||
" [1, 1, 1, ..., 0, 0, 0]]), 'labels': tensor([0, 1])}\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"# Inspect a batch\n", | ||
"batch = next(iter(loader))\n", | ||
"print(batch)\n", | ||
"# In a batch we get three keys, as defined in the method `bert_transform`.\n", | ||
"# Since the batch size is 2, two samples are stacked together for each key." | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.12.6" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 5 | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a title