文章详情

短信预约-IT技能 免费直播动态提醒

请输入下面的图形验证码

提交验证

短信预约提醒成功

基于nn.Module类实现线性回归模型

2024-12-03 02:02

关注

上次介绍了顺序模型,但是在大多数情况下,我们基本都是以类的形式实现神经网络。

大多数情况下创建一个继承自 Pytorch 中的 nn.Module 的类,这样可以使用 Pytorch 提供的许多高级 API,而无需自己实现。

下面展示了一个可以从nn.Module创建的最简单的神经网络类的示例。基于 nn.Module的类的最低要求是覆盖__init__()方法和forward()方法。

在这个类中,定义了一个简单的线性网络,具有两个输入和一个输出,并使用 Sigmoid()函数作为网络的激活函数。

  1. import torch 
  2. from torch import nn 
  3.  
  4. class LinearRegression(nn.Module): 
  5.     def __init__(self): 
  6.         #继承父类构造函数 
  7.         super(LinearRegression, self).__init__()  
  8.         #输入和输出的维度都是1 
  9.         self.linear = nn.Linear(1, 1)  
  10.     def forward(self, x): 
  11.         out = self.linear(x) 
  12.         return out 

现在让我们测试一下模型。

  1. # 创建LinearRegression()的实例 
  2. model = LinearRegression() 
  3. print(model)  
  4. # 输出如下 
  5. LinearRegression( 
  6.   (linear): Linear(in_features=1, out_features=1, bias=True

现在让定义一个损失函数和优化函数。

  1. model = LinearRegression()#实例化对象 
  2. num_epochs = 1000#迭代次数 
  3. learning_rate = 1e-2#学习率0.01 
  4. Loss = torch.nn.MSELoss()#损失函数 
  5. optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)#优化函数 

我们创建一个由方程产生的数据集,并通过函数制造噪音

  1. import torch  
  2. from matplotlib import pyplot as plt 
  3. from torch.autograd import Variable 
  4. from torch import nn 
  5. # 创建数据集  unsqueeze 相当于 
  6. x = Variable(torch.unsqueeze(torch.linspace(-1, 1, 100), dim=1)) 
  7. y = Variable(x * 2 + 0.2 + torch.rand(x.size())) 
  8. plt.scatter(x.data.numpy(),y.data.numpy()) 
  9. plt.show() 

关于torch.unsqueeze函数解读。

  1. >>> x = torch.tensor([1, 2, 3, 4]) 
  2. >>> torch.unsqueeze(x, 0) 
  3. tensor([[ 1,  2,  3,  4]]) 
  4. >>> torch.unsqueeze(x, 1) 
  5. tensor([[ 1], 
  6.         [ 2], 
  7.         [ 3], 
  8.         [ 4]]) 

遍历每次epoch,计算出loss,反向传播计算梯度,不断的更新梯度,使用梯度下降进行优化。

  1. for epoch in range(num_epochs): 
  2.     # 预测 
  3.     y_pred= model(x) 
  4.     # 计算loss 
  5.     loss = Loss(y_pred, y) 
  6.     #清空上一步参数值 
  7.     optimizer.zero_grad() 
  8.     #反向传播 
  9.     loss.backward() 
  10.     #更新参数 
  11.     optimizer.step() 
  12.     if epoch % 200 == 0: 
  13.         print("[{}/{}] loss:{:.4f}".format(epoch+1, num_epochs, loss)) 
  14.  
  15. plt.scatter(x.data.numpy(), y.data.numpy()) 
  16. plt.plot(x.data.numpy(), y_pred.data.numpy(), 'r-',lw=5) 
  17. plt.text(0.5, 0,'Loss=%.4f' % loss.data.item(), fontdict={'size': 20, 'color':  'red'}) 
  18. plt.show() 
  19. ####结果如下#### 
  20. [1/1000] loss:4.2052 
  21. [201/1000] loss:0.2690 
  22. [401/1000] loss:0.0925 
  23. [601/1000] loss:0.0810 
  24. [801/1000] loss:0.0802 

  1. [w, b] = model.parameters() 
  2. print(w,b) 
  3. # Parameter containing: 
  4. tensor([[2.0036]], requires_grad=True) Parameter containing: 
  5. tensor([0.7006], requires_grad=True

这里的b=0.7,等于0.2 + torch.rand(x.size()),经过大量的训练torch.rand()一般约等于0.5。

 

来源:Python之王内容投诉

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

软考中级精品资料免费领

  • 历年真题答案解析
  • 备考技巧名师总结
  • 高频考点精准押题
  • 2024年上半年信息系统项目管理师第二批次真题及答案解析(完整版)

    难度     813人已做
    查看
  • 【考后总结】2024年5月26日信息系统项目管理师第2批次考情分析

    难度     354人已做
    查看
  • 【考后总结】2024年5月25日信息系统项目管理师第1批次考情分析

    难度     318人已做
    查看
  • 2024年上半年软考高项第一、二批次真题考点汇总(完整版)

    难度     435人已做
    查看
  • 2024年上半年系统架构设计师考试综合知识真题

    难度     224人已做
    查看

相关文章

发现更多好内容

猜你喜欢

AI推送时光机
位置:首页-资讯-后端开发
咦!没有更多了?去看看其它编程学习网 内容吧
首页课程
资料下载
问答资讯