Skip to content

Allow callable for vectorization over particles on SVI #3576

Allow callable for vectorization over particles on SVI

Allow callable for vectorization over particles on SVI #3576

Workflow file for this run

# adapted from https://github.com/actions/starter-workflows/blob/master/ci/python-package.yml
name: CI
on:
push:
branches: [ master ]
pull_request:
branches: [ master ]
jobs:
lint:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.10"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
sudo apt install -y pandoc gsfonts
python -m pip install --upgrade pip
pip install jaxlib
pip install jax
pip install .[doc,test]
pip install https://github.com/pyro-ppl/funsor/archive/master.zip
pip install -r docs/requirements.txt
pip freeze
- name: Lint with ruff
run: |
make lint
- name: Build documentation
run: |
make docs
- name: Test documentation
run: |
make doctest
python -m doctest -v README.md
test-modeling:
runs-on: ubuntu-latest
needs: lint
strategy:
matrix:
python-version: ["3.10"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
sudo apt install -y graphviz
python -m pip install --upgrade pip
# Keep track of pyro-api master branch
pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip
pip install jaxlib
pip install jax
pip install https://github.com/pyro-ppl/funsor/archive/master.zip
pip install .[dev,test]
pip freeze
- name: Test with pytest
run: |
CI=1 pytest -vs -k "not test_example" --durations=100 --ignore=test/infer/ --ignore=test/contrib/
- name: Test x64
run: |
JAX_ENABLE_X64=1 pytest -vs test/test_distributions.py -k powerLaw
test-inference:
runs-on: ubuntu-latest
needs: lint
strategy:
matrix:
python-version: ["3.10"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
# Keep track of pyro-api master branch
pip install https://github.com/pyro-ppl/pyro-api/archive/master.zip
pip install jaxlib
pip install jax
pip install https://github.com/pyro-ppl/funsor/archive/master.zip
pip install .[dev,test]
pip freeze
- name: Test with pytest
run: |
pytest -vs --durations=20 test/infer/test_mcmc.py
pytest -vs --durations=20 test/infer --ignore=test/infer/test_mcmc.py --ignore=test/contrib/test_nested_sampling.py
pytest -vs --durations=20 test/contrib --ignore=test/contrib/stochastic_support/test_dcc.py
- name: Test x64
run: |
JAX_ENABLE_X64=1 pytest -vs test/infer/test_mcmc.py -k x64
- name: Test chains
run: |
XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/infer/test_mcmc.py -k "chain or pmap or vmap"
XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/contrib/test_tfp.py -k "chain"
XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/contrib/stochastic_support/test_dcc.py
XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs test/infer/test_hmc_gibbs.py -k "chain"
- name: Test custom prng
run: |
JAX_ENABLE_CUSTOM_PRNG=1 pytest -vs test/infer/test_mcmc.py
- name: Test nested sampling
run: |
JAX_ENABLE_X64=1 pytest -vs test/contrib/test_nested_sampling.py
examples:
runs-on: ubuntu-latest
needs: lint
strategy:
matrix:
python-version: ["3.10"]
steps:
- uses: actions/checkout@v2
- name: Set up Python ${{ matrix.python-version }}
uses: actions/setup-python@v2
with:
python-version: ${{ matrix.python-version }}
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install jaxlib
pip install jax
pip install https://github.com/pyro-ppl/funsor/archive/master.zip
pip install .[dev,examples,test]
pip freeze
- name: Test with pytest
run: |
CI=1 XLA_FLAGS="--xla_force_host_platform_device_count=2" pytest -vs -k test_example