Skip to content

Commit d6b115f

Browse files
committed
Bump multiversion
1 parent f0a3741 commit d6b115f

File tree

4 files changed

+31
-64
lines changed

4 files changed

+31
-64
lines changed

Cargo.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ codegen-units = 1
2323
[dependencies]
2424
rand = { version = "0.8.5", features = ["small_rng"] }
2525
rand_distr = "0.4.3"
26-
multiversion = "0.6.1"
26+
multiversion = "0.7.0"
2727
itertools = "0.10.3"
2828
crossbeam = "0.8.1"
2929
thiserror = "1.0.31"
@@ -34,7 +34,7 @@ ndarray = "0.15.4"
3434
proptest = "1.0.0"
3535
pretty_assertions = "1.2.1"
3636
criterion = "0.4.0"
37-
nix = "0.25.0"
37+
nix = "0.26.1"
3838
approx = "0.5.1"
3939

4040
[[bench]]

src/cpu_sampler.rs

Lines changed: 16 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ impl InitPointFunc for JitterInitFunc {
241241
}
242242

243243
pub mod test_logps {
244-
use crate::{cpu_potential::CpuLogpFunc, nuts::LogpError};
244+
use crate::{cpu_potential::CpuLogpFunc, nuts::LogpError, CpuLogpFuncMaker};
245245
use multiversion::multiversion;
246246
use thiserror::Error;
247247

@@ -251,6 +251,18 @@ pub mod test_logps {
251251
mu: f64,
252252
}
253253

254+
impl CpuLogpFuncMaker for NormalLogp {
255+
type Func = Self;
256+
257+
fn make_logp_func(&self) -> Result<Self::Func, Box<dyn std::error::Error + Send + Sync>> {
258+
Ok(self.clone())
259+
}
260+
261+
fn dim(&self) -> usize {
262+
self.dim
263+
}
264+
}
265+
254266
impl NormalLogp {
255267
pub fn new(dim: usize, mu: f64) -> NormalLogp {
256268
NormalLogp { dim, mu }
@@ -276,9 +288,7 @@ pub mod test_logps {
276288
assert!(gradient.len() == n);
277289

278290
#[cfg(feature = "simd_support")]
279-
#[multiversion]
280-
#[clone(target = "[x64|x86_64]+avx+avx2+fma")]
281-
#[clone(target = "x86+sse")]
291+
#[multiversion(targets("x86_64+avx+avx2+fma", "arm+neon"))]
282292
fn logp_inner(mu: f64, position: &[f64], gradient: &mut [f64]) -> f64 {
283293
use std::simd::f64x4;
284294
use std::simd::SimdFloat;
@@ -313,9 +323,7 @@ pub mod test_logps {
313323
}
314324

315325
#[cfg(not(feature = "simd_support"))]
316-
#[multiversion]
317-
#[clone(target = "[x64|x86_64]+avx+avx2+fma")]
318-
#[clone(target = "x86+sse")]
326+
#[multiversion(targets("x86_64+avx+avx2+fma", "arm+neon"))]
319327
fn logp_inner(mu: f64, position: &[f64], gradient: &mut [f64]) -> f64 {
320328
let n = position.len();
321329
assert!(gradient.len() == n);
@@ -370,22 +378,7 @@ mod tests {
370378
.iter()
371379
.any(|(key, _)| *key == "index_in_trajectory"));
372380

373-
struct Maker {
374-
logp: NormalLogp,
375-
}
376-
impl CpuLogpFuncMaker for Maker {
377-
type Func = NormalLogp;
378-
379-
fn make_logp_func(&self) -> Result<Self::Func, Box<dyn Error + Send + Sync>> {
380-
Ok(self.logp.clone())
381-
}
382-
383-
fn dim(&self) -> usize {
384-
self.logp.dim()
385-
}
386-
}
387-
388-
let maker = Maker { logp };
381+
let maker = logp;
389382

390383
let (handles, chains) =
391384
sample_parallel(maker, &mut JitterInitFunc::new(), settings, 4, 100, 42, 10).unwrap();

src/mass_matrix.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,9 +38,7 @@ impl DiagMassMatrix {
3838
}
3939
}
4040

41-
#[multiversion]
42-
#[clone(target = "[x64|x86_64]+avx+avx2+fma")]
43-
#[clone(target = "x86+sse")]
41+
#[multiversion(targets("x86_64+avx+avx2+fma", "arm+neon"))]
4442
fn update_diag(
4543
variance_out: &mut [f64],
4644
inv_std_out: &mut [f64],

src/math.rs

Lines changed: 12 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,7 @@ pub(crate) fn logaddexp(a: f64, b: f64) -> f64 {
2020
}
2121

2222
#[cfg(feature = "simd_support")]
23-
#[multiversion]
24-
#[clone(target = "[x64|x86_64]+avx+avx2+fma")]
25-
#[clone(target = "x86+sse")]
23+
#[multiversion(targets("x86_64+avx+avx2+fma", "arm+neon"))]
2624
pub fn multiply(x: &[f64], y: &[f64], out: &mut [f64]) {
2725
let n = x.len();
2826
assert!(y.len() == n);
@@ -44,9 +42,7 @@ pub fn multiply(x: &[f64], y: &[f64], out: &mut [f64]) {
4442
}
4543

4644
#[cfg(not(feature = "simd_support"))]
47-
#[multiversion]
48-
#[clone(target = "[x64|x86_64]+avx+avx2+fma")]
49-
#[clone(target = "x86+sse")]
45+
#[multiversion(targets("x86_64+avx+avx2+fma", "arm+neon"))]
5046
pub fn multiply(x: &[f64], y: &[f64], out: &mut [f64]) {
5147
let n = x.len();
5248
assert!(y.len() == n);
@@ -58,9 +54,7 @@ pub fn multiply(x: &[f64], y: &[f64], out: &mut [f64]) {
5854
}
5955

6056
#[cfg(feature = "simd_support")]
61-
#[multiversion]
62-
#[clone(target = "[x84|x86_64]+avx+avx2+fma")]
63-
#[clone(target = "x86+sse")]
57+
#[multiversion(targets("x86_64+avx+avx2+fma", "arm+neon"))]
6458
pub fn scalar_prods2(positive1: &[f64], positive2: &[f64], x: &[f64], y: &[f64]) -> (f64, f64) {
6559
let n = positive1.len();
6660

@@ -99,9 +93,7 @@ pub fn scalar_prods2(positive1: &[f64], positive2: &[f64], x: &[f64], y: &[f64])
9993
}
10094

10195
#[cfg(not(feature = "simd_support"))]
102-
#[multiversion]
103-
#[clone(target = "[x84|x86_64]+avx+avx2+fma")]
104-
#[clone(target = "x86+sse")]
96+
#[multiversion(targets("x86_64+avx+avx2+fma", "arm+neon"))]
10597
pub fn scalar_prods2(positive1: &[f64], positive2: &[f64], x: &[f64], y: &[f64]) -> (f64, f64) {
10698
let n = positive1.len();
10799

@@ -116,9 +108,7 @@ pub fn scalar_prods2(positive1: &[f64], positive2: &[f64], x: &[f64], y: &[f64])
116108
}
117109

118110
#[cfg(feature = "simd_support")]
119-
#[multiversion]
120-
#[clone(target = "[x84|x86_64]+avx+avx2+fma")]
121-
#[clone(target = "x86+sse")]
111+
#[multiversion(targets("x86_64+avx+avx2+fma", "arm+neon"))]
122112
pub fn scalar_prods3(
123113
positive1: &[f64],
124114
negative1: &[f64],
@@ -167,9 +157,7 @@ pub fn scalar_prods3(
167157
}
168158

169159
#[cfg(not(feature = "simd_support"))]
170-
#[multiversion]
171-
#[clone(target = "[x84|x86_64]+avx+avx2+fma")]
172-
#[clone(target = "x86+sse")]
160+
#[multiversion(targets("x86_64+avx+avx2+fma", "arm+neon"))]
173161
pub fn scalar_prods3(
174162
positive1: &[f64],
175163
negative1: &[f64],
@@ -191,9 +179,7 @@ pub fn scalar_prods3(
191179
}
192180

193181
#[cfg(feature = "simd_support")]
194-
#[multiversion]
195-
#[clone(target = "[x86|x86_64]+avx+avx2+fma")]
196-
#[clone(target = "x86+sse")]
182+
#[multiversion(targets("x86_64+avx+avx2+fma", "arm+neon"))]
197183
pub fn vector_dot(a: &[f64], b: &[f64]) -> f64 {
198184
assert!(a.len() == b.len());
199185

@@ -216,9 +202,7 @@ pub fn vector_dot(a: &[f64], b: &[f64]) -> f64 {
216202
}
217203

218204
#[cfg(not(feature = "simd_support"))]
219-
#[multiversion]
220-
#[clone(target = "[x86|x86_64]+avx+avx2+fma")]
221-
#[clone(target = "x86+sse")]
205+
#[multiversion(targets("x86_64+avx+avx2+fma", "arm+neon"))]
222206
pub fn vector_dot(a: &[f64], b: &[f64]) -> f64 {
223207
assert!(a.len() == b.len());
224208

@@ -230,9 +214,7 @@ pub fn vector_dot(a: &[f64], b: &[f64]) -> f64 {
230214
}
231215

232216
#[cfg(feature = "simd_support")]
233-
#[multiversion]
234-
#[clone(target = "[x86|x86_64]+avx+avx2+fma")]
235-
#[clone(target = "x86+sse")]
217+
#[multiversion(targets("x86_64+avx+avx2+fma", "arm+neon"))]
236218
pub fn axpy(x: &[f64], y: &mut [f64], a: f64) {
237219
let n = x.len();
238220
assert!(y.len() == n);
@@ -255,9 +237,7 @@ pub fn axpy(x: &[f64], y: &mut [f64], a: f64) {
255237
}
256238

257239
#[cfg(not(feature = "simd_support"))]
258-
#[multiversion]
259-
#[clone(target = "[x86|x86_64]+avx+avx2+fma")]
260-
#[clone(target = "x86+sse")]
240+
#[multiversion(targets("x86_64+avx+avx2+fma", "arm+neon"))]
261241
pub fn axpy(x: &[f64], y: &mut [f64], a: f64) {
262242
let n = x.len();
263243
assert!(y.len() == n);
@@ -268,9 +248,7 @@ pub fn axpy(x: &[f64], y: &mut [f64], a: f64) {
268248
}
269249

270250
#[cfg(feature = "simd_support")]
271-
#[multiversion]
272-
#[clone(target = "[x86|x86_64]+avx+avx2+fma")]
273-
#[clone(target = "x86+sse+fma")]
251+
#[multiversion(targets("x86_64+avx+avx2+fma", "arm+neon"))]
274252
pub fn axpy_out(x: &[f64], y: &[f64], a: f64, out: &mut [f64]) {
275253
let n = x.len();
276254
assert!(y.len() == n);
@@ -297,9 +275,7 @@ pub fn axpy_out(x: &[f64], y: &[f64], a: f64, out: &mut [f64]) {
297275
}
298276

299277
#[cfg(not(feature = "simd_support"))]
300-
#[multiversion]
301-
#[clone(target = "[x86|x86_64]+avx+avx2+fma")]
302-
#[clone(target = "x86+sse+fma")]
278+
#[multiversion(targets("x86_64+avx+avx2+fma", "arm+neon"))]
303279
pub fn axpy_out(x: &[f64], y: &[f64], a: f64, out: &mut [f64]) {
304280
let n = x.len();
305281
assert!(y.len() == n);

0 commit comments

Comments
 (0)