Skip to content

Commit

Permalink
[mgpu] foreach should not try to create an array if it didn't create …
Browse files Browse the repository at this point in the history
…the registers due to create_array=False.

PiperOrigin-RevId: 700955830
  • Loading branch information
cperivol authored and Google-ML-Automation committed Nov 28, 2024
1 parent bdee4c3 commit 34fe66b
Showing 1 changed file with 2 additions and 1 deletion.
3 changes: 2 additions & 1 deletion jax/experimental/mosaic/gpu/fragmented_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1265,7 +1265,8 @@ def foreach(
if create_array:
new_regs[reg_idx] = vector.insertelement(val, new_regs[reg_idx], position=i)

return FragmentedArray(_registers=new_regs, _layout=self.layout, _is_signed=is_signed)
if create_array:
return FragmentedArray(_registers=new_regs, _layout=self.layout, _is_signed=is_signed)

def store_untiled(self, ref: ir.Value):
if not ir.MemRefType.isinstance(ref.type):
Expand Down

0 comments on commit 34fe66b

Please sign in to comment.