Skip to content

Commit

Permalink
More torch versions
Browse files Browse the repository at this point in the history
  • Loading branch information
frostedoyster committed Feb 14, 2024
1 parent 165d583 commit b12db15
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 12 deletions.
8 changes: 4 additions & 4 deletions .github/workflows/build-cuda-wheels.yml
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,12 @@ jobs:
strategy:
matrix:
os: [ubuntu-20.04]
torch: ["2.2.0"]
torch: ["1.13.0", "2.0.0", "2.1.0", "2.2.0"]
include:
- cuda: "11.8"
torch_cuda: "118"
torch_index_cuda: "118"
- cuda: "12.1"
torch_cuda: "121"
torch_index_cuda: "121"
defaults:
run:
shell: pwsh
Expand Down Expand Up @@ -73,7 +73,7 @@ jobs:
CIBW_SKIP: "*-musllinux* *-win32 *-manylinux_i686"
CIBW_MANYLINUX_X86_64_IMAGE: manylinux2014
# set environment variables for sphericart-torch build
CIBW_ENVIRONMENT: SPHERICART_ARCH_NATIVE=OFF CUDACXX=/usr/local/cuda/bin/nvcc TORCH_CUDA_ARCH_LIST=All CUDAARCHS=all SPHERICART_TORCH_TORCH_VERSION=${{ matrix.torch }} PIP_EXTRA_INDEX_URL=https://download.pytorch.org/whl/cu${{ matrix.torch_cuda }}
CIBW_ENVIRONMENT: SPHERICART_ARCH_NATIVE=OFF CUDACXX=/usr/local/cuda/bin/nvcc TORCH_CUDA_ARCH_LIST=All CUDAARCHS=all SPHERICART_TORCH_TORCH_VERSION=${{ matrix.torch }} PIP_EXTRA_INDEX_URL=https://download.pytorch.org/whl/cu${{ matrix.torch_index_cuda }}
# do not complain for missing libtorch.so in sphericart-torch wheel
CIBW_REPAIR_WHEEL_COMMAND_LINUX: |
auditwheel repair --exclude libtorch.so --exclude libtorch_cpu.so --exclude libtorch_cuda.so --exclude libc10.so --exclude libc10_cuda.so -w {dest_dir} {wheel}
Expand Down
11 changes: 3 additions & 8 deletions sphericart-torch/build-backend/backend.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,14 @@
# this is a custom Python build backend wrapping setuptool's to add a build-time
# dependencies to metatensor-core, using the local version if it exists, and
# otherwise falling back to the one on PyPI.
# this is a custom Python build backend wrapping setuptool's to set a
# specific torch version as a build dependency, based on an environment
# variable
import os

from setuptools import build_meta

TORCH_VERSION = os.environ.get("SPHERICART_TORCH_TORCH_VERSION")
# CUDA_VERSION = os.environ.get("SPHERICART_TORCH_CUDA_VERSION")

if TORCH_VERSION is not None:
# force a specific version of torch+cuda
TORCH_DEP = f"torch =={TORCH_VERSION}"
# if CUDA_VERSION is not None:
# extra_index_url = f" --index-url https://download.pytorch.org/whl/cu{CUDA_VERSION.replace('.', '')}"
# TORCH_DEP += extra_index_url
else:
TORCH_DEP = "torch >=1.13"

Expand Down

0 comments on commit b12db15

Please sign in to comment.