文章详情

短信预约-IT技能 免费直播动态提醒

请输入下面的图形验证码

提交验证

短信预约提醒成功

Pytorch如何实现常用乘法算子TensorRT

2023-06-30 18:26

关注

这篇文章主要介绍了Pytorch如何实现常用乘法算子TensorRT的相关知识,内容详细易懂,操作简单快捷,具有一定借鉴价值,相信大家阅读完这篇Pytorch如何实现常用乘法算子TensorRT文章都会有所收获,下面我们一起来看看吧。

1.乘法运算总览

先把 pytorch 中的一些常用的乘法运算进行一个总览:

如上进行了一些具体罗列,可以归纳出,常用的乘法无非两种:矩阵乘 和 点乘,所以下面分这两类进行介绍。

2.乘法算子实现

2.1矩阵乘算子实现

先来看看矩阵乘法的 pytorch 的实现 (以下实现在终端):

>>> import torch>>> # torch.mm>>> a = torch.randn(66, 99)>>> b = torch.randn(99, 88)>>> c = torch.mm(a, b)>>> c.shapetorch.size([66, 88])>>>>>> # torch.bmm>>> a = torch.randn(3, 66, 99)>>> b = torch.randn(3, 99, 77)>>> c = torch.bmm(a, b)>>> c.shapetorch.size([3, 66, 77])>>>>>> # torch.mv>>> a = torch.randn(66, 99)>>> b = torch.randn(99)>>> c = torch.mv(a, b)>>> c.shapetorch.size([66])>>>>>> # torch.matmul>>> a = torch.randn(32, 3, 66, 99)>>> b = torch.randn(32, 3, 99, 55)>>> c = torch.matmul(a, b)>>> c.shapetorch.size([32, 3, 66, 55])>>>>>> # @>>> d = a @ b>>> d.shapetorch.size([32, 3, 66, 55])

来看 TensorRT 的实现,以上乘法都可使用 addMatrixMultiply 方法覆盖,对应 torch.matmul,先来看该方法的定义:

//!//! \brief Add a MatrixMultiply layer to the network.//!//! \param input0 The first input tensor (commonly A).//! \param op0 The operation to apply to input0.//! \param input1 The second input tensor (commonly B).//! \param op1 The operation to apply to input1.//!//! \see IMatrixMultiplyLayer//!//! \warning Int32 tensors are not valid input tensors.//!//! \return The new matrix multiply layer, or nullptr if it could not be created.//!IMatrixMultiplyLayer* addMatrixMultiply(  ITensor& input0, MatrixOperation op0, ITensor& input1, MatrixOperation op1) noexcept{  return mImpl->addMatrixMultiply(input0, op0, input1, op1);}

可以看到这个方法有四个传参,对应两个张量和其 operation。来看这个算子在 TensorRT 中怎么添加:

// 构造张量 Tensor0nvinfer1::IConstantLayer *Constant_layer0 = m_network->addConstant(tensorShape0, value0);// 构造张量 Tensor1nvinfer1::IConstantLayer *Constant_layer1 = m_network->addConstant(tensorShape1, value1);// 添加矩阵乘法nvinfer1::IMatrixMultiplyLayer *Matmul_layer = m_network->addMatrixMultiply(Constant_layer0->getOutput(0), matrix0Type, Constant_layer1->getOutput(0), matrix2Type);// 获取输出matmulOutput = Matmul_layer->getOputput(0);

2.2点乘算子实现

再来看看点乘的 pytorch 的实现 (以下实现在终端):

>>> import torch>>> # torch.mul>>> a = torch.randn(66, 99)>>> b = torch.randn(66, 99)>>> c = torch.mul(a, b)>>> c.shapetorch.size([66, 99])>>> d = 0.125>>> e = torch.mul(a, d)>>> e.shapetorch.size([66, 99])>>> # *>>> f = a * b>>> f.shapetorch.size([66, 99])

来看 TensorRT 的实现,以上乘法都可使用 addScale 方法覆盖,这在图像预处理中十分常用,先来看该方法的定义:

//!//! \brief Add a Scale layer to the network.//!//! \param input The input tensor to the layer.//!              This tensor is required to have a minimum of 3 dimensions in implicit batch mode//!              and a minimum of 4 dimensions in explicit batch mode.//! \param mode The scaling mode.//! \param shift The shift value.//! \param scale The scale value.//! \param power The power value.//!//! If the weights are available, then the size of weights are dependent on the ScaleMode.//! For ::kUNIFORM, the number of weights equals 1.//! For ::kCHANNEL, the number of weights equals the channel dimension.//! For ::kELEMENTWISE, the number of weights equals the product of the last three dimensions of the input.//!//! \see addScaleNd//! \see IScaleLayer//! \warning Int32 tensors are not valid input tensors.//!//! \return The new Scale layer, or nullptr if it could not be created.//!IScaleLayer* addScale(ITensor& input, ScaleMode mode, Weights shift, Weights scale, Weights power) noexcept{  return mImpl->addScale(input, mode, shift, scale, power);}

 可以看到有三个模式:

再来看这个算子在 TensorRT 中怎么添加:

// 构造张量 inputnvinfer1::IConstantLayer *Constant_layer = m_network->addConstant(tensorShape, value);// scalemode选择,kUNIFORM、kCHANNEL、kELEMENTWISEscalemode = kUNIFORM;// 构建 Weights 类型的 shift、scale、power,其中 volume 为元素数量nvinfer1::Weights scaleShift{nvinfer1::DataType::kFLOAT, nullptr, volume };nvinfer1::Weights scaleScale{nvinfer1::DataType::kFLOAT, nullptr, volume };nvinfer1::Weights scalePower{nvinfer1::DataType::kFLOAT, nullptr, volume };// !! 注意这里还需要对 shift、scale、power 的 values 进行赋值,若只是乘法只需要对 scale 进行赋值就行// 添加张量乘法nvinfer1::IScaleLayer *Scale_layer = m_network->addScale(Constant_layer->getOutput(0), scalemode, scaleShift, scaleScale, scalePower);// 获取输出scaleOutput = Scale_layer->getOputput(0);

有一点你可能会比较疑惑,既然是点乘,那么输入只需要两个张量就可以了,为啥这里有 input、shift、scale、power 四个张量这么多呢。解释一下,input 不用说,就是输入张量,而 shift 表示加法参数、scale 表示乘法参数、power 表示指数参数,说到这里,你应该能发现,这个函数除了我们上面讲的点乘外还有其他更加丰富的运算功能。

关于“Pytorch如何实现常用乘法算子TensorRT”这篇文章的内容就介绍到这里,感谢各位的阅读!相信大家对“Pytorch如何实现常用乘法算子TensorRT”知识都有一定的了解,大家如果还想学习更多知识,欢迎关注编程网行业资讯频道。

阅读原文内容投诉

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

软考中级精品资料免费领

  • 历年真题答案解析
  • 备考技巧名师总结
  • 高频考点精准押题
  • 2024年上半年信息系统项目管理师第二批次真题及答案解析(完整版)

    难度     813人已做
    查看
  • 【考后总结】2024年5月26日信息系统项目管理师第2批次考情分析

    难度     354人已做
    查看
  • 【考后总结】2024年5月25日信息系统项目管理师第1批次考情分析

    难度     318人已做
    查看
  • 2024年上半年软考高项第一、二批次真题考点汇总(完整版)

    难度     435人已做
    查看
  • 2024年上半年系统架构设计师考试综合知识真题

    难度     224人已做
    查看

相关文章

发现更多好内容

猜你喜欢

AI推送时光机
位置:首页-资讯-后端开发
咦!没有更多了?去看看其它编程学习网 内容吧
首页课程
资料下载
问答资讯