文章详情

短信预约-IT技能 免费直播动态提醒

请输入下面的图形验证码

提交验证

短信预约提醒成功

如何从PyTorch中获取过程特征图实例详解

2023-01-10 12:03

关注

一、获取Tensor

神经网络在运算过程中实际上是以Tensor为格式进行计算的,我们只需稍稍改动一下forward函数即可从运算过程中抓到Tensor

代码如下:

base_feature = self.extractor.forward(x)    #正常的前向传递
feature=base_feature.detach()               #抓取tensor
feature_imshow(feature)                     #展示函数(关键代码)

通过将过程张量赋值给一个临时变量,即可将其从前向传递中分离出来且不影响原来的前向传递函数,这种方法远比复杂的hook函数更实用。

将Tensor数据取到后到可视化还需要进行以下几步:

①类型转换

如果网络是在cuda中进行运算,则需要将提取到的tensor转换为cpu类型才能进行接下来的运算

inp = inp.cpu()        #类型转换

②张量拆解

网络中的张量一般是高维度的,需要对其进行降维,一般降至两维即可进行显示。这里以Faster R-CNN中的resnet50特征提取网络为例:输出其特征图尺寸为:[1,1024,68,38],可以很明显的看出,第一维实际上是batch_size,在图像显示中不需要,可以直接去除;第二维1024则是网络提取到的特征图张数,故可以对第二维进行遍历;而第3,4维是特征图的尺寸,直接显示即可。

inp=inp.squeeze(0)    #除去第一维
 
for i in range(len(inp)):
    plt.imshow(transforms.ToPILImage()(inp[i]))    #遍历第二维并将其转换为图像

③图像展示

选取你需要的特征图像,进行保存或使用plt进展示

完整的展示函数如下:

def feature_imshow(inp, title=None):
    inp = inp.cpu()
    inp=inp.squeeze(0)
    print(inp.shape)
    plt.figure(figsize=(12, 7))
    for i in range(len(inp)):
        plt.subplot(4, 5, i+1)    #第一二个参数为图像个数,第三参数为图像位置
        plt.imshow(transforms.ToPILImage()(inp[i]))
        i+=1
    plt.show()
    plt.pause(0.001)

总结

到此这篇关于如何从PyTorch中获取过程特征图的文章就介绍到这了,更多相关PyTorch获取过程特征图内容请搜索编程网以前的文章或继续浏览下面的相关文章希望大家以后多多支持编程网!

阅读原文内容投诉

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

软考中级精品资料免费领

  • 历年真题答案解析
  • 备考技巧名师总结
  • 高频考点精准押题
  • 2024年上半年信息系统项目管理师第二批次真题及答案解析(完整版)

    难度     807人已做
    查看
  • 【考后总结】2024年5月26日信息系统项目管理师第2批次考情分析

    难度     351人已做
    查看
  • 【考后总结】2024年5月25日信息系统项目管理师第1批次考情分析

    难度     314人已做
    查看
  • 2024年上半年软考高项第一、二批次真题考点汇总(完整版)

    难度     433人已做
    查看
  • 2024年上半年系统架构设计师考试综合知识真题

    难度     221人已做
    查看

相关文章

发现更多好内容

猜你喜欢

AI推送时光机
位置:首页-资讯-后端开发
咦!没有更多了?去看看其它编程学习网 内容吧
首页课程
资料下载
问答资讯