Skip to content

Commit

Permalink
Merge pull request #78 from neuro-ml/develop
Browse files Browse the repository at this point in the history
Convert to fp16 in `inference_step`, `multi_inference_step`
  • Loading branch information
vovaf709 authored Jul 25, 2023
2 parents 56c3658 + 41f280c commit 21870d4
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 2 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ jobs:
runs-on: ubuntu-20.04
strategy:
matrix:
python-version: [ '3.6', '3.7', '3.8', '3.9', '3.10' ]
python-version: [ '3.6', '3.7', '3.8', '3.9', '3.10', '3.11' ]

steps:
- uses: actions/checkout@v2
Expand Down
2 changes: 1 addition & 1 deletion dpipe/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.2.5'
__version__ = '0.2.6'
14 changes: 14 additions & 0 deletions dpipe/torch/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,8 +168,15 @@ def inference_step(*inputs: np.ndarray, architecture: Module, activation: Callab
-----
Note that both input and output are **not** of type ``torch.Tensor`` - the conversion
to and from ``torch.Tensor`` is made inside this function.
Inputs will be converted to fp16 if ``amp`` is True.
"""
architecture.eval()

# NumPy >= 1.24 warns about underflow during cast which is really insignificant
if amp:
with np.errstate(under='ignore'):
inputs = tuple(np.asarray(x, dtype=np.float16) for x in inputs)

with torch.no_grad():
with torch.cuda.amp.autocast(amp or torch.is_autocast_enabled()):
return to_np(activation(architecture(*sequence_to_var(*inputs, device=architecture))))
Expand All @@ -188,8 +195,15 @@ def multi_inference_step(*inputs: np.ndarray, architecture: Module,
-----
Note that both input and output are **not** of type ``torch.Tensor`` - the conversion
to and from ``torch.Tensor`` is made inside this function.
Inputs will be converted to fp16 if ``amp`` is True.
"""
architecture.eval()

# NumPy >= 1.24 warns about underflow during cast which is really insignificant
if amp:
with np.errstate(under='ignore'):
inputs = tuple(np.asarray(x, dtype=np.float16) for x in inputs)

with torch.no_grad():
with torch.cuda.amp.autocast(amp or torch.is_autocast_enabled()):
results = architecture(*sequence_to_var(*inputs, device=architecture))
Expand Down

0 comments on commit 21870d4

Please sign in to comment.