文章详情

短信预约-IT技能 免费直播动态提醒

请输入下面的图形验证码

提交验证

短信预约提醒成功

python中的torch.nn.Softmax() 用法和例子 dim=1 dim=2

2023-09-04 14:17

关注

用法

torch.nn.Softmax() 是 PyTorch 中的一个类,用于实现 softmax 函数。softmax 函数是一种常用的激活函数,它可以将一个向量转换成一个概率分布,使得每个元素都是非负数且和为 1。softmax 函数通常在分类问题中使用,可以将一个多分类问题转换成多个二分类问题,从而得到每个类别的概率分布。

语法格式

torch.nn.Softmax(dim=None)

其中,dim 是要进行 softmax 的维度,缺省值为 None,表示对最后一维进行 softmax。

例子dim=1

import torchx = torch.randn(2, 3)print('x:', x)softmax = torch.nn.Softmax(dim=1)y = softmax(x)print('y:', y)

输出

x: tensor([[ 1.3551,  0.3739,  0.5962],            [-0.3465,  1.4536,  0.4576]])y: tensor([[0.4989, 0.2238, 0.2773],            [0.1018, 0.7325, 0.1656]])

在这个例子中,我们先使用 torch.randn() 生成一个大小为 (2, 3) 的张量 x。然后,我们定义一个 torch.nn.Softmax() 对象 softmax,将维度 dim=1 作为参数传入。接着,我们将张量 x 作为输入,调用 softmax() 方法,得到一个大小为 (2, 3) 的张量 y,表示经过 softmax 函数处理后的结果。可以看到,每行元素都是非负数且和为 1。

需要注意的是,torch.nn.Softmax() 在实际使用中通常与交叉熵损失函数一起使用,用于多分类问题的训练。

例子dim=2

dim=2 表示在第二个维度上进行 softmax 计算。

import torch# 创建一个3D张量,形状为(2, 3, 4)x = torch.randn(2, 3, 4)# 使用dim=2进行softmax计算softmax = torch.nn.Softmax(dim=2)y = softmax(x)print("Original tensor:")print(x)print("\nSoftmax tensor:")print(y)

输出

Original tensor:tensor([[[ 0.4769, -0.1835, -0.3167, -1.1385],         [-0.5912,  0.4781, -0.6784, -0.4377],         [-0.9624, -0.0528, -1.4899, -1.5107]],        [[ 0.1033, -0.0107, -0.4888, -1.5489],         [ 0.4071,  0.2163, -0.3167, -0.1252],         [-1.7984, -1.1394, -1.5384, -0.3176]]])Softmax tensor:tensor([[[0.4669, 0.1745, 0.1527, 0.2060],         [0.1668, 0.4647, 0.1311, 0.2374],         [0.3005, 0.5028, 0.1452, 0.0515]],        [[0.4474, 0.2594, 0.1248, 0.1684],         [0.3616, 0.2983, 0.1426, 0.1975],         [0.1055, 0.1555, 0.1084, 0.6307]]])

可以看到,原始张量中的每个值都经过了 softmax 计算,第二个维度上的值都被归一化到了 0 到 1 之间,并且在每个样本上的值之和都为 1。

总结

当张量的形状为二维时,dim=1 和 dim=2 的效果类似,因为此时张量的行数等于时间步数,列数等于特征数。在这种情况下,dim=1 和 dim=2 都将每一行的值进行归一化,输出的结果相同。

但是当张量的形状为三维及以上时,dim=1 和 dim=2 的效果就不同了。在序列到序列的任务中,通常需要对每个时间步上的输出进行归一化,因此需要使用 torch.nn.Softmax(dim=2)。在分类任务中,通常需要对每个样本的输出进行归一化,因此需要使用 torch.nn.Softmax(dim=1)。

总之,dim 参数的选择应该根据具体的任务需求来进行选择,而不是根据形状的维数来确定。

来源地址:https://blog.csdn.net/longjiaxin1314/article/details/129689984

阅读原文内容投诉

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

软考中级精品资料免费领

  • 历年真题答案解析
  • 备考技巧名师总结
  • 高频考点精准押题
  • 2024年上半年信息系统项目管理师第二批次真题及答案解析(完整版)

    难度     807人已做
    查看
  • 【考后总结】2024年5月26日信息系统项目管理师第2批次考情分析

    难度     351人已做
    查看
  • 【考后总结】2024年5月25日信息系统项目管理师第1批次考情分析

    难度     314人已做
    查看
  • 2024年上半年软考高项第一、二批次真题考点汇总(完整版)

    难度     433人已做
    查看
  • 2024年上半年系统架构设计师考试综合知识真题

    难度     221人已做
    查看

相关文章

发现更多好内容

猜你喜欢

AI推送时光机
位置:首页-资讯-后端开发
咦!没有更多了?去看看其它编程学习网 内容吧
首页课程
资料下载
问答资讯