@@ -440,10 +440,30 @@ def set_unsafe(self, name):
440
440
raise RuntimeError ("Backend not found" )
441
441
self .__name = name
442
442
443
- def __init__ (self ):
444
-
443
+ def _loadlibs (self ):
444
+ """
445
+ function that loads ArrayFire upstream libraries
446
+ """
445
447
more_info_str = "Please look at https://github.com/arrayfire/arrayfire-python/wiki for more information."
446
448
449
+ # Iterate in reverse order of preference
450
+ for name in ('cpu' , 'opencl' , 'cuda' , '' ):
451
+ libnames = self .__libname (name )
452
+ for libname in libnames :
453
+ try :
454
+ ct .cdll .LoadLibrary (libname )
455
+ __name = 'unified' if name == '' else name
456
+ self .__clibs [__name ] = ct .CDLL (libname )
457
+ self .__name = __name
458
+ break ;
459
+ except :
460
+ pass
461
+
462
+ if (self .__name is None ):
463
+ raise RuntimeError ("Could not load any ArrayFire libraries.\n " +
464
+ more_info_str )
465
+
466
+ def __init__ (self ):
447
467
pre , post , AF_PATH , CUDA_FOUND = _setup ()
448
468
449
469
self .__pre = pre
@@ -468,7 +488,6 @@ def __init__(self):
468
488
'cpu' : 1 ,
469
489
'cuda' : 2 ,
470
490
'opencl' : 4 }
471
-
472
491
# Try to pre-load forge library if it exists
473
492
libnames = self .__libname ('forge' , '' )
474
493
for libname in libnames :
@@ -477,30 +496,15 @@ def __init__(self):
477
496
except :
478
497
pass
479
498
480
- # Iterate in reverse order of preference
481
- for name in ('cpu' , 'opencl' , 'cuda' , '' ):
482
- libnames = self .__libname (name )
483
- for libname in libnames :
484
- try :
485
- ct .cdll .LoadLibrary (libname )
486
- __name = 'unified' if name == '' else name
487
- self .__clibs [__name ] = ct .CDLL (libname )
488
- self .__name = __name
489
- break ;
490
- except :
491
- pass
492
-
493
- if (self .__name is None ):
494
- raise RuntimeError ("Could not load any ArrayFire libraries.\n " +
495
- more_info_str )
496
-
497
499
def get_id (self , name ):
498
500
return self .__backend_name_map [name ]
499
501
500
502
def get_name (self , bk_id ):
501
503
return self .__backend_map [bk_id ]
502
504
503
505
def get (self ):
506
+ if (self .__clibs [self .__name ] is None ):
507
+ self ._loadlibs ()
504
508
return self .__clibs [self .__name ]
505
509
506
510
def name (self ):
0 commit comments