Skip to content

Commit 952e0d6

Browse files
sharadmvGoogle-ML-Automation
authored andcommitted
[Pallas TPU] Add dma_granule_size_bytes to SC info
PiperOrigin-RevId: 843361419
1 parent f3d83cd commit 952e0d6

File tree

2 files changed

+41
-25
lines changed

2 files changed

+41
-25
lines changed

jax/_src/pallas/mosaic/sc_core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,7 @@ class BlockMapping(pallas_core.BlockMapping):
152152
def get_sparse_core_info() -> tpu_info.SparseCoreInfo:
153153
"""Returns the SparseCore information for the current device."""
154154
return tpu_info.get_tpu_info().sparse_core or tpu_info.SparseCoreInfo(
155-
num_cores=0, num_subcores=0, num_lanes=0
155+
num_cores=0, num_subcores=0, num_lanes=0, dma_granule_size_bytes=0,
156156
)
157157

158158

jax/_src/pallas/mosaic/tpu_info.py

Lines changed: 40 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,8 @@
2020

2121
from jax import numpy as jnp
2222
from jax._src import dtypes
23-
from jax._src.pallas.mosaic import core
2423
from jax._src import util as jax_util
24+
from jax._src.pallas.mosaic import core
2525

2626

2727
class ChipVersionBase:
@@ -41,12 +41,15 @@ class ChipVersion(ChipVersionBase, enum.Enum):
4141
def __str__(self) -> str:
4242
return self.value
4343

44+
4445
@dataclasses.dataclass(frozen=True, kw_only=True)
4546
class SparseCoreInfo:
4647
"""SparseCore-specific information."""
48+
4749
num_cores: int
4850
num_subcores: int
4951
num_lanes: int
52+
dma_granule_size_bytes: int
5053

5154

5255
@dataclasses.dataclass(frozen=True, kw_only=True)
@@ -122,10 +125,7 @@ def is_matmul_supported(
122125
or (lhs_dt in {U4, S4} and rhs_dt in {U4, S4})
123126
)
124127
case 7:
125-
return (
126-
lhs_dt in {F32, BF16}
127-
and rhs_dt in {F32, BF16}
128-
) or (
128+
return (lhs_dt in {F32, BF16} and rhs_dt in {F32, BF16}) or (
129129
lhs_dt in {F32, BF16, F8E5M2, F8E4M3FN}
130130
and rhs_dt in {F8E5M2, F8E4M3FN}
131131
)
@@ -172,6 +172,7 @@ def is_tpu_device() -> bool:
172172

173173
registry: dict[str, Callable[[], TpuInfo]] = {}
174174

175+
175176
@jax_util.cache(trace_context_in_key=True)
176177
def get_tpu_info() -> TpuInfo:
177178
"""Returns the TPU hardware information for the current device.
@@ -302,7 +303,12 @@ def get_tpu_info() -> TpuInfo:
302303
int8_ops_per_second=int(9.18e14 // num_chip_cores),
303304
fp8_ops_per_second=0, # Not Available
304305
int4_ops_per_second=int(1.84e15 // num_chip_cores),
305-
sparse_core=SparseCoreInfo(num_cores=4, num_subcores=16, num_lanes=8),
306+
sparse_core=SparseCoreInfo(
307+
num_cores=4,
308+
num_subcores=16,
309+
num_lanes=8,
310+
dma_granule_size_bytes=32,
311+
),
306312
)
307313
case "TPU v6 lite" | "TPU v6e": # 1 TensorCore per chip
308314
return TpuInfo(
@@ -321,29 +327,39 @@ def get_tpu_info() -> TpuInfo:
321327
int8_ops_per_second=int(1.84e15),
322328
fp8_ops_per_second=int(9.20e14),
323329
int4_ops_per_second=int(3.68e15),
324-
sparse_core=SparseCoreInfo(num_cores=2, num_subcores=16, num_lanes=8),
330+
sparse_core=SparseCoreInfo(
331+
num_cores=2,
332+
num_subcores=16,
333+
num_lanes=8,
334+
dma_granule_size_bytes=32,
335+
),
325336
)
326337
case "TPU7x":
327338
num_cores = core.get_num_device_cores()
328339
num_chip_cores = 2
329340
return TpuInfo(
330-
chip_version=ChipVersion.TPU_7X,
331-
generation=7,
332-
num_cores=num_cores,
333-
num_lanes=128,
334-
num_sublanes=8,
335-
mxu_column_size=256,
336-
vmem_capacity_bytes=64 * 1024 * 1024, # 64 MiB per core
337-
cmem_capacity_bytes=0,
338-
smem_capacity_bytes=1024 * 1024, # 1 MiB per core
339-
hbm_capacity_bytes=206_000_000_000 // num_chip_cores,
340-
mem_bw_bytes_per_second=int(7.40e12 // num_chip_cores),
341-
bf16_ops_per_second=int(2.31e15 // num_chip_cores),
342-
int8_ops_per_second=0, # Not Available
343-
fp8_ops_per_second=int(4.60e15 // num_chip_cores),
344-
int4_ops_per_second=0, # Not Available
345-
sparse_core=SparseCoreInfo(num_cores=4, num_subcores=16, num_lanes=16),
346-
)
341+
chip_version=ChipVersion.TPU_7X,
342+
generation=7,
343+
num_cores=num_cores,
344+
num_lanes=128,
345+
num_sublanes=8,
346+
mxu_column_size=256,
347+
vmem_capacity_bytes=64 * 1024 * 1024, # 64 MiB per core
348+
cmem_capacity_bytes=0,
349+
smem_capacity_bytes=1024 * 1024, # 1 MiB per core
350+
hbm_capacity_bytes=206_000_000_000 // num_chip_cores,
351+
mem_bw_bytes_per_second=int(7.40e12 // num_chip_cores),
352+
bf16_ops_per_second=int(2.31e15 // num_chip_cores),
353+
int8_ops_per_second=0, # Not Available
354+
fp8_ops_per_second=int(4.60e15 // num_chip_cores),
355+
int4_ops_per_second=0, # Not Available
356+
sparse_core=SparseCoreInfo(
357+
num_cores=4,
358+
num_subcores=16,
359+
num_lanes=16,
360+
dma_granule_size_bytes=64,
361+
),
362+
)
347363
case _ as d:
348364
if d in registry:
349365
return registry[d]()

0 commit comments

Comments
 (0)