@@ -206,7 +206,7 @@ def test_device(self, xp: ModuleType, device: Device):
206206 @given (
207207 n_arrays = st .integers (min_value = 1 , max_value = 3 ),
208208 rng_seed = st .integers (min_value = 1000000000 , max_value = 9999999999 ),
209- dtype = st . sampled_from (( np . float32 , np . float64 )),
209+ dtype = npst . floating_dtypes ( sizes = ( 32 , 64 )),
210210 p = st .floats (min_value = 0 , max_value = 1 ),
211211 data = st .data (),
212212 )
@@ -223,7 +223,7 @@ def test_hypothesis(
223223 if (
224224 library .like (Backend .NUMPY )
225225 and NUMPY_VERSION < (2 , 0 )
226- and dtype is np .float32
226+ and dtype . type is np .float32
227227 ):
228228 pytest .xfail (reason = "NumPy 1.x dtype promotion for scalars" )
229229
@@ -236,17 +236,17 @@ def test_hypothesis(
236236 elements = {"allow_subnormal" : not library .like (Backend .CUPY , Backend .JAX )}
237237
238238 fill_value = xp .asarray (
239- data .draw (npst .arrays (dtype = dtype , shape = (), elements = elements ))
239+ data .draw (npst .arrays (dtype = dtype . type , shape = (), elements = elements ))
240240 )
241241 float_fill_value = float (fill_value )
242- if library is Backend .CUPY and dtype is np .float32 :
242+ if library is Backend .CUPY and dtype . type is np .float32 :
243243 # Avoid data-dependent dtype promotion when encountering subnormals
244244 # close to the max float32 value
245245 float_fill_value = float (np .clip (float_fill_value , - 1e38 , 1e38 ))
246246
247247 arrays = tuple (
248248 xp .asarray (
249- data .draw (npst .arrays (dtype = dtype , shape = shape , elements = elements ))
249+ data .draw (npst .arrays (dtype = dtype . type , shape = shape , elements = elements ))
250250 )
251251 for shape in shapes
252252 )
0 commit comments