Skip to content

Commit 32173b0

Browse files
committed
* python/benchmarks/bench.py: Add support for setup= and check= arguments.
1 parent c6499e2 commit 32173b0

File tree

7 files changed

+100
-131
lines changed

7 files changed

+100
-131
lines changed

python/benchmarks/bench.py

Lines changed: 42 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -78,22 +78,35 @@ def get_package_version(package):
7878
case "DEAP_er":
7979
# Requires version >= 0.2.0
8080
from deap_er import __version__ as version
81-
case "desdeo":
81+
case "desdeo" | "seqme":
8282
# It does not provide __version__ !
8383
from importlib.metadata import version as get_version
8484

85-
version = get_version("desdeo")
85+
version = get_version(package)
8686
case _:
8787
raise ValueError(f"unknown package {package}")
8888

8989
return version
9090

9191

92+
def check_float_values(a, b, what, n, name):
93+
assert np.isclose(a, b), (
94+
f"In {name}, maxrow={n}, {what}={b} not equal to moocore={a}"
95+
)
96+
97+
9298
class Bench:
9399
cpu_model = cpuinfo.get_cpu_info()["brand_raw"]
94100

95101
def __init__(
96-
self, name, n, bench, report_values=None, return_all_values=False
102+
self,
103+
name,
104+
n,
105+
bench,
106+
setup=None,
107+
check=None,
108+
report_values=None,
109+
return_all_values=False,
97110
):
98111
self.name = name
99112
self.n = n
@@ -112,14 +125,20 @@ def __init__(
112125
else:
113126
self.values = None
114127
self.value_label = None
128+
self.setup = setup
129+
self.check = check
115130

116131
def keys(self):
117132
return self.bench.keys()
118133

119-
def __call__(self, what, n, *args, **kwargs):
120-
# FIXME: Ideally, bench() would call fun for each value in self.n
121-
assert n in self.n
122-
# FIXME: Allow passing a setup() function.
134+
def bench1(self, what, n, *args, **kwargs):
135+
if self.setup:
136+
if isinstance(self.setup, dict):
137+
setup = self.setup.get(what)
138+
if setup:
139+
args = (setup(*args, **kwargs),)
140+
kwargs = {}
141+
123142
fun = self.bench[what]
124143
duration, value = timeit.Timer(lambda: fun(*args, **kwargs)).timeit(
125144
number=3
@@ -130,6 +149,22 @@ def __call__(self, what, n, *args, **kwargs):
130149
print(f"{self.name}:{n}:{what}:{duration}")
131150
return value
132151

152+
def __call__(self, n, *args, **kwargs):
153+
# FIXME: Ideally, bench() would call fun for each value in self.n
154+
assert n in self.n
155+
values = {
156+
what: self.bench1(what, n, *args, **kwargs) for what in self.keys()
157+
}
158+
if self.check:
159+
a = values["moocore"]
160+
for what in self.keys():
161+
if what == "moocore":
162+
continue
163+
b = values[what]
164+
self.check(a, b, what=what, n=n, name=self.name)
165+
166+
return values
167+
133168
def plots(self, title, file_prefix, log="y", relative=False):
134169
for what in self.keys():
135170
self.times[what] = np.asarray(self.times[what])

python/benchmarks/bench_epsilon.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
66
"""
77

8-
from bench import Bench, read_data
8+
from bench import Bench, read_data, check_float_values
99

1010
import numpy as np
1111
import moocore
@@ -43,25 +43,12 @@
4343
"moocore": lambda z, ref=ref: moocore.epsilon_additive(z, ref=ref),
4444
"jMetalPy": lambda z, eps=jmetal_EPS(ref): eps.compute(z),
4545
},
46+
check=check_float_values,
4647
)
4748

48-
values = {}
4949
for maxrow in n:
50-
z = x[:maxrow, :]
51-
for what in bench.keys():
52-
values[what] = bench(what, maxrow, z)
50+
values = bench(maxrow, x[:maxrow, :])
5351

54-
# Check values
55-
for what in bench.keys():
56-
if what == "moocore":
57-
continue
58-
a = values["moocore"]
59-
b = values[what]
60-
assert np.isclose(a, b), (
61-
f"In {name}, maxrow={maxrow}, {what}={b} not equal to moocore={a}"
62-
)
63-
64-
del values
6552
bench.plots(file_prefix=file_prefix, title=title)
6653

6754
if "__file__" not in globals(): # Running interactively.

python/benchmarks/bench_hv.py

Lines changed: 19 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,12 @@
4747
}
4848

4949

50+
def check_values(a, b, what, n, name):
51+
assert np.isclose(a, b), (
52+
f"In {name}, maxrow={n}, {what}={b} not equal to moocore={a}"
53+
)
54+
55+
5056
title = "HV computation"
5157
file_prefix = "hv"
5258
names = files.keys()
@@ -72,30 +78,20 @@
7278
ref_point=torch.from_numpy(-ref)
7379
): hv.compute(z)
7480

75-
bench = Bench(name=name, n=n, bench=benchmarks)
76-
values = {}
81+
bench = Bench(
82+
name=name,
83+
n=n,
84+
# Exclude the conversion to torch from the timing.
85+
setup={"botorch": lambda z: torch.from_numpy(-z)},
86+
# elif what == "trieste":
87+
# zz = tf.convert_to_tensor(z)
88+
bench=benchmarks,
89+
check=check_values,
90+
)
91+
7792
for maxrow in n:
78-
z = x[:maxrow, :]
79-
for what in bench.keys():
80-
if what == "botorch":
81-
zz = torch.from_numpy(-z)
82-
# elif what == "trieste":
83-
# zz = tf.convert_to_tensor(z)
84-
else:
85-
zz = z
86-
values[what] = bench(what, maxrow, zz)
87-
88-
# Check values
89-
for what in bench.keys():
90-
if what == "moocore":
91-
continue
92-
a = values["moocore"]
93-
b = values[what]
94-
assert np.isclose(a, b), (
95-
f"In {name}, maxrow={maxrow}, {what}={b} not equal to moocore={a}"
96-
)
97-
98-
del values
93+
values = bench(maxrow, x[:maxrow, :])
94+
9995
bench.plots(file_prefix=file_prefix, title=title)
10096

10197
if "__file__" not in globals(): # Running interactively.

python/benchmarks/bench_hvapprox.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -107,14 +107,11 @@ def time_hv_exact(name, maxrow, z, ref):
107107
bench = Bench(
108108
name=name, n=n, bench=benchmarks, report_values="HV Relative Error"
109109
)
110-
res = {}
111110
for maxrow in n:
112111
z = x[:maxrow, :]
113112
exact = time_hv_exact(name, maxrow, z, ref)
114-
for what in bench.keys():
115-
res[what] = bench(what, maxrow, z, exact=exact)
113+
res = bench(maxrow, z, exact=exact)
116114

117-
del res
118115
bench.plots(file_prefix=file_prefix, title=title)
119116

120117
if "__file__" not in globals(): # Running interactively.

python/benchmarks/bench_igdplus.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
66
"""
77

8-
from bench import Bench, read_data
8+
from bench import Bench, read_data, check_float_values
99

1010
import numpy as np
1111
import moocore
@@ -47,25 +47,12 @@
4747
# FIXME: Currently DESDEO is a thousand times slower than moocore, so it is not worth running it.
4848
# "desdeo": lambda z, ref=ref: desdeo_igd_plus(z, reference_set=ref).igd_plus,
4949
},
50+
check=check_float_values,
5051
)
5152

52-
values = {}
5353
for maxrow in n:
54-
z = x[:maxrow, :]
55-
for what in bench.keys():
56-
values[what] = bench(what, maxrow, z)
54+
values = bench(maxrow, x[:maxrow, :])
5755

58-
# Check values
59-
for what in bench.keys():
60-
if what == "moocore":
61-
continue
62-
a = values["moocore"]
63-
b = values[what]
64-
assert np.isclose(a, b), (
65-
f"In {name}, maxrow={maxrow}, {what}={b} not equal to moocore={a}"
66-
)
67-
68-
del values
6956
bench.plots(file_prefix=file_prefix, title=title)
7057

7158
if "__file__" not in globals(): # Running interactively.

python/benchmarks/bench_ndom.py

Lines changed: 25 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,8 @@
2525

2626
from paretoset import paretoset
2727

28+
from seqme.core.rank import is_pareto_front as seqme_is_pareto_front
29+
2830
# See https://github.com/multi-objective/testsuite/tree/main/data
2931
files = {
3032
# range are the (start, stop, num) parameters for np.geomspace()
@@ -50,6 +52,18 @@ def get_dataset(name):
5052
return moocore.get_dataset(files[name]["file"])[:, :-1]
5153

5254

55+
# Exclude the conversion to torch from the timing.
56+
setup = {"botorch": lambda z: torch.from_numpy(z)}
57+
58+
59+
def check_values(a, b, what, n, name):
60+
np.testing.assert_allclose(
61+
a,
62+
b,
63+
err_msg=f"In {name}, maxrow={n}, {what}={b} not equal to moocore={a}",
64+
)
65+
66+
5367
title = "is_non_dominated(keep_weakly=False)"
5468
file_prefix = "ndom"
5569
names = files.keys()
@@ -60,6 +74,7 @@ def get_dataset(name):
6074
bench = Bench(
6175
name=name,
6276
n=n,
77+
setup=setup,
6378
bench={
6479
"moocore": lambda z: moocore.is_nondominated(
6580
z, maximise=True, keep_weakly=False
@@ -69,38 +84,17 @@ def get_dataset(name):
6984
z, sense=z.shape[1] * ["max"], distinct=True, use_numba=True
7085
),
7186
},
87+
check=check_values,
7288
)
7389

74-
values = {}
7590
for maxrow in n:
76-
z = x[:maxrow, :]
77-
for what in bench.keys():
78-
if what == "botorch":
79-
# Exclude the conversion to torch from the timing.
80-
zz = torch.from_numpy(z)
81-
else:
82-
zz = z
83-
values[what] = bench(what, maxrow, zz)
84-
85-
# Check values
86-
for what in bench.keys():
87-
if what == "moocore":
88-
continue
89-
a = values["moocore"]
90-
b = values[what]
91-
np.testing.assert_array_equal(
92-
a,
93-
b,
94-
err_msg=f"In {name}, maxrow={maxrow}, {what}={b} not equal to moocore={a}",
95-
)
96-
97-
del values
91+
values = bench(maxrow, x[:maxrow, :])
92+
9893
bench.plots(file_prefix=file_prefix, title=title, log="xy")
9994

10095

10196
title = "is_non_dominated(keep_weakly=True)"
10297
file_prefix = "wndom"
103-
10498
names = files.keys()
10599
for name in names:
106100
x = get_dataset(name)
@@ -109,6 +103,7 @@ def get_dataset(name):
109103
bench = Bench(
110104
name=name,
111105
n=n,
106+
setup=setup,
112107
bench={
113108
"moocore": lambda z: bool2pos(
114109
moocore.is_nondominated(z, maximise=True, keep_weakly=True)
@@ -128,33 +123,16 @@ def get_dataset(name):
128123
-z, only_non_dominated_front=True
129124
),
130125
"desdeo": lambda z: bool2pos(desdeo_is_nondominated(-z)),
126+
"seqme": lambda z: bool2pos(
127+
seqme_is_pareto_front(-z, assume_unique_lexsorted=True)
128+
),
131129
},
130+
check=check_values,
132131
)
133132

134-
values = {}
135133
for maxrow in n:
136-
z = x[:maxrow, :]
137-
for what in bench.keys():
138-
if what == "botorch":
139-
# Exclude the conversion to torch from the timing.
140-
zz = torch.from_numpy(z)
141-
else:
142-
zz = z
143-
values[what] = bench(what, maxrow, zz)
144-
145-
# Check values
146-
for what in bench.keys():
147-
if what == "moocore":
148-
continue
149-
a = values["moocore"]
150-
b = values[what]
151-
np.testing.assert_allclose(
152-
a,
153-
b,
154-
err_msg=f"In {name}, maxrow={maxrow}, {what}={b} not equal to moocore={a}",
155-
)
156-
157-
del values
134+
values = bench(maxrow, x[:maxrow, :])
135+
158136
bench.plots(file_prefix=file_prefix, title=title, log="xy")
159137

160138

python/benchmarks/bench_ndsort.py

Lines changed: 7 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -57,27 +57,16 @@ def get_dataset(name):
5757
"pymoo": lambda z, nds=pymoo_NDS(): nds.do(z, return_rank=True)[1],
5858
"desdeo": lambda z: desdeo_nds(z).argmax(axis=0),
5959
},
60+
check=lambda a, b, what, n, name: np.testing.assert_array_equal(
61+
a,
62+
b,
63+
err_msg=f"In {name}, maxrow={n}, {what}={b} not equal to moocore={a}",
64+
),
6065
)
6166

62-
values = {}
6367
for maxrow in n:
64-
z = x[:maxrow, :]
65-
for what in bench.keys():
66-
values[what] = bench(what, maxrow, z)
67-
68-
# Check values
69-
for what in bench.keys():
70-
if what == "moocore":
71-
continue
72-
a = values["moocore"]
73-
b = values[what]
74-
np.testing.assert_array_equal(
75-
a,
76-
b,
77-
err_msg=f"In {name}, maxrow={maxrow}, {what}={b} not equal to moocore={a}",
78-
)
79-
80-
del values
68+
values = bench(maxrow, x[:maxrow, :])
69+
8170
bench.plots(file_prefix=file_prefix, title=title, log="xy")
8271

8372
if "__file__" not in globals(): # Running interactively.

0 commit comments

Comments
 (0)