2828from abc import ABC
2929from typing import List
3030
31- import cupy
3231import torch
3332
3433from ucm .store .cache .connector import UcmCacheStore
@@ -43,6 +42,7 @@ def __init__(
4342 layer_size : int ,
4443 chunk_size : int ,
4544 storage_backends : List [str ],
45+ device_id : int ,
4646 ):
4747 super ().__init__ ()
4848 chunk_block_size = tensor_size * layer_size * chunk_size
@@ -57,7 +57,7 @@ def __init__(
5757 cache_config = {}
5858 cache_config ["backend" ] = self .posix .cc_store ()
5959 cache_config ["engine_id" ] = secrets .token_hex (8 )
60- cache_config ["device_id" ] = 1
60+ cache_config ["device_id" ] = device_id
6161 cache_config ["tensor_size" ] = tensor_size
6262 cache_config ["shard_size" ] = chunk_block_size
6363 cache_config ["block_size" ] = chunk_block_size
@@ -111,38 +111,64 @@ def check(self, task: Task) -> bool:
111111 return self .cache .check (task )
112112
113113
114- def main ():
115- tensor_size = 262144
116- layer_size = 64
117- chunk_size = 4
118- request_size = chunk_size * 16
119- storage_backends = ["." ]
120- store = HierarchicalStore (tensor_size , layer_size , chunk_size , storage_backends )
114+ def cmp_and_print_diff (a , b , rtol = 0.0 , atol = 0.0 ):
115+ for r , (row_a , row_b ) in enumerate (zip (a , b )):
116+ for c , (ta , tb ) in enumerate (zip (row_a , row_b )):
117+ if not torch .allclose (ta , tb , rtol = rtol , atol = atol ):
118+ mask = ~ torch .isclose (ta , tb , rtol = rtol , atol = atol )
119+ diff_a = ta [mask ].cpu ()
120+ diff_b = tb [mask ].cpu ()
121+ print (f"DIFF at [{ r } ][{ c } ] total { mask .sum ().item ()} element(s)" )
122+ print (" a val:" , diff_a .flatten ())
123+ print (" b val:" , diff_b .flatten ())
124+ assert False
125+
126+
127+ def e2e_test (
128+ store : HierarchicalStore ,
129+ tensor_size : int ,
130+ layer_size : int ,
131+ chunk_size : int ,
132+ request_size : int ,
133+ device_id : int ,
134+ ):
121135 chunk_block_ids = [secrets .token_bytes (16 ) for _ in range (request_size )]
122136 founds = store .lookup (chunk_block_ids )
123137 assert not all (founds )
124138 shard_indexes = [0 for _ in range (request_size )]
125- src_addrs = [
139+ src_tensors = [
126140 [
127- cupy .cuda .alloc_pinned_memory (tensor_size ).ptr
141+ torch .rand (
142+ [tensor_size // 2 ],
143+ dtype = torch .bfloat16 ,
144+ device = "cuda:{}" .format (device_id ),
145+ )
128146 for _ in range (layer_size * chunk_size )
129147 ]
130148 for _ in range (request_size )
131149 ]
132- task = store .dump_data (chunk_block_ids , shard_indexes , src_addrs )
150+ task = store .dump (chunk_block_ids , shard_indexes , src_tensors )
133151 store .wait (task )
134152 time .sleep (1 )
135- dst_addrs = [
136- [
137- cupy .cuda .alloc_pinned_memory (tensor_size ).ptr
138- for _ in range (layer_size * chunk_size )
139- ]
140- for _ in range (request_size )
141- ]
142- founds = store .lookup (chunk_block_ids )
143- assert all (founds )
144- task = store .load_data (chunk_block_ids , shard_indexes , dst_addrs )
153+ dst_tensors = [[torch .empty_like (t ) for t in row ] for row in src_tensors ]
154+ task = store .load (chunk_block_ids , shard_indexes , dst_tensors )
145155 store .wait (task )
156+ cmp_and_print_diff (src_tensors , dst_tensors )
157+
158+
159+ def main ():
160+ tensor_size = 262144
161+ layer_size = 64
162+ chunk_size = 4
163+ request_size = chunk_size * 16
164+ storage_backends = ["." ]
165+ device_id = 1
166+ test_batch_number = 64
167+ store = HierarchicalStore (
168+ tensor_size , layer_size , chunk_size , storage_backends , device_id
169+ )
170+ for _ in range (test_batch_number ):
171+ e2e_test (store , tensor_size , layer_size , chunk_size , request_size , device_id )
146172
147173
148174if __name__ == "__main__" :
0 commit comments