Skip to content

Commit

Permalink
Add SMEM as a supported Pallas output memory space.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 712144883
  • Loading branch information
Google-ML-Automation committed Jan 5, 2025
1 parent 9af2970 commit 54fd738
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 0 deletions.
2 changes: 2 additions & 0 deletions jax/_src/pallas/mosaic/pallas_call_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,8 @@ def _get_memory_space_from_aval(
return None
case tpu_core.TPUMemorySpace.VMEM:
return tpu_custom_call.MemorySpace.VMEM
case tpu_core.TPUMemorySpace.SMEM:
return tpu_custom_call.MemorySpace.SMEM
case tpu_core.TPUMemorySpace.SEMAPHORE:
return tpu_custom_call.MemorySpace.SEMAPHORE_MEM
return None
Expand Down
3 changes: 3 additions & 0 deletions jax/_src/tpu_custom_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ class MemorySpace(enum.Enum):
HBM = enum.auto()
VMEM = enum.auto()
SEMAPHORE_MEM = enum.auto()
SMEM = enum.auto()

@property
def color(self) -> int:
Expand All @@ -92,6 +93,8 @@ def color(self) -> int:
return 1
elif self == MemorySpace.SEMAPHORE_MEM:
return 2
elif self == MemorySpace.SMEM:
return 4
else:
raise ValueError("invalid memory space: " + str(self))

Expand Down

0 comments on commit 54fd738

Please sign in to comment.