11# -*- coding: utf-8 -*-
22#
3- """Performance benchmarking and usage examples for the wlsqm.lapackdrivers module.
3+ """Performance benchmarking and usage examples for the wlsqm.utils. lapackdrivers module.
44
55JJ 2016-11-02
66"""
77
8- from __future__ import division
9- from __future__ import absolute_import
8+ from __future__ import division , print_function , absolute_import
109
1110import time
12- import math
1311
1412import numpy as np
1513from numpy .linalg import solve as numpy_solve # for comparison purposes
1614
17- import pylab as pl
15+ import matplotlib . pyplot as plt
1816
1917try :
2018 import wlsqm .utils .lapackdrivers as drivers
2119except ImportError :
22- print "WLSQM not found; is it installed?"
23- from sys import exit
24- exit (1 )
20+ import sys
21+ sys .exit ( "WLSQM not found; is it installed?" )
2522
2623# from find_neighbors2.py
2724class SimpleTimer :
@@ -37,7 +34,7 @@ def __exit__(self, errtype, errvalue, traceback):
3734 dt = time .time () - self .t0
3835 identifier = ("%s" % self .label ) if len (self .label ) else "time taken: "
3936 avg = (", avg. %gs per run" % (dt / self .n )) if self .n is not None else ""
40- print "%s%gs%s" % (identifier , dt , avg )
37+ print ( "%s%gs%s" % (identifier , dt , avg ) )
4138
4239# from util.py
4340def f5 (seq , idfun = None ):
@@ -91,17 +88,17 @@ def main():
9188 # test that it works
9289
9390 x = numpy_solve (A , b )
94- print "NumPy:" , x
91+ print ( "NumPy:" , x )
9592
9693 A2 = A .copy (order = 'F' )
9794 x2 = b .copy ()
9895 drivers .symmetric (A2 , x2 )
99- print "dsysv:" , x2
96+ print ( "dsysv:" , x2 )
10097
10198 A3 = A .copy (order = 'F' )
10299 x3 = b .copy ()
103100 drivers .general (A3 , x3 )
104- print "dgesv:" , x3
101+ print ( "dgesv:" , x3 )
105102
106103 assert (np .abs (x - x3 ) < 1e-10 ).all (), "Something went wrong, solutions do not match" # check general solver first
107104 assert (np .abs (x - x2 ) < 1e-10 ).all (), "Something went wrong, solutions do not match" # then check symmetric solver
@@ -127,7 +124,7 @@ def main():
127124 sizes = f5 ( map ( lambda x : int (x ), np .ceil (3 * np .logspace (0 , 2 , 21 , dtype = int )) ) )
128125 reps = map ( lambda x : int (x ), 10. ** (6 - np .log10 (sizes )) )
129126
130- print "performance test: %d tasks, sizes %s" % (ntasks , sizes )
127+ print ( "performance test: %d tasks, sizes %s" % (ntasks , sizes ) )
131128
132129 results1 = np .empty ( (len (sizes ),), dtype = np .float64 )
133130 results2 = np .empty ( (len (sizes ),), dtype = np .float64 )
@@ -153,11 +150,11 @@ def main():
153150
154151 for j ,item in enumerate (zip (sizes ,reps )):
155152 n ,r = item
156- print "testing size %d, reps = %d" % (n , r )
153+ print ( "testing size %d, reps = %d" % (n , r ) )
157154
158155 # same LHS, many different RHS
159156
160- print " prep same LHS, many RHS..."
157+ print ( " prep same LHS, many RHS..." )
161158
162159 A = np .random .sample ( (n ,n ) )
163160 # symmetrize
@@ -169,24 +166,24 @@ def main():
169166 b = np .random .sample ( (n ,r ) )
170167 b = np .array ( b , dtype = np .float64 , order = 'F' )
171168
172- print " solve:"
169+ print ( " solve:" )
173170
174171# # for verification only - very slow (Python loop, serial!)
175172# if use_numpy:
176173# t0 = time.time()
177174# x = np.empty( (n,r), dtype=np.float64 )
178- # for k in xrange (r):
175+ # for k in range (r):
179176# x[:,k] = numpy_solve(A, b[:,k])
180177# results1[j] = (time.time() - t0) / r
181178
182- print " symmetricsp"
179+ print ( " symmetricsp" )
183180 t0 = time .time ()
184181 A2 = A .copy (order = 'F' )
185182 x2 = b .copy (order = 'F' )
186183 drivers .symmetricsp (A2 , x2 , ntasks )
187184 results2 [j ] = (time .time () - t0 ) / r
188185
189- print " generalsp"
186+ print ( " generalsp" )
190187 t0 = time .time ()
191188 A3 = A .copy (order = 'F' )
192189 x3 = b .copy (order = 'F' )
@@ -195,7 +192,7 @@ def main():
195192
196193 # different LHS for each problem
197194
198- print " prep independent problems..."
195+ print ( " prep independent problems..." )
199196
200197 A = np .random .sample ( (n ,n ,r ) )
201198 # symmetrize
@@ -207,53 +204,53 @@ def main():
207204 b = np .random .sample ( (n ,r ) )
208205 b = np .array ( b , dtype = np .float64 , order = 'F' )
209206
210- print " solve:"
207+ print ( " solve:" )
211208
212209 # for verification only - very slow (Python loop, serial!)
213210 if use_numpy :
214- print " NumPy"
211+ print ( " NumPy" )
215212 t0 = time .time ()
216213 x = np .empty ( (n ,r ), dtype = np .float64 , order = 'F' )
217- for k in xrange (r ):
214+ for k in range (r ):
218215 x [:,k ] = numpy_solve (A [:,:,k ], b [:,k ])
219216 results1 [j ] = (time .time () - t0 ) / r
220217
221- print " msymmetricp"
218+ print ( " msymmetricp" )
222219 t0 = time .time ()
223220 A2 = A .copy (order = 'F' )
224221 x2 = b .copy (order = 'F' )
225222 drivers .msymmetricp (A2 , x2 , ntasks )
226223 results4 [j ] = (time .time () - t0 ) / r
227224
228- print " mgeneralp"
225+ print ( " mgeneralp" )
229226 t0 = time .time ()
230227 A3 = A .copy (order = 'F' )
231228 x3 = b .copy (order = 'F' )
232229 drivers .mgeneralp (A3 , x3 , ntasks )
233230 results5 [j ] = (time .time () - t0 ) / r
234231
235- print " msymmetricfactorp & msymmetricfactoredp" # factor once, then it is possible to solve multiple times (although we now test only once)
232+ print ( " msymmetricfactorp & msymmetricfactoredp" ) # factor once, then it is possible to solve multiple times (although we now test only once)
236233 t0 = time .time ()
237- ipiv = np .empty ( (n ,r ), dtype = np .int32 , order = 'F' )
234+ ipiv = np .empty ( (n ,r ), dtype = np .intc , order = 'F' )
238235 fact = A .copy (order = 'F' )
239236 x4 = b .copy (order = 'F' )
240237 drivers .msymmetricfactorp ( fact , ipiv , ntasks )
241238 drivers .msymmetricfactoredp ( fact , ipiv , x4 , ntasks )
242239 results6 [j ] = (time .time () - t0 ) / r
243240
244- print " mgeneralfactorp & mgeneralfactoredp" # factor once, then it is possible to solve multiple times (although we now test only once)
241+ print ( " mgeneralfactorp & mgeneralfactoredp" ) # factor once, then it is possible to solve multiple times (although we now test only once)
245242 t0 = time .time ()
246- ipiv = np .empty ( (n ,r ), dtype = np .int32 , order = 'F' )
243+ ipiv = np .empty ( (n ,r ), dtype = np .intc , order = 'F' )
247244 fact = A .copy (order = 'F' )
248245 x5 = b .copy (order = 'F' )
249246 drivers .mgeneralfactorp ( fact , ipiv , ntasks )
250247 drivers .mgeneralfactoredp ( fact , ipiv , x5 , ntasks )
251248 results7 [j ] = (time .time () - t0 ) / r
252249
253250 if use_numpy :
254- # print np.max(np.abs(x - x3)) # DEBUG
255- # print np.max(np.abs(x - x5)) # DEBUG
256- print np .max (np .abs (x2 - x4 )) # DEBUG
251+ # print( np.max(np.abs(x - x3)) ) # DEBUG
252+ # print( np.max(np.abs(x - x5)) ) # DEBUG
253+ print ( np .max (np .abs (x2 - x4 )) ) # DEBUG
257254 assert (np .abs (x - x5 ) < 1e-10 ).all (), "Something went wrong, solutions do not match" # check general solver first
258255 assert (np .abs (x - x3 ) < 1e-10 ).all (), "Something went wrong, solutions do not match" # check general solver
259256# assert (np.abs(x - x2) < 1e-5).all(), "Something went wrong, solutions do not match" # doesn't make sense to compare, DSYSV is more accurate for badly conditioned symmetric matrices
@@ -265,27 +262,27 @@ def main():
265262#
266263# for j,item in enumerate(zip(sizes,reps)):
267264# n,r = item
268- # print "testing size %d, reps = %d" % (n, r)
265+ # print( "testing size %d, reps = %d" % (n, r) )
269266#
270267# A = np.random.sample( (n,n) )
271268# A = 0.5 * (A + A.T) # symmetrize
272269# A = np.array( A, dtype=np.float64, order='F' )
273270# b = np.random.sample( (n,) )
274271#
275272# t0 = time.time()
276- # for k in xrange (r):
273+ # for k in range (r):
277274# x = numpy_solve(A, b)
278275# results1[j] = (time.time() - t0) / r
279276#
280277# t0 = time.time()
281- # for k in xrange (r):
278+ # for k in range (r):
282279# A2 = A.copy(order='F')
283280# x2 = b.copy()
284281# drivers.symmetric(A2, x2)
285282# results2[j] = (time.time() - t0) / r
286283#
287284# t0 = time.time()
288- # for k in xrange (r):
285+ # for k in range (r):
289286# A3 = A.copy(order='F')
290287# x3 = b.copy()
291288# drivers.general(A3, x3)
@@ -294,27 +291,26 @@ def main():
294291
295292 # visualize
296293
297- pl .figure (1 )
298- pl .clf ()
294+ plt .figure (1 )
295+ plt .clf ()
299296 if use_numpy :
300- pl .loglog (sizes , results1 , 'k-' , label = 'NumPy' )
301- pl .loglog (sizes , results2 , 'b--' , label = 'dsysv, same LHS, many RHS' )
302- pl .loglog (sizes , results3 , 'b-' , label = 'dgesv, same LHS, many RHS' )
303- pl .loglog (sizes , results4 , 'r--' , label = 'dsysv, independent problems' )
304- pl .loglog (sizes , results5 , 'r-' , label = 'dgesv, independent problems' )
305- pl .loglog (sizes , results6 , 'g--' , label = 'dsytrf+dsytrs, independent problems' )
306- pl .loglog (sizes , results7 , 'g-' , label = 'dgetrf+dgetrs, independent problems' )
307- pl .xlabel ('n' )
308- pl .ylabel ('t' )
309- pl .title ('Average time per problem instance, %d parallel tasks' % (ntasks ))
310- pl .axis ('tight' )
311- pl .grid (b = True , which = 'both' )
312- pl .legend (loc = 'best' )
313-
314- pl .savefig ('figure1_latest.pdf' )
297+ plt .loglog (sizes , results1 , 'k-' , label = 'NumPy' )
298+ plt .loglog (sizes , results2 , 'b--' , label = 'dsysv, same LHS, many RHS' )
299+ plt .loglog (sizes , results3 , 'b-' , label = 'dgesv, same LHS, many RHS' )
300+ plt .loglog (sizes , results4 , 'r--' , label = 'dsysv, independent problems' )
301+ plt .loglog (sizes , results5 , 'r-' , label = 'dgesv, independent problems' )
302+ plt .loglog (sizes , results6 , 'g--' , label = 'dsytrf+dsytrs, independent problems' )
303+ plt .loglog (sizes , results7 , 'g-' , label = 'dgetrf+dgetrs, independent problems' )
304+ plt .xlabel ('n' )
305+ plt .ylabel ('t' )
306+ plt .title ('Average time per problem instance, %d parallel tasks' % (ntasks ))
307+ plt .axis ('tight' )
308+ plt .grid (b = True , which = 'both' )
309+ plt .legend (loc = 'best' )
310+
311+ plt .savefig ('figure1_latest.pdf' )
315312
316313
317314if __name__ == '__main__' :
318315 main ()
319- pl .show ()
320-
316+ plt .show ()
0 commit comments