文章详情

短信预约-IT技能 免费直播动态提醒

请输入下面的图形验证码

提交验证

短信预约提醒成功

pytorch分类模型绘制混淆矩阵及可视化的方法

2023-06-29 22:24

关注

本文小编为大家详细介绍“pytorch分类模型绘制混淆矩阵及可视化的方法”,内容详细,步骤清晰,细节处理妥当,希望这篇“pytorch分类模型绘制混淆矩阵及可视化的方法”文章能帮助大家解决疑惑,下面跟着小编的思路慢慢深入,一起来学习新知识吧。

Step 1. 获取混淆矩阵

#首先定义一个 分类数*分类数 的空混淆矩阵 conf_matrix = torch.zeros(Emotion_kinds, Emotion_kinds) # 使用torch.no_grad()可以显著降低测试用例的GPU占用    with torch.no_grad():        for step, (imgs, targets) in enumerate(test_loader):            # imgs:     torch.Size([50, 3, 200, 200])   torch.FloatTensor            # targets:  torch.Size([50, 1]),     torch.LongTensor  多了一维,所以我们要把其去掉            targets = targets.squeeze()  # [50,1] ----->  [50]            # 将变量转为gpu            targets = targets.cuda()            imgs = imgs.cuda()            # print(step,imgs.shape,imgs.type(),targets.shape,targets.type())                        out = model(imgs)            #记录混淆矩阵参数            conf_matrix = confusion_matrix(out, targets, conf_matrix)            conf_matrix=conf_matrix.cpu()

混淆矩阵的求取用到了confusion_matrix函数,其定义如下:

def confusion_matrix(preds, labels, conf_matrix):    preds = torch.argmax(preds, 1)    for p, t in zip(preds, labels):        conf_matrix[p, t] += 1    return conf_matrix

在当我们的程序执行结束 test_loader 后,我们可以得到本次数据的 混淆矩阵,接下来就要计算其 识别正确的个数以及混淆矩阵可视化:

conf_matrix=np.array(conf_matrix.cpu())# 将混淆矩阵从gpu转到cpu再转到npcorrects=conf_matrix.diagonal(offset=0)#抽取对角线的每种分类的识别正确个数per_kinds=conf_matrix.sum(axis=1)#抽取每个分类数据总的测试条数 print("混淆矩阵总元素个数:{0},测试集总个数:{1}".format(int(np.sum(conf_matrix)),test_num)) print(conf_matrix) # 获取每种Emotion的识别准确率 print("每种情感总个数:",per_kinds) print("每种情感预测正确的个数:",corrects) print("每种情感的识别准确率为:{0}".format([rate*100 for rate in corrects/per_kinds]))

执行此步的输出结果如下所示:

pytorch分类模型绘制混淆矩阵及可视化的方法

Step 2. 混淆矩阵可视化

对上边求得的混淆矩阵可视化

# 绘制混淆矩阵Emotion=8#这个数值是具体的分类数,大家可以自行修改labels = ['neutral', 'calm', 'happy', 'sad', 'angry', 'fearful', 'disgust', 'surprised']#每种类别的标签# 显示数据plt.imshow(conf_matrix, cmap=plt.cm.Blues)# 在图中标注数量/概率信息thresh = conf_matrix.max() / 2#数值颜色阈值,如果数值超过这个,就颜色加深。for x in range(Emotion_kinds):    for y in range(Emotion_kinds):        # 注意这里的matrix[y, x]不是matrix[x, y]        info = int(conf_matrix[y, x])        plt.text(x, y, info,                 verticalalignment='center',                 horizontalalignment='center',                 color="white" if info > thresh else "black")                 plt.tight_layout()#保证图不重叠plt.yticks(range(Emotion_kinds), labels)plt.xticks(range(Emotion_kinds), labels,rotation=45)#X轴字体倾斜45°plt.show()plt.close()

好了,以下就是最终的可视化的混淆矩阵啦:

pytorch分类模型绘制混淆矩阵及可视化的方法

其它分类指标的获取

例如 F1分数、TP、TN、FP、FN、精确率、召回率 等指标, 待补充哈(因为暂时还没用到)~

pytorch分类模型绘制混淆矩阵及可视化的方法

读到这里,这篇“pytorch分类模型绘制混淆矩阵及可视化的方法”文章已经介绍完毕,想要掌握这篇文章的知识点还需要大家自己动手实践使用过才能领会,如果想了解更多相关内容的文章,欢迎关注编程网行业资讯频道。

阅读原文内容投诉

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

软考中级精品资料免费领

  • 历年真题答案解析
  • 备考技巧名师总结
  • 高频考点精准押题
  • 2024年上半年信息系统项目管理师第二批次真题及答案解析(完整版)

    难度     813人已做
    查看
  • 【考后总结】2024年5月26日信息系统项目管理师第2批次考情分析

    难度     354人已做
    查看
  • 【考后总结】2024年5月25日信息系统项目管理师第1批次考情分析

    难度     318人已做
    查看
  • 2024年上半年软考高项第一、二批次真题考点汇总(完整版)

    难度     435人已做
    查看
  • 2024年上半年系统架构设计师考试综合知识真题

    难度     224人已做
    查看

相关文章

发现更多好内容

猜你喜欢

AI推送时光机
位置:首页-资讯-后端开发
咦!没有更多了?去看看其它编程学习网 内容吧
首页课程
资料下载
问答资讯