Intel® Extension for OpenXLA* 0.3.0 Release
Major Features
Intel® Extension for OpenXLA* is Intel optimized PyPI package to extend official OpenXLA on Intel GPU. It is based on PJRT plugin mechanism, which can seamlessly run JAX models on Intel® Data Center GPU Max Series and Intel® Data Center GPU Flex Series. This release contains following major features:
- JAX Upgrade: Upgrade version to v0.4.24.
- Feature Support:
- Supports custom call registration mechanism by new OpenXLA C API. This feature provides the ability to interact with third-party software, such as mpi4jax.
- Continue to improve JAX native distributed scale-up collectives. Now it supports any number of devices less than 16 in a single node.
- Experimental support for Intel® Data Center GPU Flex Series.
- Bug Fix:
- Fix accuracy issues in GEMM kernel when it's optimized by Intel® Xe Templates for Linear Algebra (XeTLA).
- Fix crash when input batch size is greater than 65535.
- Toolkit Support: Support Intel® oneAPI Base Toolkit 2024.1.
Known Caveats
- Extension will crash when using Binary operations (e.g.
Mul
,MatMul
) and SPMD multi-device parallelism APIpsum_scatter
under samepartial
annotation. Please refer JAX UT test_matmul_reduce_scatter to understand the error scenario better. - JAX collectives fall into deadlock and hang Extension when working with Toolkit 2024.1. Recommend to use Toolkit 2024.0 if need collectives.
clear_backends
API doesn't work and may cause an OOM exception as below when working with Toolkit 2024.0.
terminate called after throwing an instance of 'sycl::_V1::runtime_error'
what(): Native API failed. Native API returns: -5 (PI_ERROR_OUT_OF_RESOURCES) -5 (PI_ERROR_OUT_OF_RESOURCES)
Fatal Python error: Aborted
Note: clear_backends
API will be deprecated by JAX soon.
Breaking changes
- Previous JAX v0.4.20 is no longer supported. Please follow JAX change log to update application if meets version errors.