文章详情

短信预约-IT技能 免费直播动态提醒

请输入下面的图形验证码

提交验证

短信预约提醒成功

PyTorch中的matmul函数详解

2023-09-07 06:08

关注

PyTorch中的两个张量的乘法可以分为两种:

  1. 两个张量对应的元素相乘(element-wise),在PyTorch中可以通过torch.mul函数(或者 ∗ * 运算符)实现

  2. 两个张量矩阵相乘(Matrix product),在PyTorch中可以通过torch.matmul函数实现

本文主要介绍两个张量的矩阵相乘。

语法为:

torch.matmul(input, other, out = None)

函数对input和other两个张量进行矩阵相乘。为了方便后续的讲解,将input记为a,将other记为b。

点积在数学中,又称数量积,是指接受在实数R上的两个1D张量并返回一个实数值0D张量的二元运算。
若1D张量a=[1,2],1D张量b=[3,4],则:
a ⋅ \cdot b=1 × \times × 3 + 2 × \times × 4 = 11

  1. 若a为1D张量,b为1D张量,则返回两个张量的点积,则返回两个张量的点积(此时的torch.matmul不支持out参数)

举例如下:

import torcha = torch.tensor([1, 2])b = torch.tensor([3, 4])result = torch.matmul(a, b)print(result)

结果为:

(PyTorch) D:\Code Project>D:/Anaconda/envs/PyTorch/python.exe "d:/Code Project/demo.py"tensor(11)
  1. 若a为2D张量,b为2D张量,则返回两个张量的矩阵乘积。

矩阵相乘最重要的方法是一般矩阵乘积,它只有在第一个2D张量(矩阵)的列数(column)和第二个2D张量(矩阵)的行数(row)相同时才有意义。
若2D张量a=[[1,2],[3,4]],2D张量b=[[5,6,7],[8,9,10]],则:
a × \times × b=[[21,24,27],[47,54,61]],2D张量a的形状为(2,2),而2D张量b的形状(2,3)。矩阵乘积的运算规则:
在这里插入图片描述

举例为:

import torcha = torch.tensor([[1, 2],[3,4]])b = torch.tensor([[5,6,7],[8,9,10]])result = torch.matmul(a, b)print(result)

结果展示为:

(PyTorch) D:\Code Project>D:/Anaconda/envs/PyTorch/python.exe "d:/Code Project/demo.py"tensor([[21, 24, 27],        [47, 54, 61]])
  1. 若a为1D张量,b为2D张量,torch.matmul函数:

首先,在1D张量a的前面插入一个长度为1的新维度变成2D张量;

然后,在满足第一个2D张量(矩阵)的列数(column)和第二个2D张量(矩阵)的行数(row)相同的条件下,两个2D张量矩阵乘积,否则会抛出错误;

最后,将矩阵乘积结果中长度为1的维度(前面插入的长度为1的新维度)删除作为最终torch.matmul函数返回的结果。

import torcha = torch.tensor([1, 2])b = torch.tensor([[5, 6, 7],[8, 9, 10]])result = torch.matmul(a, b)print(result, result.shape)

结果为:

(PyTorch) D:\Code Project>D:/Anaconda/envs/PyTorch/python.exe "d:/Code Project/demo.py"tensor([21, 24, 27]) torch.Size([3])

简单来说,先将1D张量a扩展成2D张量,满足矩阵乘积的条件下,将两个2D张量进行矩阵乘积的运算。
在这里插入图片描述
此时得到的形状是(1,3)的2D张量,最后将前面插入长度为1的新维度删除即为最终torch.matmul(a, b)函数返回的结果。

  1. 若a为2D张量,b为1D张量,torch.matmul函数:

首先,在1D张量b的后面插入一个长度为1的新维度变成2D张量;

然后,在满足第一个2D张量(矩阵)的列数(column)和第二个2D张量(矩阵)的行数(row)相同的条件下,两个2D张量矩阵乘积,否则会抛出错误;

最后,将矩阵乘积结果中长度为1的维度(后面插入的长度为1的新维度)删除作为最终torch.matmul函数返回的结果;

import torchb = torch.tensor([1, 2, 3])a = torch.tensor([[5, 6, 7],[8, 9, 10]])result = torch.matmul(a, b)print(result, result.shape)

结果展示为:

(PyTorch) D:\Code Project>D:/Anaconda/envs/PyTorch/python.exe "d:/Code Project/demo.py"tensor([38, 56]) torch.Size([2])

其中:

38 = 15+26+3*7

56 = 18+29+3*10

来源地址:https://blog.csdn.net/wzk4869/article/details/127932435

阅读原文内容投诉

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

软考中级精品资料免费领

  • 历年真题答案解析
  • 备考技巧名师总结
  • 高频考点精准押题
  • 2024年上半年信息系统项目管理师第二批次真题及答案解析(完整版)

    难度     801人已做
    查看
  • 【考后总结】2024年5月26日信息系统项目管理师第2批次考情分析

    难度     348人已做
    查看
  • 【考后总结】2024年5月25日信息系统项目管理师第1批次考情分析

    难度     311人已做
    查看
  • 2024年上半年软考高项第一、二批次真题考点汇总(完整版)

    难度     432人已做
    查看
  • 2024年上半年系统架构设计师考试综合知识真题

    难度     220人已做
    查看

相关文章

发现更多好内容

猜你喜欢

AI推送时光机
位置:首页-资讯-后端开发
咦!没有更多了?去看看其它编程学习网 内容吧
首页课程
资料下载
问答资讯