文章详情

短信预约-IT技能 免费直播动态提醒

请输入下面的图形验证码

提交验证

短信预约提醒成功

Pytorch中的Broadcasting问题

2023-01-03 12:01

关注

Numpy、Pytorch中的broadcasting

写在前面

自己一直都不清楚numpy、pytorch里面不同维数的向量之间的element wise的计算究竟是按照什么规则来确认维数匹配和不匹配的情况的,比如

>>> b = np.ones((4,5))
>>> a = np.arange(5)
>>> c = a + b
>>> c.shape
(4, 5)
>>> c
array([[1., 2., 3., 4., 5.],
       [1., 2., 3., 4., 5.],
       [1., 2., 3., 4., 5.],
       [1., 2., 3., 4., 5.]])

上面这种情况就会自动让a和b的维数匹配,a加到了b的每一行上

>>> b = np.ones((5,4))
>>> a = np.arange(5)
>>> c = a + b
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
ValueError: operands could not be broadcast together with shapes (5,) (5,4)

这种情况就无法匹配,此时我们希望的是a能自动加到b的每一列上,但结果看来好像不行

虽然一直存在这种疑惑,但因为平时遇到的各种运算都比较简单,遇到这种不是直接匹配的array的加法第一直觉就是去console里面试一试,报错就换个姿势再试一试,总归问题可以快速地解决,但是最近在写模型的时候,遇到了绕不过去的问题,所以去查了文档,本文就以解决那个问题为目标,来解释清楚pytorch(numpy也是一样)中的broadcasting semantics的问题

问题描述

我有一个数据Tensor,维数是64 × 2048 64\times204864×2048,现在我想通过对这64 6464个2048 20482048维的向量做attention(也就是做一个加权和)来得到一个2048 20482048维的向量,因为模型的需要,我需要用五组不同的权值向量来计算出五个不同的加权结果,也就是我的计算结果应该是一个5 × 2048 5\times 20485×2048维的向量,因为在64 6464个向量上加权,所以一组权值向量是64 6464维,五组就是5 × 64 5\times 645×64维

尝试解决

现在我手头上有两个Tensor,一个是数据Tensor(64 × 2048 64\times 204864×2048)另一个是权值Tensor(5 × 64 5\times 645×64),我GAN!直到我写到了这里,我才发现这不是一个矩阵乘法就能解决的问题嘛+_+,当然,我想给自己正名,这里我简化了一下问题所以才发现原来这么容易就解决了,而原来我在写代码的时候因为还要考虑batch_size等问题才云里雾里不知道咋办,还好当时没想出来,所以去查了文档发现了新的东西,然后写文章的时候想到也算是完满了(不然也不会发现自己好涝)

以上都是题外话,现在,我们还是考虑用愚蠢的element wise的方法来解决,好在现在有两种方法可以解决问题,所以我们可以用来相互检验一下,element wise的解决方法就是,我希望这5个64维的权值向量分别和这64个2048维的向量进行element wise的乘法,也就是第一个64维权值向量先对64个2048维向量加权得到一个2048维的向量,然后第二个64维权值向量先对64个2048维向量加权得到一个2048维的向量…,以此类推总共五个,最终得到五个64 × 2048 64×204864×2048维的向量,然后求和得到最后的5 × 2048 5×20485×2048维的向量

那么按照平常的习惯,我就去先试试pytorch能不能直接地理解我的想法

>>> import torch
>>> bs = 10 # batch_size
>>> x = torch.randn(bs,64,2048)
>>> att = torch.randn(5,64)
>>> out = att * x
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
RuntimeError: The size of tensor a (64) must match the size of tensor b (2048) at non-singleton dimension 2

直接乘不行,因为维数是不匹配的,那怎样的维数才算匹配呢?

BROADCASTING SEMANTICS

以下内容主要来源于自官方文档

很多pytorch的运算是支持broadcasting semantics的,而简单来说,如果运算支持broadcast,则参与运算的Tensor会自动进行扩展来使得运算符左右的Tensor维数匹配,而无需人手动地去拷贝其中的某个Tensor,这就类似于我们开头的那个例子

>>> b = np.ones((4,5))
>>> a = np.arange(5)
>>> c = a + b
>>> c.shape
(4, 5)
>>> c
array([[1., 2., 3., 4., 5.],
       [1., 2., 3., 4., 5.],
       [1., 2., 3., 4., 5.],
       [1., 2., 3., 4., 5.]])

我们无需让a的维数和b一样,因为numpy自动帮我们做了

这里的另一个重要的概念是broadcastable,如果两个Tensor是broadcastable的,那么就可以对他俩使用支持broadcast的运算,比如直接加减乘除

而两个向量要是broadcast的话,必须满足以下两个条件

这是官方的例子解释

>>> x=torch.empty(5,7,3)
>>> y=torch.empty(5,7,3)
# 相同维数的tensor一定是broadcastable的

>>> x=torch.empty((0,))
>>> y=torch.empty(2,2)
# 不是broadcastable的,因为每个tensor维数至少要是1

>>> x=torch.empty(5,3,4,1)
>>> y=torch.empty(  3,1,1)
# 是broadcastable的,因为从后往前看,一定要注意是从后往前看!
# 第一个维度都是1,相等,满足第二个条件
# 第二个维度其中有一个是1,满足第二个条件
# 第三个维度都是3,相等,满足第二个条件
# 第四个维度其中有一个不存在,满足第二个条件

# 但是
>>> x=torch.empty(5,2,4,1)
>>> y=torch.empty(  3,1,1)
# 不是broadcastable的,因为从后往前看第三个维度是不match的 2!=3,且都不是1

如果x和y是broadcastable的,那么结果的tensor的size按照如下的规则计算

官方的例子

>>> x=torch.empty(5,1,4,1)
>>> y=torch.empty(  3,1,1)
>>> (x+y).size()
torch.Size([5, 3, 4, 1])

>>> x=torch.empty(1)
>>> y=torch.empty(3,1,7)
>>> (x+y).size()
torch.Size([3, 1, 7])

>>> x=torch.empty(5,2,4,1)
>>> y=torch.empty(3,1,1)
>>> (x+y).size()
RuntimeError: The size of tensor a (2) must match the size of tensor b (3) at non-singleton dimension 1

此外,关于broadcast导致的就地(in-place)操作和梯度运算的兼容性等问题,可以自行参考官方文档

解决问题

上面我们看到,要想两个Tensor支持element wise的运算,需要它们是broadcastable的,而要想它们是broadcastable的,就需要它们的维度自后向前逐一匹配,回到我们原来的问题中,我们有两个Tensor x(64 × 2048) att(5 × 64),为了让它们broadcastable,我们只需要

>>> import torch
>>> bs = 10 # batch_size
>>> x = torch.randn(bs,64,2048)
>>> att = torch.randn(5,64)
>>> x = x.unsqueeze(1)
>>> att = att.view(1,*att.shape,1)
>>> x.shape
torch.Size([10, 1, 64, 2048])
>>> att.shape
torch.Size([1, 5, 64, 1])
>>> out = x * att
>>> out.shape
torch.Size([10, 5, 64, 2048])

最后我们来验证两种方法是否结果相同

>>> import torch
>>> bs = 10 
>>> x = torch.randn(bs,64,2048)
>>> att = torch.randn(5,64)
>>> out1 = torch.matmul(att,x)  # 直接矩阵相乘
>>> out.shape
torch.Size([10, 5, 2048])

>>> x = x.unsqueeze(1)
>>> att = att.view(1,*att.shape,1)
>>> out2 = x * att  # element wise的方法
>>> out2 = out2.sum(dim=2)

>>> test = torch.sum((out1-out2)<0.00001)  # 浮点数有微小的误差
>>> test
tensor(102400)
>>> out1.numel()  # 最后表明两个out向量是相等的
102400

Reference

[1] https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html#module-numpy.doc.broadcasting

[2] https://pytorch.org/docs/stable/notes/broadcasting.html#broadcasting-semantics

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持编程网。

阅读原文内容投诉

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

软考中级精品资料免费领

  • 历年真题答案解析
  • 备考技巧名师总结
  • 高频考点精准押题
  • 2024年上半年信息系统项目管理师第二批次真题及答案解析(完整版)

    难度     813人已做
    查看
  • 【考后总结】2024年5月26日信息系统项目管理师第2批次考情分析

    难度     354人已做
    查看
  • 【考后总结】2024年5月25日信息系统项目管理师第1批次考情分析

    难度     318人已做
    查看
  • 2024年上半年软考高项第一、二批次真题考点汇总(完整版)

    难度     435人已做
    查看
  • 2024年上半年系统架构设计师考试综合知识真题

    难度     224人已做
    查看

相关文章

发现更多好内容

猜你喜欢

AI推送时光机
位置:首页-资讯-后端开发
咦!没有更多了?去看看其它编程学习网 内容吧
首页课程
资料下载
问答资讯