文章详情

短信预约-IT技能 免费直播动态提醒

请输入下面的图形验证码

提交验证

短信预约提醒成功

GPT时代学算法,Pytorch框架实现线性模型

2024-11-30 01:37

关注

整个代码会发生多大变化呢?

首先是数据生成的部分,这个部分和之前类似:

import numpy as np
import torch
from torch.utils import data
from d2l import torch as d2l
true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = d2l.synthetic_data(true_w, true_b, 1000)

但是从数据读取开始,就变得不同了。

在之前的代码中,我们是自己实现了迭代器,从训练数据中随机抽取数据。但我们没有做无放回的采样设计,也没有做数据的打乱操作。

然而这些内容Pytorch框架都有现成的工具可以使用,我们不需要再自己实现了。

这里需要用到TensorDataset和DataLoader两个类:

def load_array(data_arrays, batch_size, is_train=True): #@save
    """构造一个PyTorch数据迭代器"""
    dataset = data.TensorDataset(*data_arrays)
    return data.DataLoader(dataset, batch_size, shuffle=is_train)

关于这两个类的用法,我们可以直接询问ChatGPT。

图片

简而言之TensorDataset是用来封装tensor数据的,它的主要功能就是和DataLoader配合。

图片

DataLoader是一个迭代器,除了基本的数据读取之外,还提供乱序、采样、多线程读取等功能。

我们调用load_array获得训练数据的迭代器。

batch_size = 10
data_iter = load_array((features, labels), batch_size)

模型部分

在之前的实现当中,我们是自己创建了两个tensor来作为线性回归模型的参数。

然而其实不必这么麻烦,我们可以把线性回归看做是单层的神经网络,在原理和效果上,它们都是完全一样的。因此我们可以通过调用对应的API来很方便地实现模型:

from torch import nn
net = nn.Sequential(nn.Linear(2, 1))

这里的nn是神经网络的英文缩写,nn.Linear(2, 1)定义了一个输入维度是2,输出维度是1的单层线性网络,等同于线性模型。

nn.Sequential模块容器,它能够将输入的多个网络结构按照顺序拼装成一个完整的模型。这是一种非常常用和方便地构建模型的方法,除了这种方法之外,还有其他的方法创建模型,我们在之后遇到的时候再详细展开。

图片

一般来说模型创建好了之后,并不需要特别去初始化,但如果你想要对模型的参数进行调整的话,可以使用weight.data和weight.bias来访问参数:

net[0].weight.data.normal_(0, 0.01)
net[0].bias.data.fill_(0)

接着我们来定义损失函数,Pytorch当中同样封装了损失函数的实现,我们直接调用即可。

loss = nn.MSELoss()

nn.MSELoss即均方差,MSE即mean square error的缩写。

最后是优化算法,Pytorch当中也封装了更新模型中参数的方法,我们不需要手动来使用tensor里的梯度去更新模型了。只需要定义优化方法,让优化方法自动完成即可:

optim = torch.optim.SGD(net.parameters(), lr=0.03)

训练

最后就是把上述这些实现全部串联起来的模型训练了。

整个过程代码量很少,只有几行。

num_epochs = 3
for epoch in range(num_epochs):
    for X, y in data_iter:
        l = loss(net(X) ,y)
        optim.zero_grad()
        l.backward()
        optim.step()
    l = loss(net(features), labels)
    print(f'epoch {epoch + 1}, loss {l:f}')

我们之前自己实现的模型参数更新部分,被一行optim.step()代替了。

不论多么复杂的模型,都可以通过optim.step()来进行参数更新,非常方便!

同样我们可以来检查一下训练完成之后模型的参数值,同样和我们设置的非常接近。

图片

到这里,整个线性回归模型的实现就结束了。

这个模型是所有模型里最简单的了,正因为简单,所以最适合初学者。后面当接触了更多更复杂的模型之后,会发现虽然代码变复杂了,但遵循的仍然是现在这个框架。

来源:Coder梁内容投诉

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

软考中级精品资料免费领

  • 历年真题答案解析
  • 备考技巧名师总结
  • 高频考点精准押题
  • 2024年上半年信息系统项目管理师第二批次真题及答案解析(完整版)

    难度     813人已做
    查看
  • 【考后总结】2024年5月26日信息系统项目管理师第2批次考情分析

    难度     354人已做
    查看
  • 【考后总结】2024年5月25日信息系统项目管理师第1批次考情分析

    难度     318人已做
    查看
  • 2024年上半年软考高项第一、二批次真题考点汇总(完整版)

    难度     435人已做
    查看
  • 2024年上半年系统架构设计师考试综合知识真题

    难度     224人已做
    查看

相关文章

发现更多好内容

猜你喜欢

AI推送时光机
位置:首页-资讯-后端开发
咦!没有更多了?去看看其它编程学习网 内容吧
首页课程
资料下载
问答资讯