Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensor size issue #1

Open
arunraja-hub opened this issue Nov 28, 2024 · 11 comments
Open

Tensor size issue #1

arunraja-hub opened this issue Nov 28, 2024 · 11 comments

Comments

@arunraja-hub
Copy link

When I was just trying to run the training using python train.py params_x1x3x4_diffusion_mosesaq_20240824 0, as suggested in the readme, I got the following error:

RuntimeError: Trying to resize storage that is not resizable

According to lucidrains/denoising-diffusion-pytorch#248 the solution is to change num_workers in the dataloader to 0 but that resulted in the following error:

RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 1176 but got size 595 for tensor number 1 in the list.

Could you please provide some guidance on this?

@keiradams
Copy link
Collaborator

keiradams commented Nov 28, 2024

Hi! I have not experienced this error, so I suspect it has something to do with our different training setups or package versions.

To help debug, can you try the following:

  • Make sure you can successfully run inference code provided by the RUNME_{}.ipynb notebooks.
  • In train.py, make sure you can call dataset[0] after initializing dataset = HeteroDatset(...)
  • In train.py, make sure you can call next(iter(train_loader)) after initializing train_loader = torch_geometric.loader.DataLoader(...), with batch_size = 1 and batch_size > 1.

If all of that works, then I would guess it is related to an issue with DDPM in Pytorch-Lightning with your particular system set-up. Are you trying to train with 1 GPU? On a CPU? On multiple GPUs? The parameters in parameters/params_x1x3x4_diffusion_mosesaq_20240824.py specify 'num_gpus': 2 and 'multiprocessing_spawn': True. Both of those could be causing issues with your specific setup?

Also, does this error occur at the start of the training epochs? Or mid-way through training?

Additionally, make sure that the versions of your packages are the same as those listed in the README, particularly your Pytorch-Lightning, Pytorch, and PyG versions.

It would also help if you could provide the complete error traceback.

@arunraja-hub
Copy link
Author

Hi @keiradams , thanks for your quick reply. I did not make any changes to the code in the repo. I am able to run the RUNME notebooks using a new virtual environment I have setup without issue. However, for train.py I ran into the following error which I think might be due to the pytorch geometric version. I had to choose slightly different pytorch and pytorch geometric version to yours as my cuda version is different.

Seed set to 0
Traceback (most recent call last):
  File "/mnt/data/slurm-storage/aruraj/opig/shepherd/train.py", line 98, in <module>
    dataset = HeteroDataset(
              ^^^^^^^^^^^^^^
TypeError: Can't instantiate abstract class HeteroDataset with abstract method get

Here is the yaml of my local virtual env:

name: shepherd
channels:
  - pyg
  - pytorch
  - nvidia
  - anaconda
  - defaults
  - conda-forge
dependencies:
  - _libgcc_mutex=0.1=conda_forge
  - _openmp_mutex=4.5=2_gnu
  - aiofiles=22.1.0=py311h06a4308_0
  - aiosqlite=0.18.0=py311h06a4308_0
  - annotated-types=0.6.0=pyhd8ed1ab_0
  - anyio=3.5.0=py311h06a4308_0
  - appdirs=1.4.4=pyh9f0ad1d_0
  - argon2-cffi=21.3.0=pyhd3eb1b0_0
  - argon2-cffi-bindings=21.2.0=py311h5eee18b_0
  - arrow=1.3.0=pyhd8ed1ab_0
  - asttokens=2.0.5=pyhd3eb1b0_0
  - attrs=23.1.0=py311h06a4308_0
  - babel=2.11.0=py311h06a4308_0
  - backcall=0.2.0=pyhd3eb1b0_0
  - backoff=2.2.1=pyhd8ed1ab_0
  - beautifulsoup4=4.12.2=py311h06a4308_0
  - blas=1.0=mkl
  - bleach=4.1.0=pyhd3eb1b0_0
  - blessed=1.19.1=pyhe4f9e05_2
  - boto3=1.28.82=pyhd8ed1ab_0
  - bottleneck=1.3.5=py311hbed6279_0
  - brotli=1.0.9=h5eee18b_7
  - brotli-bin=1.0.9=h5eee18b_7
  - brotli-python=1.0.9=py311h6a678d5_7
  - bzip2=1.0.8=h7b6447c_0
  - ca-certificates=2023.08.22=h06a4308_0
  - cachecontrol=0.12.14=pyhd8ed1ab_0
  - celluloid=0.2.0=pyhd8ed1ab_0
  - certifi=2023.7.22=py311h06a4308_0
  - cffi=1.15.1=py311h5eee18b_3
  - charset-normalizer=2.0.4=pyhd3eb1b0_0
  - cleo=2.0.1=pyhd8ed1ab_0
  - click=8.1.7=unix_pyh707e725_0
  - colorama=0.4.6=pyhd8ed1ab_0
  - crashtest=0.4.1=pyhd8ed1ab_0
  - croniter=1.4.1=pyhd8ed1ab_0
  - cryptography=41.0.3=py311hdda0065_0
  - cuda-cudart=12.1.105=0
  - cuda-cupti=12.1.105=0
  - cuda-libraries=12.1.0=0
  - cuda-nvrtc=12.1.105=0
  - cuda-nvtx=12.1.105=0
  - cuda-opencl=12.3.52=0
  - cuda-runtime=12.1.0=0
  - cycler=0.11.0=pyhd3eb1b0_0
  - cyrus-sasl=2.1.28=h52b45da_1
  - dateutils=0.6.12=py_0
  - dbus=1.13.18=hb2f20db_0
  - debugpy=1.6.7=py311h6a678d5_0
  - decorator=5.1.1=pyhd3eb1b0_0
  - deepdiff=6.7.0=pyhd8ed1ab_0
  - defusedxml=0.7.1=pyhd3eb1b0_0
  - distlib=0.3.7=pyhd8ed1ab_0
  - docker-pycreds=0.4.0=py_0
  - dulwich=0.21.6=py311h459d7ec_2
  - entrypoints=0.4=py311h06a4308_0
  - executing=0.8.3=pyhd3eb1b0_0
  - expat=2.5.0=h6a678d5_0
  - fastapi=0.103.0=pyhd8ed1ab_0
  - ffmpeg=4.3=hf484d3e_0
  - filelock=3.9.0=py311h06a4308_0
  - fontconfig=2.14.1=h4c34cd2_2
  - fonttools=4.25.0=pyhd3eb1b0_0
  - freetype=2.12.1=h4a9f257_0
  - giflib=5.2.1=h5eee18b_3
  - gitdb=4.0.11=pyhd8ed1ab_0
  - gitpython=3.1.40=pyhd8ed1ab_0
  - glib=2.69.1=he621ea3_2
  - gmp=6.2.1=h295c915_3
  - gmpy2=2.1.2=py311hc9b5ff0_0
  - gnutls=3.6.15=he1e5248_0
  - gst-plugins-base=1.14.1=h6a678d5_1
  - gstreamer=1.14.1=h5eee18b_1
  - h11=0.14.0=pyhd8ed1ab_0
  - html5lib=1.1=pyh9f0ad1d_0
  - icu=58.2=he6710b0_3
  - idna=3.4=py311h06a4308_0
  - importlib-metadata=6.8.0=pyha770c72_0
  - importlib_metadata=6.8.0=hd8ed1ab_0
  - inquirer=3.1.3=pyhd8ed1ab_0
  - intel-openmp=2023.1.0=hdb19cb5_46305
  - ipykernel=6.25.0=py311h92b7b1e_0
  - ipython=8.15.0=py311h06a4308_0
  - ipython_genutils=0.2.0=pyhd3eb1b0_1
  - itsdangerous=2.1.2=pyhd8ed1ab_0
  - jaraco.classes=3.3.0=pyhd8ed1ab_0
  - jedi=0.18.1=py311h06a4308_1
  - jeepney=0.8.0=pyhd8ed1ab_0
  - jinja2=3.1.2=py311h06a4308_0
  - jmespath=1.0.1=pyhd8ed1ab_0
  - joblib=1.2.0=py311h06a4308_0
  - jpeg=9e=h5eee18b_1
  - json5=0.9.6=pyhd3eb1b0_0
  - jsonschema=4.17.3=py311h06a4308_0
  - jupyter=1.0.0=py311h06a4308_8
  - jupyter_client=7.4.9=py311h06a4308_0
  - jupyter_console=6.6.3=py311h06a4308_0
  - jupyter_core=5.3.0=py311h06a4308_0
  - jupyter_events=0.6.3=py311h06a4308_0
  - jupyter_server=1.23.4=py311h06a4308_0
  - jupyter_server_fileid=0.9.0=py311h06a4308_0
  - jupyter_server_ydoc=0.8.0=py311h06a4308_1
  - jupyter_ydoc=0.2.4=py311h06a4308_0
  - jupyterlab=3.6.3=py311h06a4308_0
  - jupyterlab_pygments=0.1.2=py_0
  - jupyterlab_server=2.22.0=py311h06a4308_0
  - keyring=23.13.1=py311h38be061_0
  - kiwisolver=1.4.4=py311h6a678d5_0
  - krb5=1.20.1=h143b758_1
  - lame=3.100=h7b6447c_0
  - lcms2=2.12=h3be6417_0
  - ld_impl_linux-64=2.38=h1181459_1
  - lerc=3.0=h295c915_0
  - libbrotlicommon=1.0.9=h5eee18b_7
  - libbrotlidec=1.0.9=h5eee18b_7
  - libbrotlienc=1.0.9=h5eee18b_7
  - libclang=14.0.6=default_hc6dbbc7_1
  - libclang13=14.0.6=default_he11475f_1
  - libcublas=12.1.0.26=0
  - libcufft=11.0.2.4=0
  - libcufile=1.8.0.34=0
  - libcups=2.4.2=h2d74bed_1
  - libcurand=10.3.4.52=0
  - libcusolver=11.4.4.55=0
  - libcusparse=12.0.2.55=0
  - libdeflate=1.17=h5eee18b_1
  - libedit=3.1.20221030=h5eee18b_0
  - libevent=2.1.12=hdbd6064_1
  - libffi=3.4.4=h6a678d5_0
  - libgcc-ng=13.2.0=h807b86a_2
  - libgfortran-ng=11.2.0=h00389a5_1
  - libgfortran5=11.2.0=h1234567_1
  - libgomp=13.2.0=h807b86a_2
  - libiconv=1.16=h7f8727e_2
  - libidn2=2.3.4=h5eee18b_0
  - libjpeg-turbo=2.0.0=h9bf148f_0
  - libllvm14=14.0.6=hdb19cb5_3
  - libnpp=12.0.2.50=0
  - libnvjitlink=12.1.105=0
  - libnvjpeg=12.1.1.14=0
  - libpng=1.6.39=h5eee18b_0
  - libpq=12.15=hdbd6064_1
  - libprotobuf=3.20.3=he621ea3_0
  - libsodium=1.0.18=h7b6447c_0
  - libstdcxx-ng=11.2.0=h1234567_1
  - libtasn1=4.19.0=h5eee18b_0
  - libtiff=4.5.1=h6a678d5_0
  - libunistring=0.9.10=h27cfd23_0
  - libuuid=1.41.5=h5eee18b_0
  - libwebp=1.3.2=h11a3e52_0
  - libwebp-base=1.3.2=h5eee18b_0
  - libxcb=1.15=h7f8727e_0
  - libxkbcommon=1.0.1=h5eee18b_1
  - libxml2=2.10.4=hcbfbd50_0
  - libxslt=1.1.37=h2085143_0
  - lightning=2.1.1=pyhd8ed1ab_0
  - lightning-cloud=0.5.50=pyhd8ed1ab_0
  - lightning-utilities=0.9.0=pyhd8ed1ab_0
  - llvm-openmp=14.0.6=h9e868ea_0
  - lockfile=0.12.2=py_1
  - lxml=4.9.3=py311hdbbb534_0
  - lz4-c=1.9.4=h6a678d5_0
  - markdown-it-py=3.0.0=pyhd8ed1ab_0
  - markupsafe=2.1.1=py311h5eee18b_0
  - matplotlib=3.7.2=py311h06a4308_0
  - matplotlib-base=3.7.2=py311ha02d727_0
  - matplotlib-inline=0.1.6=py311h06a4308_0
  - mdurl=0.1.0=pyhd8ed1ab_0
  - mistune=0.8.4=py311h5eee18b_1000
  - mkl=2023.1.0=h213fc3f_46343
  - mkl-service=2.4.0=py311h5eee18b_1
  - mkl_fft=1.3.8=py311h5eee18b_0
  - mkl_random=1.2.4=py311hdb19cb5_0
  - more-itertools=10.1.0=pyhd8ed1ab_0
  - mpc=1.1.0=h10f8cd9_1
  - mpfr=4.0.2=hb69a4c5_1
  - mpmath=1.3.0=py311h06a4308_0
  - msgpack-python=1.0.3=py311hdb19cb5_0
  - munkres=1.1.4=py_0
  - mysql=5.7.24=h721c034_2
  - nbclassic=0.5.5=py311h06a4308_0
  - nbclient=0.5.13=py311h06a4308_0
  - nbconvert=6.5.4=py311h06a4308_0
  - nbformat=5.9.2=py311h06a4308_0
  - ncurses=6.4=h6a678d5_0
  - nest-asyncio=1.5.6=py311h06a4308_0
  - nettle=3.7.3=hbbd107a_1
  - networkx=3.1=py311h06a4308_0
  - notebook=6.5.4=py311h06a4308_1
  - notebook-shim=0.2.2=py311h06a4308_0
  - nspr=4.35=h6a678d5_0
  - nss=3.89.1=h6a678d5_0
  - numexpr=2.8.7=py311h65dcdc2_0
  - openh264=2.1.1=h4ff587b_0
  - openjpeg=2.4.0=h3ad879b_0
  - openssl=3.1.4=hd590300_0
  - ordered-set=4.1.0=pyhd8ed1ab_0
  - orjson=3.9.10=py311h34b1e23_0
  - packaging=23.1=py311h06a4308_0
  - pandocfilters=1.5.0=pyhd3eb1b0_0
  - parso=0.8.3=pyhd3eb1b0_0
  - pathtools=0.1.2=py_1
  - pcre=8.45=h295c915_0
  - pexpect=4.8.0=pyhd3eb1b0_3
  - pickleshare=0.7.5=pyhd3eb1b0_1003
  - pillow=10.0.1=py311ha6cbd5a_0
  - pkginfo=1.9.6=pyhd8ed1ab_0
  - ply=3.11=py311h06a4308_0
  - poetry=1.4.2=linux_pyhd8ed1ab_0
  - poetry-core=1.5.2=pyhd8ed1ab_0
  - poetry-plugin-export=1.3.1=pyhd8ed1ab_0
  - prometheus_client=0.14.1=py311h06a4308_0
  - prompt-toolkit=3.0.36=py311h06a4308_0
  - prompt_toolkit=3.0.36=hd3eb1b0_0
  - protobuf=3.20.3=py311h6a678d5_0
  - psutil=5.9.0=py311h5eee18b_0
  - ptyprocess=0.7.0=pyhd3eb1b0_2
  - pure_eval=0.2.2=pyhd3eb1b0_0
  - pycparser=2.21=pyhd3eb1b0_0
  - pyg=2.4.0=py311_torch_2.1.0_cu121
  - pygments=2.15.1=py311h06a4308_1
  - pyjwt=2.8.0=pyhd8ed1ab_0
  - pyopenssl=23.2.0=py311h06a4308_0
  - pyparsing=3.0.9=py311h06a4308_0
  - pyproject_hooks=1.0.0=pyhd8ed1ab_0
  - pyqt=5.15.7=py311h6a678d5_0
  - pyqt5-sip=12.11.0=py311h6a678d5_0
  - pyrsistent=0.18.0=py311h5eee18b_0
  - pysocks=1.7.1=py311h06a4308_0
  - python=3.11.5=h955ad1f_0
  - python-build=0.10.0=pyhd8ed1ab_1
  - python-dateutil=2.8.2=pyhd3eb1b0_0
  - python-editor=1.0.4=py_0
  - python-fastjsonschema=2.16.2=py311h06a4308_0
  - python-installer=0.7.0=pyhd8ed1ab_0
  - python-json-logger=2.0.7=py311h06a4308_0
  - python-multipart=0.0.6=pyhd8ed1ab_0
  - python-tzdata=2023.3=pyhd3eb1b0_0
  - python_abi=3.11=2_cp311
  - pytorch=2.1.0=py3.11_cuda12.1_cudnn8.9.2_0
  - pytorch-cuda=12.1=ha16c6d3_5
  - pytorch-lightning=2.1.0=pyhd8ed1ab_0
  - pytorch-mutex=1.0=cuda
  - pytz=2023.3.post1=py311h06a4308_0
  - pyyaml=6.0.1=py311h5eee18b_0
  - pyzmq=23.2.0=py311h6a678d5_0
  - qt-main=5.15.2=h7358343_9
  - qt-webengine=5.15.9=h9ab4d14_7
  - qtconsole=5.4.2=py311h06a4308_0
  - qtpy=2.2.0=py311h06a4308_0
  - qtwebkit=5.212=h3fafdc1_5
  - rapidfuzz=2.13.7=py311ha02d727_0
  - readchar=4.0.5=pyhd8ed1ab_0
  - readline=8.2=h5eee18b_0
  - requests=2.31.0=py311h06a4308_0
  - requests-toolbelt=0.10.1=pyhd8ed1ab_0
  - rfc3339-validator=0.1.4=py311h06a4308_0
  - rfc3986-validator=0.1.1=py311h06a4308_0
  - rich=13.6.0=pyhd8ed1ab_0
  - s3transfer=0.7.0=pyhd8ed1ab_0
  - scikit-learn=1.2.2=py311h6a678d5_1
  - scipy=1.11.3=py311h08b1b3b_0
  - secretstorage=3.3.3=py311h38be061_2
  - send2trash=1.8.0=pyhd3eb1b0_1
  - setproctitle=1.3.3=py311h459d7ec_0
  - setuptools=68.0.0=py311h06a4308_0
  - shellingham=1.5.4=pyhd8ed1ab_0
  - sip=6.6.2=py311h6a678d5_0
  - six=1.16.0=pyhd3eb1b0_1
  - smmap=5.0.0=pyhd8ed1ab_0
  - sniffio=1.2.0=py311h06a4308_1
  - soupsieve=2.5=py311h06a4308_0
  - sqlite=3.41.2=h5eee18b_0
  - stack_data=0.2.0=pyhd3eb1b0_0
  - starlette=0.27.0=pyhd8ed1ab_0
  - starsessions=1.3.0=pyhd8ed1ab_0
  - sympy=1.11.1=py311h06a4308_0
  - tbb=2021.8.0=hdb19cb5_0
  - terminado=0.17.1=py311h06a4308_0
  - threadpoolctl=2.2.0=pyh0d69192_0
  - tinycss2=1.2.1=py311h06a4308_0
  - tk=8.6.12=h1ccaba5_0
  - toml=0.10.2=pyhd3eb1b0_0
  - tomli=2.0.1=pyhd8ed1ab_0
  - tomlkit=0.12.2=pyha770c72_0
  - torchaudio=2.1.0=py311_cu121
  - torchtriton=2.1.0=py311
  - torchvision=0.16.0=py311_cu121
  - tornado=6.3.3=py311h5eee18b_0
  - tqdm=4.65.0=py311h92b7b1e_0
  - traitlets=5.7.1=py311h06a4308_0
  - trove-classifiers=2023.11.7=pyhd8ed1ab_0
  - types-python-dateutil=2.8.19.14=pyhd8ed1ab_0
  - typing-extensions=4.7.1=py311h06a4308_0
  - typing_extensions=4.7.1=py311h06a4308_0
  - tzdata=2023c=h04d1e81_0
  - urllib3=1.26.18=py311h06a4308_0
  - uvicorn=0.24.0=py311h38be061_0
  - virtualenv=20.21.1=pyhd8ed1ab_0
  - wcwidth=0.2.5=pyhd3eb1b0_0
  - webencodings=0.5.1=py311h06a4308_1
  - websocket-client=0.58.0=py311h06a4308_4
  - websockets=12.0=py311h459d7ec_0
  - wheel=0.41.2=py311h06a4308_0
  - xz=5.4.2=h5eee18b_0
  - y-py=0.5.9=py311h52d8a92_0
  - yaml=0.2.5=h7b6447c_0
  - ypy-websocket=0.8.2=py311h06a4308_0
  - zeromq=4.3.4=h2531618_0
  - zipp=3.17.0=pyhd8ed1ab_0
  - zlib=1.2.13=h5eee18b_0
  - zstd=1.5.5=hc292b87_0
  - pip:
      - absl-py==2.1.0
      - addict==2.4.0
      - aiobotocore==2.12.3
      - aiohttp==3.9.3
      - aioitertools==0.11.0
      - aiosignal==1.3.1
      - alembic==1.12.1
      - anndata==0.10.7
      - antlr4-python3-runtime==4.9.3
      - array-api-compat==1.6
      - asciitree==0.3.3
      - ase==3.22.1
      - autograd==1.6.2
      - autograd-gamma==0.5.0
      - ax-platform==0.3.7
      - beartype==0.18.5
      - biopython==1.78
      - biothings-client==0.3.1
      - blinker==1.9.0
      - bokeh==3.4.1
      - botocore==1.34.69
      - botorch==0.10.0
      - bravado==11.0.3
      - bravado-core==6.1.0
      - cachetools==5.3.2
      - cattrs==23.2.3
      - cellxgene-census==1.10.2
      - chembl-webresource-client==0.10.9
      - clarabel==0.6.0
      - cloudpickle==3.0.0
      - colorlog==6.7.0
      - comm==0.2.2
      - configargparse==1.7
      - contourpy==1.2.1
      - cvxpy==1.4.1
      - cvxpylayers==0.1.6
      - dash==2.18.2
      - dash-core-components==2.0.0
      - dash-html-components==2.0.0
      - dash-table==5.0.0
      - dataclasses==0.6
      - datamol==0.12.5
      - debtcollector==3.0.0
      - deeppurpose==0.1.5
      - descriptastorus==2.6.1
      - dgllife==0.3.2
      - diffcp==1.0.23
      - dirsync==2.2.5
      - dscribe==2.1.1
      - e3nn==0.5.1
      - easydict==1.13
      - ecos==2.0.12
      - einops==0.7.0
      - einx==0.3.0
      - equiformer-pytorch==0.5.3
      - et-xmlfile==1.1.0
      - fasteners==0.19
      - flask==3.0.3
      - formulaic==1.0.1
      - fqdn==1.5.1
      - frozendict==2.4.4
      - frozenlist==1.4.0
      - fsspec==2023.12.2
      - future==0.18.3
      - fuzzywuzzy==0.18.0
      - gcsfs==2023.12.2.post1
      - gdown==5.2.0
      - gget==0.28.4
      - google-api-core==2.19.1
      - google-auth==2.27.0
      - google-auth-oauthlib==1.2.0
      - google-cloud-core==2.4.1
      - google-cloud-storage==2.18.2
      - google-crc32c==1.5.0
      - google-resumable-media==2.7.2
      - googleapis-common-protos==1.63.2
      - gpytorch==1.11
      - greenlet==3.0.1
      - grpcio==1.60.1
      - h5py==3.10.0
      - httpcore==1.0.5
      - httpx==0.27.0
      - huggingface-hub==0.20.3
      - hydra-colorlog==1.2.0
      - hydra-core==1.3.2
      - hyperopt==0.2.7
      - importlib-resources==6.4.3
      - interface-meta==1.3.0
      - ipywidgets==8.1.5
      - isoduration==20.11.0
      - jax==0.4.16
      - jaxlib==0.4.16+cuda12.cudnn89
      - jaxtyping==0.2.28
      - jsonpointer==2.4
      - jsonref==1.1.0
      - jupyterlab-widgets==3.0.13
      - lifelines==0.28.0
      - linear-operator==0.5.1
      - llvmlite==0.42.0
      - loguru==0.7.2
      - looseversion==1.3.0
      - mako==1.3.0
      - markdown==3.5.2
      - matscipy==1.0.0
      - metis==0.2a5
      - ml-dtypes==0.3.1
      - mock==5.1.0
      - molecular-rectifier==0.1.10.2
      - mols2grid==2.0.0
      - monotonic==1.6
      - multidict==6.0.5
      - multipledispatch==1.0.0
      - mygene==3.2.2
      - mypy-extensions==1.0.0
      - mysql-connector-python==8.0.29
      - natsort==8.4.0
      - neptune==1.8.2
      - numba==0.59.1
      - numcodecs==0.13.0
      - numpy==1.26.4
      - nvidia-cublas-cu12==12.3.2.9
      - nvidia-cuda-cupti-cu12==12.3.52
      - nvidia-cuda-nvcc-cu12==12.3.52
      - nvidia-cuda-nvrtc-cu12==12.3.52
      - nvidia-cuda-runtime-cu12==12.3.52
      - nvidia-cudnn-cu12==8.9.4.25
      - nvidia-cufft-cu12==11.0.11.19
      - nvidia-cusolver-cu12==11.5.3.52
      - nvidia-cusparse-cu12==12.1.3.153
      - nvidia-nccl-cu12==2.19.3
      - nvidia-nvjitlink-cu12==12.3.52
      - oauthlib==3.2.2
      - omegaconf==2.3.0
      - open3d==0.18.0
      - openpyxl==3.0.10
      - openqdc==0.1.2
      - opt-einsum==3.3.0
      - opt-einsum-fx==0.1.4
      - optuna==3.4.0
      - osqp==0.6.3
      - pandas==2.1.4
      - pandas-flavor==0.6.0
      - patsy==0.5.6
      - pip==23.3.1
      - platformdirs==4.2.1
      - plotly==5.21.0
      - prettytable==3.10.0
      - proto-plus==1.24.0
      - psikit==0.2.0
      - py4j==0.10.9.7
      - pyarrow==14.0.1
      - pyarrow-hotfix==0.6
      - pyasn1==0.5.1
      - pyasn1-modules==0.3.0
      - pybind11==2.11.1
      - pydantic==2.6.3
      - pydantic-core==2.16.3
      - pynndescent==0.5.12
      - pyquaternion==0.9.9
      - pyre-extensions==0.0.30
      - pyro-api==0.1.2
      - pyro-ppl==1.9.0
      - pytdc==0.4.12
      - python-dotenv==1.0.1
      - qdldl==0.1.7.post0
      - ray==2.8.0
      - rdkit==2023.9.5
      - rdkit-pypi==2022.9.5
      - regex==2023.10.3
      - requests-cache==1.2.0
      - requests-oauthlib==1.3.1
      - retrying==1.3.4
      - rfc3987==1.3.8
      - rotary-embedding-torch==0.6.4
      - rsa==4.9
      - s3fs==2023.12.2
      - safetensors==0.4.0
      - scanpy==1.9.6
      - schnetpack==2.0.4
      - scs==3.2.4.post1
      - seaborn==0.12.2
      - selfies==2.1.2
      - sentry-sdk==2.17.0
      - session-info==1.0.0
      - simplejson==3.19.2
      - somacore==1.0.7
      - sparse==0.15.4
      - sqlalchemy==2.0.23
      - statsmodels==0.14.2
      - stdlib-list==0.10.0
      - subword-nmt==0.3.8
      - svgutils==0.3.4
      - swagger-spec-validator==3.0.3
      - taylor-series-linear-attention==0.1.11
      - tenacity==8.2.3
      - tensorboard==2.15.1
      - tensorboard-data-server==0.7.2
      - tensorboardx==2.6.2.2
      - tiledb==0.25.1
      - tiledbsoma==1.7.2
      - timm==0.9.11
      - tokenizers==0.14.1
      - torch-cluster==1.6.3+pt21cu121
      - torch-ema==0.3
      - torch-scatter==2.1.2+pt21cu121
      - torch-sparse==0.6.18+pt21cu121
      - torcheval==0.0.7
      - torchmetrics==1.0.1
      - transformers==4.35.0
      - typeguard==2.13.3
      - typer==0.12.4
      - typing-inspect==0.9.0
      - umap-learn==0.5.6
      - uri-template==1.3.0
      - url-normalize==1.4.3
      - wandb==0.18.5
      - webcolors==1.13
      - werkzeug==3.0.1
      - wget==3.2
      - widgetsnbextension==4.0.13
      - wrapt==1.16.0
      - xarray==2024.3.0
      - xlsxwriter==3.2.0
      - xsmiles==0.2.2
      - xyzservices==2024.4.0
      - yapf==0.40.2
      - yarl==1.9.4
      - zarr==2.18.2

@keiradams
Copy link
Collaborator

Hi @arunraja-hub, sorry for the delay.

Can you try adding the line of code def get(self, k): return self.__getitem__(k) to the class HeteroDataset(torch_geometric.data.Dataset): definition in datasets.py in your local clone?

I've updated the file on this Github, for your reference. Pytorch / PyG changed the function names in-between versions, which may be causing this issue.

Let me know if this solves your issues, or if there are other fixes that need to be implemented!

@arunraja-hub
Copy link
Author

Hi @keiradams This error has been resolved now but I am still facing the same original tensor size issue. Here is the complete error traceback. I have had to change lightning, torch and PyG versions to fit my cuda version (11.5)

Seed set to 0
/slurm-storage/aruraj/.conda/envs/airs/lib/python3.11/site-packages/lightning_fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python train.py params_x1x3x4_diffusion_mosesaq_20240824 0 ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/slurm-storage/aruraj/.conda/envs/airs/lib/python3.11/site-packages/torch/overrides.py:110: UserWarning: 'has_cuda' is deprecated, please use 'torch.backends.cuda.is_built()'
  torch.has_cuda,
/slurm-storage/aruraj/.conda/envs/airs/lib/python3.11/site-packages/torch/overrides.py:111: UserWarning: 'has_cudnn' is deprecated, please use 'torch.backends.cudnn.is_available()'
  torch.has_cudnn,
/slurm-storage/aruraj/.conda/envs/airs/lib/python3.11/site-packages/torch/overrides.py:117: UserWarning: 'has_mps' is deprecated, please use 'torch.backends.mps.is_built()'
  torch.has_mps,
/slurm-storage/aruraj/.conda/envs/airs/lib/python3.11/site-packages/torch/overrides.py:118: UserWarning: 'has_mkldnn' is deprecated, please use 'torch.backends.mkldnn.is_available()'
  torch.has_mkldnn,
/slurm-storage/aruraj/.conda/envs/airs/lib/python3.11/site-packages/torch/jit/_check.py:178: UserWarning: The TorchScript type system doesn't support instance-level annotations on empty non-base types in `__init__`. Instead, either 1) use a type annotation in the class body, or 2) wrap the type in `torch.jit.Attribute`.
  warnings.warn(
6010427
beginning to train...
You are using a CUDA device ('NVIDIA A10') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0,1,2,3]

  | Name  | Type  | Params
--------------------------------
0 | model | Model | 6.0 M 
--------------------------------
6.0 M     Trainable params
0         Non-trainable params
6.0 M     Total params
24.042    Total estimated model params size (MB)
/slurm-storage/aruraj/.conda/envs/airs/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=63` in the `DataLoader` to improve performance.
/slurm-storage/aruraj/.conda/envs/airs/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py:293: The number of training batches (143) is smaller than the logging interval Trainer(log_every_n_steps=1000). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.
Epoch 0:   0%|                                                                                                                                    | 0/143 [00:00<?, ?it/s]Traceback (most recent call last):
  File "/mnt/data/slurm-storage/aruraj/opig/shepherd/train.py", line 231, in <module>
    trainer.fit(model_pl, train_loader, ckpt_path = ckpt_path)
  File "/slurm-storage/aruraj/.conda/envs/airs/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 545, in fit
    call._call_and_handle_interrupt(
  File "/slurm-storage/aruraj/.conda/envs/airs/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py", line 44, in _call_and_handle_interrupt
    return trainer_fn(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/slurm-storage/aruraj/.conda/envs/airs/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 581, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/slurm-storage/aruraj/.conda/envs/airs/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 990, in _run
    results = self._run_stage()
              ^^^^^^^^^^^^^^^^^
  File "/slurm-storage/aruraj/.conda/envs/airs/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 1036, in _run_stage
    self.fit_loop.run()
  File "/slurm-storage/aruraj/.conda/envs/airs/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py", line 202, in run
    self.advance()
  File "/slurm-storage/aruraj/.conda/envs/airs/lib/python3.11/site-packages/pytorch_lightning/loops/fit_loop.py", line 359, in advance
    self.epoch_loop.run(self._data_fetcher)
  File "/slurm-storage/aruraj/.conda/envs/airs/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 136, in run
    self.advance(data_fetcher)
  File "/slurm-storage/aruraj/.conda/envs/airs/lib/python3.11/site-packages/pytorch_lightning/loops/training_epoch_loop.py", line 202, in advance
    batch, _, __ = next(data_fetcher)
                   ^^^^^^^^^^^^^^^^^^
  File "/slurm-storage/aruraj/.conda/envs/airs/lib/python3.11/site-packages/pytorch_lightning/loops/fetchers.py", line 127, in __next__
    batch = super().__next__()
            ^^^^^^^^^^^^^^^^^^
  File "/slurm-storage/aruraj/.conda/envs/airs/lib/python3.11/site-packages/pytorch_lightning/loops/fetchers.py", line 56, in __next__
    batch = next(self.iterator)
            ^^^^^^^^^^^^^^^^^^^
  File "/slurm-storage/aruraj/.conda/envs/airs/lib/python3.11/site-packages/pytorch_lightning/utilities/combined_loader.py", line 326, in __next__
    out = next(self._iterator)
          ^^^^^^^^^^^^^^^^^^^^
  File "/slurm-storage/aruraj/.conda/envs/airs/lib/python3.11/site-packages/pytorch_lightning/utilities/combined_loader.py", line 74, in __next__
    out[i] = next(self.iterators[i])
             ^^^^^^^^^^^^^^^^^^^^^^^
  File "/slurm-storage/aruraj/.conda/envs/airs/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 630, in __next__
    data = self._next_data()
           ^^^^^^^^^^^^^^^^^
  File "/slurm-storage/aruraj/.conda/envs/airs/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 674, in _next_data
    data = self._dataset_fetcher.fetch(index)  # may raise StopIteration
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/slurm-storage/aruraj/.conda/envs/airs/lib/python3.11/site-packages/torch/utils/data/_utils/fetch.py", line 54, in fetch
    return self.collate_fn(data)
           ^^^^^^^^^^^^^^^^^^^^^
  File "/slurm-storage/aruraj/.conda/envs/airs/lib/python3.11/site-packages/torch_geometric/loader/dataloader.py", line 55, in collate_fn
    return self(batch)
           ^^^^^^^^^^^
  File "/slurm-storage/aruraj/.conda/envs/airs/lib/python3.11/site-packages/torch_geometric/loader/dataloader.py", line 28, in __call__
    return Batch.from_data_list(
           ^^^^^^^^^^^^^^^^^^^^^
  File "/slurm-storage/aruraj/.conda/envs/airs/lib/python3.11/site-packages/torch_geometric/data/batch.py", line 93, in from_data_list
    batch, slice_dict, inc_dict = collate(
                                  ^^^^^^^^
  File "/slurm-storage/aruraj/.conda/envs/airs/lib/python3.11/site-packages/torch_geometric/data/collate.py", line 92, in collate
    value, slices, incs = _collate(attr, values, data_list, stores,
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/slurm-storage/aruraj/.conda/envs/airs/lib/python3.11/site-packages/torch_geometric/data/collate.py", line 177, in _collate
    value = torch.cat(values, dim=cat_dim or 0, out=out)
            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 1176 but got size 595 for tensor number 1 in the list.

@keiradams
Copy link
Collaborator

keiradams commented Dec 10, 2024

@arunraja-hub Can you confirm that these steps work prior to calling trainer.fit()?

  • make sure you can call dataset[0] after initializing dataset = HeteroDatset(...)
  • make sure you can call next(iter(train_loader)) after initializing train_loader = torch_geometric.loader.DataLoader(...), with batch_size = 1.
  • make sure you can call next(iter(train_loader)) after initializing train_loader = torch_geometric.loader.DataLoader(...), with batch_size > 1

(edited)

@arunraja-hub
Copy link
Author

@keiradams I can call dataset[0] and next(iter(train_loader)) when batch_size > 0 but as expected for batch_size =0, I got the following error:

Traceback (most recent call last):
  File "/mnt/data/slurm-storage/aruraj/opig/shepherd/train.py", line 159, in <module>
    train_loader = torch_geometric.loader.DataLoader(
                   ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/slurm-storage/aruraj/.conda/envs/airs/lib/python3.11/site-packages/torch_geometric/loader/dataloader.py", line 98, in __init__
    super().__init__(
  File "/slurm-storage/aruraj/.conda/envs/airs/lib/python3.11/site-packages/torch/utils/data/dataloader.py", line 355, in __init__
    batch_sampler = BatchSampler(sampler, batch_size, drop_last)
                    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/slurm-storage/aruraj/.conda/envs/airs/lib/python3.11/site-packages/torch/utils/data/sampler.py", line 263, in __init__
    raise ValueError(f"batch_size should be a positive integer value, but got batch_size={batch_size}")
ValueError: batch_size should be a positive integer value, but got batch_size=0

@keiradams
Copy link
Collaborator

keiradams commented Dec 10, 2024

sorry, I meant batch_size = 1 and batch_size > 1

@arunraja-hub

@arunraja-hub
Copy link
Author

Yes batch_size = 1 and batch_size > 1 work for me
@keiradams

@keiradams
Copy link
Collaborator

keiradams commented Dec 10, 2024

@arunraja-hub this error is quite odd to me, then. Can you train without an issue on a CPU with num_workers = 0? On a CPU with num_workers > 1? On 1 GPU with num_workers = 0 and num_workers > 1 ?

You will have to change the parameters in trainer = pl.Trainer() to make these changes.

@arunraja-hub
Copy link
Author

@keiradams the training seems to work when batch_size = 1. The tensor size issue might be occurring due to the batching of graphs of various sizes though PyG should have taken care of this as it creates a batch-level adjacency matrix when dealing with a batch of graphs of varying sizes (https://pytorch-geometric.readthedocs.io/en/2.6.1/notes/batching.html)

@keiradams
Copy link
Collaborator

@arunraja-hub If you can sample from the dataloader when batch_size > 1 (outside of training) by calling next(iter(train_loader)), then the issue shouldn't be with batching through PyG.

Can you confirm again whether you have tested this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants