文章详情

短信预约-IT技能 免费直播动态提醒

请输入下面的图形验证码

提交验证

短信预约提醒成功

pytorch实践线性模型3d源码分析

2023-07-06 01:53

关注

这篇文章主要介绍“pytorch实践线性模型3d源码分析”的相关知识,小编通过实际案例向大家展示操作过程,操作方法简单快捷,实用性强,希望这篇“pytorch实践线性模型3d源码分析”文章能帮助大家解决问题。

y = wx +b
通过meshgrid 得到两个二维矩阵
关键理解:
plot_surface需要的xyz是二维np数组
这里提前准备meshgrid来生产x和y需要的参数
下图的W和I即plot_surface需要xy

pytorch实践线性模型3d源码分析

Z即我们需要的权重损失
计算方式要和W,I. I的每行中内容是一样的就是y=wx+b的b是一样的

    fig = plt.figure()    ax = fig.add_axes(Axes3D(fig))    ax.plot_surface(W, I, Z=MSE_data)

总的实验代码

import matplotlib.pyplot as pltimport numpy as npfrom mpl_toolkits.mplot3d import Axes3Dclass LinearModel:    @staticmethod    def forward(w, x):        return w * x    @staticmethod    def forward_with_intercept(w, x, b):        return w * x + b    @staticmethod    def get_loss(w, x, y_origin, exp=2, b=None):        if b:            y = LinearModel.forward_with_intercept(w, x, b)        else:            y = LinearModel.forward(w, x)        return pow(y_origin - y, exp)def test_2d():    x_data = [1.0, 2.0, 3.0]    y_data = [2.0, 4.0, 6.0]    weight_data = []    MSE_data = []    # 设定实验的权重范围    for w in np.arange(0.0, 4.1, 0.1):        weight_data.append(w)        loss_total = 0        # 计算每个权重在数据集上的MSE平均平方方差        for x_val, y_val in zip(x_data, y_data):            loss_total += LinearModel.get_loss(w, x_val, y_val)        MSE_data.append(loss_total / len(x_data))    # 绘图    plt.xlabel("weight")    plt.ylabel("MSE")    plt.plot(weight_data, MSE_data)    plt.show()def test_3d():    x_data = [1.0, 2.0, 3.0]    y_data = [5.0, 8.0, 11.0]    weight_data = np.arange(0.0, 4.1, 0.1)    intercept_data = np.arange(0.0, 4.1, 0.1)    W, I = np.meshgrid(weight_data, intercept_data)    MSE_data = []    # 设定实验的权重范围 循环要先写截距的 meshgrid 的返回第二个是相当于41*41 同一行值相同 ,要在第二层循环去遍历权重    for intercept in intercept_data:        MSE_data_tmp = []        for w in weight_data:            loss_total = 0            # 计算每个权重在数据集上的MSE平均平方方差            for x_val, y_val in zip(x_data, y_data):                loss_total += LinearModel.get_loss(w, x_val, y_val, b=intercept)            MSE_data_tmp.append(loss_total / len(x_data))        MSE_data.append(MSE_data_tmp)    MSE_data = np.array(MSE_data)    fig = plt.figure()    ax = fig.add_axes(Axes3D(fig))    ax.plot_surface(W, I, Z=MSE_data)    plt.xlabel("weight")    plt.ylabel("intercept")    plt.show()if __name__ == '__main__':    test_2d()    test_3d()

关于“pytorch实践线性模型3d源码分析”的内容就介绍到这里了,感谢大家的阅读。如果想了解更多行业相关的知识,可以关注编程网行业资讯频道,小编每天都会为大家更新不同的知识点。

阅读原文内容投诉

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

软考中级精品资料免费领

  • 历年真题答案解析
  • 备考技巧名师总结
  • 高频考点精准押题
  • 2024年上半年信息系统项目管理师第二批次真题及答案解析(完整版)

    难度     807人已做
    查看
  • 【考后总结】2024年5月26日信息系统项目管理师第2批次考情分析

    难度     351人已做
    查看
  • 【考后总结】2024年5月25日信息系统项目管理师第1批次考情分析

    难度     314人已做
    查看
  • 2024年上半年软考高项第一、二批次真题考点汇总(完整版)

    难度     433人已做
    查看
  • 2024年上半年系统架构设计师考试综合知识真题

    难度     221人已做
    查看

相关文章

发现更多好内容

猜你喜欢

AI推送时光机
位置:首页-资讯-后端开发
咦!没有更多了?去看看其它编程学习网 内容吧
首页课程
资料下载
问答资讯