文章详情

短信预约-IT技能 免费直播动态提醒

请输入下面的图形验证码

提交验证

短信预约提醒成功

【pytorch】torch.cdist使用说明

2023-09-02 15:42

关注

torch.cdist的使用介绍如官网所示,

在这里插入图片描述

它是批量计算两个向量集合的距离。

其中, x1和x2是输入的两个向量集合。

p 默认为2,为欧几里德距离。

它的功能上等同于 scipy.spatial.distance.cdist(input,’minkowski’, p=p)

如果x1的shape是 [B,P,M], x2的shape是[B,R,M],则cdist的结果shape是 [B,P,R]

x1一般是输入矢量,而x2一般是码本。

x2中所有的元素分别与x1中的每一个元素求欧几里德距离(当p默认为2时)

如下面示例

import torchx1 = torch.FloatTensor([0.1, 0.2, 0, 0.5]).view(4, 1)x2 = torch.FloatTensor([0.2, 0.3]).view(2, 1)print(torch.cdist(x1,x2))

x2中的所有元素分别与x1中的每一个元素求欧几里德距离,即有如下步骤

x 11 = ( 0.1 − 0.2 ) 2 =0.1 x 12 = ( 0.1 − 0.3 ) 2 =0.2 x 21 = ( 0.2 − 0.2 ) 2 =0 x 22 = ( 0.2 − 0.3 ) 2 =0.1 x 31 = ( 0 − 0.2 ) 2 =0.2 x 32 = ( 0 − 0.3 ) 2 =0.3 x 41 = ( 0.5 − 0.2 ) 2 =0.3 x 42 = ( 0.5 − 0.3 ) 2 =0.2 x_{11} = \sqrt{ (0.1-0.2)^2} = 0.1 \newline x_{12} = \sqrt { (0.1-0.3)^2} = 0.2 \newline x_{21} = \sqrt { (0.2-0.2)^2} = 0 \newline x_{22} = \sqrt { (0.2-0.3)^2} = 0.1 \newline x_{31} = \sqrt { (0-0.2)^2} = 0.2 \newline x_{32} = \sqrt { (0-0.3)^2} = 0.3 \newline x_{41} = \sqrt { (0.5-0.2)^2 } =0.3\newline x_{42} = \sqrt { (0.5-0.3)^2 } = 0.2\newline x11=(0.10.2)2 =0.1x12=(0.10.3)2 =0.2x21=(0.20.2)2 =0x22=(0.20.3)2 =0.1x31=(00.2)2 =0.2x32=(00.3)2 =0.3x41=(0.50.2)2 =0.3x42=(0.50.3)2 =0.2

所以运行结果为
在这里插入图片描述

如下面示例

import torchx1 = torch.FloatTensor([0.1, 0.2, 0.1, 0.5, 0.2, -0.9, 0.8, 0.4]).view(4, 2)x2 = torch.FloatTensor([0.2, 0.3, 0, 0.1]).view(2, 2)print(torch.cdist(x1,x2))

x1和x2数据是二维的,
在这里插入图片描述

x2中的所有元素分别与x1中的每一个元素求欧几里德距离,即有如下步骤

x 11 = ( 0.1 − 0.2 ) 2 + ( 0.2 − 0.3 ) 2 = 0.02 =0.1414 x 12 = ( 0.1 − 0.0 ) 2 + ( 0.2 − 0.1 ) 2 = 0.02 =0.1414 x 21 = ( 0.1 − 0.2 ) 2 + ( 0.5 − 0.3 ) 2 = 0.05 =0.2236 x 22 = ( 0.1 − 0.0 ) 2 + ( 0.5 − 0.1 ) 2 = 0.17 =0.4123 x 31 = ( 0.2 − 0.2 ) 2 + ( − 0.9 − 0.3 ) 2 =1.2 x 32 = ( 0.2 − 0.0 ) 2 + ( − 0.9 − 0.1 ) 2 = ( 1.04)=1.0198 x 41 = ( 0.8 − 0.2 ) 2 + ( 0.4 − 0.3 ) 2 = ( 0.37)=0.6083 x 42 = ( 0.8 − 0.0 ) 2 + ( 0.4 − 0.1 ) 2 = ( 0.73)=0.8544 x_{11} = \sqrt{ (0.1-0.2)^2 + (0.2-0.3)^2 } = \sqrt{0.02} = 0.1414 \newline x_{12} = \sqrt { (0.1-0.0)^2 + (0.2-0.1)^2 } = \sqrt{0.02} = 0.1414 \newline x_{21} = \sqrt { (0.1-0.2)^2 + (0.5-0.3)^2 } = \sqrt{0.05} = 0.2236 \newline x_{22} = \sqrt { (0.1-0.0)^2 + (0.5-0.1)^2 } = \sqrt{0.17} = 0.4123 \newline x_{31} = \sqrt { (0.2-0.2)^2 + (-0.9-0.3)^2} = 1.2 \newline x_{32} = \sqrt { (0.2-0.0)^2 + (-0.9-0.1)^2} = \sqrt(1.04) = 1.0198 \newline x_{41} = \sqrt { (0.8-0.2)^2 + (0.4-0.3)^2 } = \sqrt(0.37) = 0.6083 \newline x_{42} = \sqrt { (0.8-0.0)^2 + (0.4-0.1)^2 } = \sqrt(0.73) = 0.8544 \newline x11=(0.10.2)2+(0.20.3)2 =0.02 =0.1414x12=(0.10.0)2+(0.20.1)2 =0.02 =0.1414x21=(0.10.2)2+(0.50.3)2 =0.05 =0.2236x22=(0.10.0)2+(0.50.1)2 =0.17 =0.4123x31=(0.20.2)2+(0.90.3)2 =1.2x32=(0.20.0)2+(0.90.1)2 =( 1.04)=1.0198x41=(0.80.2)2+(0.40.3)2 =( 0.37)=0.6083x42=(0.80.0)2+(0.40.1)2 =( 0.73)=0.8544

所以结果如下

在这里插入图片描述

p=2的欧几里德距离也是L2范式,如果p=1即是L1范式
上面的例子修改一下p参数

import torchx1 = torch.FloatTensor([0.1, 0.2, 0.1, 0.5, 0.2, -0.9, 0.8, 0.4]).view(4, 2)x2 = torch.FloatTensor([0.2, 0.3, 0, 0.1]).view(2, 2)print(torch.cdist(x1,x2,p=1))

结果如下,这里就不一个一个运算了。
在这里插入图片描述

来源地址:https://blog.csdn.net/mimiduck/article/details/128886148

阅读原文内容投诉

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

软考中级精品资料免费领

  • 历年真题答案解析
  • 备考技巧名师总结
  • 高频考点精准押题
  • 2024年上半年信息系统项目管理师第二批次真题及答案解析(完整版)

    难度     813人已做
    查看
  • 【考后总结】2024年5月26日信息系统项目管理师第2批次考情分析

    难度     354人已做
    查看
  • 【考后总结】2024年5月25日信息系统项目管理师第1批次考情分析

    难度     318人已做
    查看
  • 2024年上半年软考高项第一、二批次真题考点汇总(完整版)

    难度     435人已做
    查看
  • 2024年上半年系统架构设计师考试综合知识真题

    难度     224人已做
    查看

相关文章

发现更多好内容

猜你喜欢

AI推送时光机
位置:首页-资讯-后端开发
咦!没有更多了?去看看其它编程学习网 内容吧
首页课程
资料下载
问答资讯