@@ -142,6 +142,171 @@ def create_args_from_json(data):
142142 return grid , args_dict
143143
144144
145+ def _apply_stride_and_offset (tensor , shape , stride , storage_offset ):
146+ """
147+ Apply custom stride and storage offset to a tensor if needed.
148+
149+ Args:
150+ tensor: The base contiguous tensor
151+ shape: The desired shape
152+ stride: The desired stride (or None for contiguous)
153+ storage_offset: The desired storage offset
154+
155+ Returns:
156+ torch.Tensor: The strided tensor view or original tensor if contiguous
157+ """
158+ if stride is None :
159+ return tensor
160+
161+ # Calculate expected contiguous stride
162+ expected_contiguous_stride = []
163+ s = 1
164+ for dim_size in reversed (shape ):
165+ expected_contiguous_stride .insert (0 , s )
166+ s *= dim_size
167+
168+ # If stride matches contiguous stride and no storage offset, return as-is
169+ if tuple (stride ) == tuple (expected_contiguous_stride ) and storage_offset == 0 :
170+ return tensor
171+
172+ # Calculate required storage size
173+ if len (shape ) > 0 and len (stride ) > 0 :
174+ max_offset = storage_offset
175+ for dim_stride , dim_size in zip (stride , shape ):
176+ if dim_size > 0 :
177+ max_offset += dim_stride * (dim_size - 1 )
178+ storage_size = max_offset + 1
179+ else :
180+ storage_size = storage_offset + 1
181+
182+ # Create larger storage tensor and create strided view
183+ storage_tensor = torch .empty (storage_size , dtype = tensor .dtype , device = tensor .device )
184+
185+ # Create strided view
186+ strided_view = storage_tensor .as_strided (
187+ size = shape , stride = stride , storage_offset = storage_offset
188+ )
189+
190+ # Copy data from the base tensor into the strided layout
191+ strided_view .copy_ (tensor .flatten ()[: strided_view .numel ()].view (shape ))
192+
193+ return strided_view
194+
195+
196+ def _create_base_tensor (arg_info ) -> torch .Tensor :
197+ if arg_info .get ("blob_path" ):
198+ return load_tensor (arg_info .get ("blob_path" ), arg_info .get ("device" ))
199+
200+ # Extract basic tensor properties
201+ dtype_str = arg_info .get ("dtype" )
202+ try :
203+ torch_dtype = getattr (torch , dtype_str .split ("." )[- 1 ])
204+ except AttributeError :
205+ logging .error (f"Unsupported dtype: { dtype_str } . Defaulting to float32." )
206+ torch_dtype = torch .float32
207+
208+ shape = arg_info .get ("shape" , [])
209+ device = arg_info .get ("device" , "cpu" )
210+
211+ # Extract statistical information if available
212+ mean = arg_info .get ("mean" )
213+ std = arg_info .get ("std" )
214+ min_val = arg_info .get ("min" )
215+ max_val = arg_info .get ("max" )
216+ has_stats = (
217+ mean is not None
218+ and std is not None
219+ and min_val is not None
220+ and max_val is not None
221+ )
222+
223+ if arg_info .get ("tensor_capture_error" , False ):
224+ logging .error (
225+ f"Error: Tensor '{ arg_info .get ('name' , '' )} ' had capture error. Generating random tensor instead."
226+ )
227+
228+ # Use a dummy tensor to check properties of the dtype
229+ tensor_props = torch .empty (0 , dtype = torch_dtype )
230+
231+ # Case 1: Floating point types
232+ if tensor_props .is_floating_point ():
233+ if has_stats :
234+ # Generate tensor with statistical properties matching original data
235+ if std == 0 or min_val == max_val :
236+ # Constant tensor
237+ return torch .full (shape , mean , dtype = torch_dtype , device = device )
238+ # Generate normal distribution with mean and std, then clamp to [min, max]
239+ tensor = torch .randn (shape , dtype = torch .float32 , device = device ) * std + mean
240+ tensor = torch .clamp (tensor , min = min_val , max = max_val )
241+ return tensor .to (torch_dtype )
242+ else :
243+ # Fallback to original random generation
244+ if torch_dtype in [torch .float8_e4m3fn , torch .float8_e5m2 ]:
245+ tmp = torch .rand (shape , dtype = torch .float32 , device = device )
246+ return tmp .to (torch_dtype )
247+ else :
248+ return torch .empty (shape , dtype = torch_dtype , device = device ).random_ ()
249+
250+ # Case 2: Integer types
251+ elif torch_dtype in [
252+ torch .int8 ,
253+ torch .int16 ,
254+ torch .int32 ,
255+ torch .int64 ,
256+ torch .uint8 ,
257+ torch .bool ,
258+ ]:
259+ if has_stats and torch_dtype != torch .bool :
260+ # Generate tensor with statistical properties, then round for integers
261+ if std == 0 or min_val == max_val :
262+ # Constant tensor
263+ return torch .full (shape , int (mean ), dtype = torch_dtype , device = device )
264+ tensor = torch .randn (shape , dtype = torch .float32 , device = device ) * std + mean
265+ tensor = torch .clamp (tensor , min = min_val , max = max_val )
266+ return torch .round (tensor ).to (torch_dtype )
267+ else :
268+ # Fallback to original random generation
269+ return torch .empty (shape , dtype = torch_dtype , device = device ).random_ ()
270+
271+ # Case 3: Complex numbers need special handling
272+ elif tensor_props .is_complex ():
273+ # Complex types: fallback to original logic for now
274+ # TODO: Could be improved to use statistical info if available
275+ float_dtype = torch .float32 if torch_dtype == torch .complex64 else torch .float64
276+ real_part = torch .rand (shape , dtype = float_dtype , device = device )
277+ imag_part = torch .rand (shape , dtype = float_dtype , device = device )
278+ return torch .complex (real_part , imag_part )
279+
280+ # Case 4: Handle other unsigned integers (like uint32) which fail with random_()
281+ elif "uint" in str (torch_dtype ):
282+ if has_stats :
283+ # Generate tensor with statistical properties for unsigned integers
284+ if std == 0 or min_val == max_val :
285+ return torch .full (shape , int (mean ), dtype = torch_dtype , device = device )
286+ tensor = torch .randn (shape , dtype = torch .float32 , device = device ) * std + mean
287+ tensor = torch .clamp (tensor , min = min_val , max = max_val )
288+ return torch .round (tensor ).to (torch_dtype )
289+ else :
290+ # Fallback to original random generation
291+ return torch .randint (0 , 1000 , shape , dtype = torch_dtype , device = device )
292+
293+ # Case 5: If we don't know how to handle the type, raise an error
294+ else :
295+ raise NotImplementedError (
296+ f"Random data generation not implemented for dtype: { torch_dtype } "
297+ )
298+
299+
300+ def _create_tensor (arg_info ) -> torch .Tensor :
301+ tensor = _create_base_tensor (arg_info )
302+
303+ # Apply stride and storage offset if needed
304+ shape = arg_info .get ("shape" , [])
305+ stride = arg_info .get ("stride" )
306+ storage_offset = arg_info .get ("storage_offset" , 0 )
307+ return _apply_stride_and_offset (tensor , shape , stride , storage_offset )
308+
309+
145310def _create_arg_from_info (arg_info ):
146311 """
147312 Recursively construct a kernel argument from its JSON schema.
@@ -166,121 +331,7 @@ def _create_arg_from_info(arg_info):
166331 return arg_info .get ("value" )
167332
168333 elif arg_type == "tensor" :
169- if arg_info .get ("blob_path" ):
170- return load_tensor (arg_info .get ("blob_path" ), arg_info .get ("device" ))
171-
172- # Extract basic tensor properties
173- dtype_str = arg_info .get ("dtype" )
174- try :
175- torch_dtype = getattr (torch , dtype_str .split ("." )[- 1 ])
176- except AttributeError :
177- logging .error (f"Unsupported dtype: { dtype_str } . Defaulting to float32." )
178- torch_dtype = torch .float32
179-
180- shape = arg_info .get ("shape" , [])
181- device = arg_info .get ("device" , "cpu" )
182-
183- # Extract statistical information if available
184- mean = arg_info .get ("mean" )
185- std = arg_info .get ("std" )
186- min_val = arg_info .get ("min" )
187- max_val = arg_info .get ("max" )
188- has_stats = (
189- mean is not None
190- and std is not None
191- and min_val is not None
192- and max_val is not None
193- )
194-
195- if arg_info .get ("tensor_capture_error" , False ):
196- logging .error (
197- f"Error: Tensor '{ arg_info .get ('name' , '' )} ' had capture error. Generating random tensor instead."
198- )
199-
200- # Use a dummy tensor to check properties of the dtype
201- tensor_props = torch .empty (0 , dtype = torch_dtype )
202-
203- # Case 1: Floating point types
204- if tensor_props .is_floating_point ():
205- if has_stats :
206- # Generate tensor with statistical properties matching original data
207- if std == 0 or min_val == max_val :
208- # Constant tensor
209- return torch .full (shape , mean , dtype = torch_dtype , device = device )
210- # Generate normal distribution with mean and std, then clamp to [min, max]
211- tensor = (
212- torch .randn (shape , dtype = torch .float32 , device = device ) * std + mean
213- )
214- tensor = torch .clamp (tensor , min = min_val , max = max_val )
215- return tensor .to (torch_dtype )
216- else :
217- # Fallback to original random generation
218- if torch_dtype in [torch .float8_e4m3fn , torch .float8_e5m2 ]:
219- tmp = torch .rand (shape , dtype = torch .float32 , device = device )
220- return tmp .to (torch_dtype )
221- else :
222- return torch .empty (
223- shape , dtype = torch_dtype , device = device
224- ).random_ ()
225-
226- # Case 2: Integer types
227- elif torch_dtype in [
228- torch .int8 ,
229- torch .int16 ,
230- torch .int32 ,
231- torch .int64 ,
232- torch .uint8 ,
233- torch .bool ,
234- ]:
235- if has_stats and torch_dtype != torch .bool :
236- # Generate tensor with statistical properties, then round for integers
237- if std == 0 or min_val == max_val :
238- # Constant tensor
239- return torch .full (
240- shape , int (mean ), dtype = torch_dtype , device = device
241- )
242- tensor = (
243- torch .randn (shape , dtype = torch .float32 , device = device ) * std + mean
244- )
245- tensor = torch .clamp (tensor , min = min_val , max = max_val )
246- return torch .round (tensor ).to (torch_dtype )
247- else :
248- # Fallback to original random generation
249- return torch .empty (shape , dtype = torch_dtype , device = device ).random_ ()
250-
251- # Case 3: Complex numbers need special handling
252- elif tensor_props .is_complex ():
253- # Complex types: fallback to original logic for now
254- # TODO: Could be improved to use statistical info if available
255- float_dtype = (
256- torch .float32 if torch_dtype == torch .complex64 else torch .float64
257- )
258- real_part = torch .rand (shape , dtype = float_dtype , device = device )
259- imag_part = torch .rand (shape , dtype = float_dtype , device = device )
260- return torch .complex (real_part , imag_part )
261-
262- # Case 4: Handle other unsigned integers (like uint32) which fail with random_()
263- elif "uint" in str (torch_dtype ):
264- if has_stats :
265- # Generate tensor with statistical properties for unsigned integers
266- if std == 0 or min_val == max_val :
267- return torch .full (
268- shape , int (mean ), dtype = torch_dtype , device = device
269- )
270- tensor = (
271- torch .randn (shape , dtype = torch .float32 , device = device ) * std + mean
272- )
273- tensor = torch .clamp (tensor , min = min_val , max = max_val )
274- return torch .round (tensor ).to (torch_dtype )
275- else :
276- # Fallback to original random generation
277- return torch .randint (0 , 1000 , shape , dtype = torch_dtype , device = device )
278-
279- # Case 5: If we don't know how to handle the type, raise an error
280- else :
281- raise NotImplementedError (
282- f"Random data generation not implemented for dtype: { torch_dtype } "
283- )
334+ return _create_tensor (arg_info )
284335
285336 elif arg_type == "triton_kernels.tensor.Tensor" :
286337 if not TRITON_KERNELS_CUSTOM_TYPES :
0 commit comments