如何获得比numpy.dot更快的代码进行矩阵乘法?

如何获得比numpy.dot更快的代码进行矩阵乘法?,第1张

如何获得比numpy.dot更快的代码进行矩阵乘法

np.dot
何时分派到BLAS

  • NumPy已被编译为使用BLAS,
  • 在运行时可以使用BLAS实现,
  • 你的数据有dtypes之一
    float32
    float64
    complex32
    complex64
  • 数据在内存中适当对齐。

否则,它默认使用自己的慢速矩阵乘法例程。

这里介绍了检查BLAS链接的方法。简而言之,请检查

_dotblas.so
您的NumPy安装中是否有文件或类似文件。如果存在,请检查链接到哪个BLAS库;参考BLAS较慢,ATLAS较快,OpenBLAS和特定于供应商的版本(例如Intel
MKL)甚至更快。当心多线程BLAS实现,因为它们与Python的配合不好
multiprocessing

接下来,通过检查

flags
数组的来检查数据对齐方式。在1.7.2之前的NumPy版本中,to的两个参数都
np.dot
应按C顺序排列。在NumPy>
= 1.7.2中,这不再重要,因为已经引入了Fortran数组的特殊情况。

>>> X = np.random.randn(10, 4)>>> Y = np.random.randn(7, 4).T>>> X.flags  C_ConTIGUOUS : True  F_ConTIGUOUS : False  OWNdata: True  WRITEABLE : True  ALIGNED : True  UPDATeIFCOPY : False>>> Y.flags  C_ConTIGUOUS : False  F_ConTIGUOUS : True  OWNdata: False  WRITEABLE : True  ALIGNED : True  UPDATEIFCOPY : False

如果您的NumPy没有与BLAS链接,请(轻松)重新安装它,或(硬)使用

gemm
SciPy的BLAS (通用矩阵乘法)功能:

>>> from scipy.linalg import get_blas_funcs>>> gemm = get_blas_funcs("gemm", [X, Y])>>> np.all(gemm(1, X, Y) == np.dot(X, Y))True

这看起来很容易,但是几乎不会进行任何错误检查,因此您必须真正知道自己在做什么。



欢迎分享,转载请注明来源:内存溢出

原文地址: https://outofmemory.cn/zaji/5652949.html

(0)
打赏 微信扫一扫 微信扫一扫 支付宝扫一扫 支付宝扫一扫
上一篇 2022-12-16
下一篇 2022-12-16

发表评论

登录后才能评论

评论列表(0条)

保存