文章详情

短信预约-IT技能 免费直播动态提醒

请输入下面的图形验证码

提交验证

短信预约提醒成功

pytorch中矩阵乘法和数组乘法怎么实现

2023-07-05 17:06

关注

本篇内容介绍了“pytorch中矩阵乘法和数组乘法怎么实现”的有关知识,在实际案例的操作过程中,不少人都会遇到这样的困境,接下来就让小编带领大家学习一下如何处理这些情况吧!希望大家仔细阅读,能够学有所成!

一、torch.mul

该乘法可简单理解为矩阵各位相乘,一个常见的例子为向量点乘,源码定义为torch.mul(input,other,out=None)。其中other可以为一个数也可以为一个张量,other为数即张量的数乘。

该函数可触发广播机制(broadcast)。只要mat1与other满足broadcast条件,就可可以进行逐元素相乘 。

tensor1 = 2*torch.ones(1,4)tensor2 = 3*torch.ones(4,1)print(torch.mul(tensor1, tensor2))#输出结果为:tensor([[6., 6., 6., 6.],        [6., 6., 6., 6.],        [6., 6., 6., 6.],        [6., 6., 6., 6.]])
# 生成指定张量c = torch.Tensor([[1, 2, 3], [4, 5 ,6]])print(c.shape)  # 2*3print(c) # 生成随机张量d = torch.randn(2,2,3) print(d)print(d.shape)  # 2*2*3 mul = torch.mul(c, d) # c会自动broadcast和d进行匹配print(mul.shape)      # 2*2*3print(mul)

二、torch.mm

该函数一般只能用来计算两个二维矩阵的矩阵乘法,而且不支持broadcast操作。该函数源码定义为torch.mm(input,mat2,out=None) ,参数与返回值均为tensor形式。

a=torch.ones(4,3)  b=2*torch.ones(3,2)  c=torch.empty(4,2)  torch.mm(a,b,out=c)  print(torch.mm(a,b))  print( c )#输出结果为tensor([[6., 6.],        [6., 6.],        [6., 6.],        [6., 6.]])tensor([[6., 6.],        [6., 6.],        [6., 6.],        [6., 6.]])

三、torch.matmul

这个矩阵乘法是在torch.mm的基础上增加了广播机制,源码定义为torch.matmul(input,other,out=None)。

其基本运算规则如下:

如果两个参数都为一维,则等价于torch.mul,需要注意的是:此时的out不接受任何参数

如果两个张量都为二维且符合矩阵相乘规则,或第一个参数为一维(长度为m,这里等价为大小为1* m),第二个参数为二维(大小为m* n)则运算等价于torch.mm

如果第一个参数为二维(大小m* n),第二个参数为一维(长度为n),这里第二个参数会进行转置成为n* 1的列向量,随后进行矩阵相乘,将得到的结果再进行转置,最终返回一个大小为1* m的向量

tensor1 = torch.tensor([[1,1,1,1],[2,2,2,2],[3,3,3,3]],dtype=torch.float32)tensor2 = torch.ones(4)print(tensor1.size())print(tensor2.size())print(torch.matmul(tensor1, tensor2).shape)#输出结果为:torch.Size([3, 4])torch.Size([4])torch.Size([3])

还有一种情况就是任意一个参数至少为3维, 当前面的维度相同且最后两个维度符合二维矩阵运算规则可进行计算,例如第一参数的大小为a* b * c * m,第二个参数的大小为a* b* m* d,则返回一个大小为a* b* c * d的张量,可触发广播机制。

tensor1 = torch.ones(1,4,3,2)tensor2 = torch.ones(2,6)print(torch.matmul(tensor1, tensor2).size())#输出结果为:torch.Size([1, 4, 3, 6])

四、三维带Batch矩阵乘法 torch.bmm()

torch.bmm(bmat1,bmat2), 其中bmat1(B×n×m),bmat2(B×m×d)输出out的维度是B×n×d,该函数两个输入必须三维矩阵中的第一维要要相同,不支持broadCast操作。

五、torch中tensor数组的广播计算

首先定义两个张量,x的形状是[1,2,1],y的形状是[1,2,2]。

当x与y相乘时,由于x.size(2)不等于y.size(2),x会被扩展为[1,2,2]形状,然后再与张量y进行乘法运算。

x = torch.rand(1,2,1)y = torch.rand(1,2,2)

pytorch中矩阵乘法和数组乘法怎么实现

“pytorch中矩阵乘法和数组乘法怎么实现”的内容就介绍到这里了,感谢大家的阅读。如果想了解更多行业相关的知识可以关注编程网网站,小编将为大家输出更多高质量的实用文章!

阅读原文内容投诉

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

软考中级精品资料免费领

  • 历年真题答案解析
  • 备考技巧名师总结
  • 高频考点精准押题
  • 2024年上半年信息系统项目管理师第二批次真题及答案解析(完整版)

    难度     807人已做
    查看
  • 【考后总结】2024年5月26日信息系统项目管理师第2批次考情分析

    难度     351人已做
    查看
  • 【考后总结】2024年5月25日信息系统项目管理师第1批次考情分析

    难度     314人已做
    查看
  • 2024年上半年软考高项第一、二批次真题考点汇总(完整版)

    难度     433人已做
    查看
  • 2024年上半年系统架构设计师考试综合知识真题

    难度     221人已做
    查看

相关文章

发现更多好内容

猜你喜欢

AI推送时光机
位置:首页-资讯-后端开发
咦!没有更多了?去看看其它编程学习网 内容吧
首页课程
资料下载
问答资讯