From 34fe66b08b3720bd37ca12d3527cbf31dabbd9b8 Mon Sep 17 00:00:00 2001 From: Christos Perivolaropoulos Date: Thu, 28 Nov 2024 02:19:59 -0800 Subject: [PATCH] [mgpu] foreach should not try to create an array if it didn't create the registers due to create_array=False. PiperOrigin-RevId: 700955830 --- jax/experimental/mosaic/gpu/fragmented_array.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/jax/experimental/mosaic/gpu/fragmented_array.py b/jax/experimental/mosaic/gpu/fragmented_array.py index 6b288906e967..c2cd8c21c132 100644 --- a/jax/experimental/mosaic/gpu/fragmented_array.py +++ b/jax/experimental/mosaic/gpu/fragmented_array.py @@ -1265,7 +1265,8 @@ def foreach( if create_array: new_regs[reg_idx] = vector.insertelement(val, new_regs[reg_idx], position=i) - return FragmentedArray(_registers=new_regs, _layout=self.layout, _is_signed=is_signed) + if create_array: + return FragmentedArray(_registers=new_regs, _layout=self.layout, _is_signed=is_signed) def store_untiled(self, ref: ir.Value): if not ir.MemRefType.isinstance(ref.type):