2020
2121from jax import numpy as jnp
2222from jax ._src import dtypes
23- from jax ._src .pallas .mosaic import core
2423from jax ._src import util as jax_util
24+ from jax ._src .pallas .mosaic import core
2525
2626
2727class ChipVersionBase :
@@ -41,12 +41,15 @@ class ChipVersion(ChipVersionBase, enum.Enum):
4141 def __str__ (self ) -> str :
4242 return self .value
4343
44+
4445@dataclasses .dataclass (frozen = True , kw_only = True )
4546class SparseCoreInfo :
4647 """SparseCore-specific information."""
48+
4749 num_cores : int
4850 num_subcores : int
4951 num_lanes : int
52+ dma_granule_size_bytes : int
5053
5154
5255@dataclasses .dataclass (frozen = True , kw_only = True )
@@ -122,10 +125,7 @@ def is_matmul_supported(
122125 or (lhs_dt in {U4 , S4 } and rhs_dt in {U4 , S4 })
123126 )
124127 case 7 :
125- return (
126- lhs_dt in {F32 , BF16 }
127- and rhs_dt in {F32 , BF16 }
128- ) or (
128+ return (lhs_dt in {F32 , BF16 } and rhs_dt in {F32 , BF16 }) or (
129129 lhs_dt in {F32 , BF16 , F8E5M2 , F8E4M3FN }
130130 and rhs_dt in {F8E5M2 , F8E4M3FN }
131131 )
@@ -172,6 +172,7 @@ def is_tpu_device() -> bool:
172172
173173registry : dict [str , Callable [[], TpuInfo ]] = {}
174174
175+
175176@jax_util .cache (trace_context_in_key = True )
176177def get_tpu_info () -> TpuInfo :
177178 """Returns the TPU hardware information for the current device.
@@ -302,7 +303,12 @@ def get_tpu_info() -> TpuInfo:
302303 int8_ops_per_second = int (9.18e14 // num_chip_cores ),
303304 fp8_ops_per_second = 0 , # Not Available
304305 int4_ops_per_second = int (1.84e15 // num_chip_cores ),
305- sparse_core = SparseCoreInfo (num_cores = 4 , num_subcores = 16 , num_lanes = 8 ),
306+ sparse_core = SparseCoreInfo (
307+ num_cores = 4 ,
308+ num_subcores = 16 ,
309+ num_lanes = 8 ,
310+ dma_granule_size_bytes = 32 ,
311+ ),
306312 )
307313 case "TPU v6 lite" | "TPU v6e" : # 1 TensorCore per chip
308314 return TpuInfo (
@@ -321,29 +327,39 @@ def get_tpu_info() -> TpuInfo:
321327 int8_ops_per_second = int (1.84e15 ),
322328 fp8_ops_per_second = int (9.20e14 ),
323329 int4_ops_per_second = int (3.68e15 ),
324- sparse_core = SparseCoreInfo (num_cores = 2 , num_subcores = 16 , num_lanes = 8 ),
330+ sparse_core = SparseCoreInfo (
331+ num_cores = 2 ,
332+ num_subcores = 16 ,
333+ num_lanes = 8 ,
334+ dma_granule_size_bytes = 32 ,
335+ ),
325336 )
326337 case "TPU7x" :
327338 num_cores = core .get_num_device_cores ()
328339 num_chip_cores = 2
329340 return TpuInfo (
330- chip_version = ChipVersion .TPU_7X ,
331- generation = 7 ,
332- num_cores = num_cores ,
333- num_lanes = 128 ,
334- num_sublanes = 8 ,
335- mxu_column_size = 256 ,
336- vmem_capacity_bytes = 64 * 1024 * 1024 , # 64 MiB per core
337- cmem_capacity_bytes = 0 ,
338- smem_capacity_bytes = 1024 * 1024 , # 1 MiB per core
339- hbm_capacity_bytes = 206_000_000_000 // num_chip_cores ,
340- mem_bw_bytes_per_second = int (7.40e12 // num_chip_cores ),
341- bf16_ops_per_second = int (2.31e15 // num_chip_cores ),
342- int8_ops_per_second = 0 , # Not Available
343- fp8_ops_per_second = int (4.60e15 // num_chip_cores ),
344- int4_ops_per_second = 0 , # Not Available
345- sparse_core = SparseCoreInfo (num_cores = 4 , num_subcores = 16 , num_lanes = 16 ),
346- )
341+ chip_version = ChipVersion .TPU_7X ,
342+ generation = 7 ,
343+ num_cores = num_cores ,
344+ num_lanes = 128 ,
345+ num_sublanes = 8 ,
346+ mxu_column_size = 256 ,
347+ vmem_capacity_bytes = 64 * 1024 * 1024 , # 64 MiB per core
348+ cmem_capacity_bytes = 0 ,
349+ smem_capacity_bytes = 1024 * 1024 , # 1 MiB per core
350+ hbm_capacity_bytes = 206_000_000_000 // num_chip_cores ,
351+ mem_bw_bytes_per_second = int (7.40e12 // num_chip_cores ),
352+ bf16_ops_per_second = int (2.31e15 // num_chip_cores ),
353+ int8_ops_per_second = 0 , # Not Available
354+ fp8_ops_per_second = int (4.60e15 // num_chip_cores ),
355+ int4_ops_per_second = 0 , # Not Available
356+ sparse_core = SparseCoreInfo (
357+ num_cores = 4 ,
358+ num_subcores = 16 ,
359+ num_lanes = 16 ,
360+ dma_granule_size_bytes = 64 ,
361+ ),
362+ )
347363 case _ as d :
348364 if d in registry :
349365 return registry [d ]()
0 commit comments