Skip to content

Commit 346d96a

Browse files
committed
examples/lapackdrivers_example.py: Python 3 compatibility; general code cleanup
1 parent f0efb70 commit 346d96a

File tree

1 file changed

+51
-55
lines changed

1 file changed

+51
-55
lines changed

examples/lapackdrivers_example.py

Lines changed: 51 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,27 +1,24 @@
11
# -*- coding: utf-8 -*-
22
#
3-
"""Performance benchmarking and usage examples for the wlsqm.lapackdrivers module.
3+
"""Performance benchmarking and usage examples for the wlsqm.utils.lapackdrivers module.
44
55
JJ 2016-11-02
66
"""
77

8-
from __future__ import division
9-
from __future__ import absolute_import
8+
from __future__ import division, print_function, absolute_import
109

1110
import time
12-
import math
1311

1412
import numpy as np
1513
from numpy.linalg import solve as numpy_solve # for comparison purposes
1614

17-
import pylab as pl
15+
import matplotlib.pyplot as plt
1816

1917
try:
2018
import wlsqm.utils.lapackdrivers as drivers
2119
except ImportError:
22-
print "WLSQM not found; is it installed?"
23-
from sys import exit
24-
exit(1)
20+
import sys
21+
sys.exit( "WLSQM not found; is it installed?" )
2522

2623
# from find_neighbors2.py
2724
class SimpleTimer:
@@ -37,7 +34,7 @@ def __exit__(self, errtype, errvalue, traceback):
3734
dt = time.time() - self.t0
3835
identifier = ("%s" % self.label) if len(self.label) else "time taken: "
3936
avg = (", avg. %gs per run" % (dt/self.n)) if self.n is not None else ""
40-
print "%s%gs%s" % (identifier, dt, avg)
37+
print( "%s%gs%s" % (identifier, dt, avg) )
4138

4239
# from util.py
4340
def f5(seq, idfun=None):
@@ -91,17 +88,17 @@ def main():
9188
# test that it works
9289

9390
x = numpy_solve(A, b)
94-
print "NumPy:", x
91+
print( "NumPy:", x )
9592

9693
A2 = A.copy(order='F')
9794
x2 = b.copy()
9895
drivers.symmetric(A2, x2)
99-
print "dsysv:", x2
96+
print( "dsysv:", x2 )
10097

10198
A3 = A.copy(order='F')
10299
x3 = b.copy()
103100
drivers.general(A3, x3)
104-
print "dgesv:", x3
101+
print( "dgesv:", x3 )
105102

106103
assert (np.abs(x - x3) < 1e-10).all(), "Something went wrong, solutions do not match" # check general solver first
107104
assert (np.abs(x - x2) < 1e-10).all(), "Something went wrong, solutions do not match" # then check symmetric solver
@@ -127,7 +124,7 @@ def main():
127124
sizes = f5( map( lambda x: int(x), np.ceil(3*np.logspace(0, 2, 21, dtype=int)) ) )
128125
reps = map( lambda x: int(x), 10.**(6 - np.log10(sizes)) )
129126

130-
print "performance test: %d tasks, sizes %s" % (ntasks, sizes)
127+
print( "performance test: %d tasks, sizes %s" % (ntasks, sizes) )
131128

132129
results1 = np.empty( (len(sizes),), dtype=np.float64 )
133130
results2 = np.empty( (len(sizes),), dtype=np.float64 )
@@ -153,11 +150,11 @@ def main():
153150

154151
for j,item in enumerate(zip(sizes,reps)):
155152
n,r = item
156-
print "testing size %d, reps = %d" % (n, r)
153+
print( "testing size %d, reps = %d" % (n, r) )
157154

158155
# same LHS, many different RHS
159156

160-
print " prep same LHS, many RHS..."
157+
print( " prep same LHS, many RHS..." )
161158

162159
A = np.random.sample( (n,n) )
163160
# symmetrize
@@ -169,24 +166,24 @@ def main():
169166
b = np.random.sample( (n,r) )
170167
b = np.array( b, dtype=np.float64, order='F' )
171168

172-
print " solve:"
169+
print( " solve:" )
173170

174171
# # for verification only - very slow (Python loop, serial!)
175172
# if use_numpy:
176173
# t0 = time.time()
177174
# x = np.empty( (n,r), dtype=np.float64 )
178-
# for k in xrange(r):
175+
# for k in range(r):
179176
# x[:,k] = numpy_solve(A, b[:,k])
180177
# results1[j] = (time.time() - t0) / r
181178

182-
print " symmetricsp"
179+
print( " symmetricsp" )
183180
t0 = time.time()
184181
A2 = A.copy(order='F')
185182
x2 = b.copy(order='F')
186183
drivers.symmetricsp(A2, x2, ntasks)
187184
results2[j] = (time.time() - t0) / r
188185

189-
print " generalsp"
186+
print( " generalsp" )
190187
t0 = time.time()
191188
A3 = A.copy(order='F')
192189
x3 = b.copy(order='F')
@@ -195,7 +192,7 @@ def main():
195192

196193
# different LHS for each problem
197194

198-
print " prep independent problems..."
195+
print( " prep independent problems..." )
199196

200197
A = np.random.sample( (n,n,r) )
201198
# symmetrize
@@ -207,53 +204,53 @@ def main():
207204
b = np.random.sample( (n,r) )
208205
b = np.array( b, dtype=np.float64, order='F' )
209206

210-
print " solve:"
207+
print( " solve:" )
211208

212209
# for verification only - very slow (Python loop, serial!)
213210
if use_numpy:
214-
print " NumPy"
211+
print( " NumPy" )
215212
t0 = time.time()
216213
x = np.empty( (n,r), dtype=np.float64, order='F' )
217-
for k in xrange(r):
214+
for k in range(r):
218215
x[:,k] = numpy_solve(A[:,:,k], b[:,k])
219216
results1[j] = (time.time() - t0) / r
220217

221-
print " msymmetricp"
218+
print( " msymmetricp" )
222219
t0 = time.time()
223220
A2 = A.copy(order='F')
224221
x2 = b.copy(order='F')
225222
drivers.msymmetricp(A2, x2, ntasks)
226223
results4[j] = (time.time() - t0) / r
227224

228-
print " mgeneralp"
225+
print( " mgeneralp" )
229226
t0 = time.time()
230227
A3 = A.copy(order='F')
231228
x3 = b.copy(order='F')
232229
drivers.mgeneralp(A3, x3, ntasks)
233230
results5[j] = (time.time() - t0) / r
234231

235-
print " msymmetricfactorp & msymmetricfactoredp" # factor once, then it is possible to solve multiple times (although we now test only once)
232+
print( " msymmetricfactorp & msymmetricfactoredp" ) # factor once, then it is possible to solve multiple times (although we now test only once)
236233
t0 = time.time()
237-
ipiv = np.empty( (n,r), dtype=np.int32, order='F' )
234+
ipiv = np.empty( (n,r), dtype=np.intc, order='F' )
238235
fact = A.copy(order='F')
239236
x4 = b.copy(order='F')
240237
drivers.msymmetricfactorp( fact, ipiv, ntasks )
241238
drivers.msymmetricfactoredp( fact, ipiv, x4, ntasks )
242239
results6[j] = (time.time() - t0) / r
243240

244-
print " mgeneralfactorp & mgeneralfactoredp" # factor once, then it is possible to solve multiple times (although we now test only once)
241+
print( " mgeneralfactorp & mgeneralfactoredp" ) # factor once, then it is possible to solve multiple times (although we now test only once)
245242
t0 = time.time()
246-
ipiv = np.empty( (n,r), dtype=np.int32, order='F' )
243+
ipiv = np.empty( (n,r), dtype=np.intc, order='F' )
247244
fact = A.copy(order='F')
248245
x5 = b.copy(order='F')
249246
drivers.mgeneralfactorp( fact, ipiv, ntasks )
250247
drivers.mgeneralfactoredp( fact, ipiv, x5, ntasks )
251248
results7[j] = (time.time() - t0) / r
252249

253250
if use_numpy:
254-
# print np.max(np.abs(x - x3)) # DEBUG
255-
# print np.max(np.abs(x - x5)) # DEBUG
256-
print np.max(np.abs(x2 - x4)) # DEBUG
251+
# print( np.max(np.abs(x - x3)) ) # DEBUG
252+
# print( np.max(np.abs(x - x5)) ) # DEBUG
253+
print( np.max(np.abs(x2 - x4)) ) # DEBUG
257254
assert (np.abs(x - x5) < 1e-10).all(), "Something went wrong, solutions do not match" # check general solver first
258255
assert (np.abs(x - x3) < 1e-10).all(), "Something went wrong, solutions do not match" # check general solver
259256
# assert (np.abs(x - x2) < 1e-5).all(), "Something went wrong, solutions do not match" # doesn't make sense to compare, DSYSV is more accurate for badly conditioned symmetric matrices
@@ -265,27 +262,27 @@ def main():
265262
#
266263
# for j,item in enumerate(zip(sizes,reps)):
267264
# n,r = item
268-
# print "testing size %d, reps = %d" % (n, r)
265+
# print( "testing size %d, reps = %d" % (n, r) )
269266
#
270267
# A = np.random.sample( (n,n) )
271268
# A = 0.5 * (A + A.T) # symmetrize
272269
# A = np.array( A, dtype=np.float64, order='F' )
273270
# b = np.random.sample( (n,) )
274271
#
275272
# t0 = time.time()
276-
# for k in xrange(r):
273+
# for k in range(r):
277274
# x = numpy_solve(A, b)
278275
# results1[j] = (time.time() - t0) / r
279276
#
280277
# t0 = time.time()
281-
# for k in xrange(r):
278+
# for k in range(r):
282279
# A2 = A.copy(order='F')
283280
# x2 = b.copy()
284281
# drivers.symmetric(A2, x2)
285282
# results2[j] = (time.time() - t0) / r
286283
#
287284
# t0 = time.time()
288-
# for k in xrange(r):
285+
# for k in range(r):
289286
# A3 = A.copy(order='F')
290287
# x3 = b.copy()
291288
# drivers.general(A3, x3)
@@ -294,27 +291,26 @@ def main():
294291

295292
# visualize
296293

297-
pl.figure(1)
298-
pl.clf()
294+
plt.figure(1)
295+
plt.clf()
299296
if use_numpy:
300-
pl.loglog(sizes, results1, 'k-', label='NumPy')
301-
pl.loglog(sizes, results2, 'b--', label='dsysv, same LHS, many RHS')
302-
pl.loglog(sizes, results3, 'b-', label='dgesv, same LHS, many RHS')
303-
pl.loglog(sizes, results4, 'r--', label='dsysv, independent problems')
304-
pl.loglog(sizes, results5, 'r-', label='dgesv, independent problems')
305-
pl.loglog(sizes, results6, 'g--', label='dsytrf+dsytrs, independent problems')
306-
pl.loglog(sizes, results7, 'g-', label='dgetrf+dgetrs, independent problems')
307-
pl.xlabel('n')
308-
pl.ylabel('t')
309-
pl.title('Average time per problem instance, %d parallel tasks' % (ntasks))
310-
pl.axis('tight')
311-
pl.grid(b=True, which='both')
312-
pl.legend(loc='best')
313-
314-
pl.savefig('figure1_latest.pdf')
297+
plt.loglog(sizes, results1, 'k-', label='NumPy')
298+
plt.loglog(sizes, results2, 'b--', label='dsysv, same LHS, many RHS')
299+
plt.loglog(sizes, results3, 'b-', label='dgesv, same LHS, many RHS')
300+
plt.loglog(sizes, results4, 'r--', label='dsysv, independent problems')
301+
plt.loglog(sizes, results5, 'r-', label='dgesv, independent problems')
302+
plt.loglog(sizes, results6, 'g--', label='dsytrf+dsytrs, independent problems')
303+
plt.loglog(sizes, results7, 'g-', label='dgetrf+dgetrs, independent problems')
304+
plt.xlabel('n')
305+
plt.ylabel('t')
306+
plt.title('Average time per problem instance, %d parallel tasks' % (ntasks))
307+
plt.axis('tight')
308+
plt.grid(b=True, which='both')
309+
plt.legend(loc='best')
310+
311+
plt.savefig('figure1_latest.pdf')
315312

316313

317314
if __name__ == '__main__':
318315
main()
319-
pl.show()
320-
316+
plt.show()

0 commit comments

Comments
 (0)