文章详情

短信预约-IT技能 免费直播动态提醒

请输入下面的图形验证码

提交验证

短信预约提醒成功

PyTorch中的nn.Module类怎么使用

2023-07-05 04:25

关注

这篇文章主要讲解了“PyTorch中的nn.Module类怎么使用”,文中的讲解内容简单清晰,易于学习与理解,下面请大家跟着小编的思路慢慢深入,一起来研究和学习“PyTorch中的nn.Module类怎么使用”吧!

PyTorch nn.Module类的简介

torch.nn.Module类是所有神经网络模块(modules)的基类,它的实现在torch/nn/modules/module.py中。你的模型也应该继承这个类,主要重载__init__、forward和extra_repr函数。Modules还可以包含其它Modules,从而可以将它们嵌套在树结构中。

只要在自己的类中定义了forward函数,backward函数就会利用Autograd被自动实现。只要实例化一个对象并传入对应的参数就可以自动调用forward函数。因为此时会调用对象的__call__方法,而nn.Module类中的__call__方法会调用forward函数。

nn.Module类中函数介绍:

测试代码如下:

import torchimport torch.nn as nnimport torch.nn.functional as F # nn.functional.py中存放激活函数等的实现 @torch.no_grad()def init_weights(m):    print("xxxx:", m)    if type(m) == nn.Linear:         m.weight.fill_(1.0)         print("yyyy:", m.weight) class Model(nn.Module):    def __init__(self):        # 在实现自己的__init__函数时,为了正确初始化自定义的神经网络模块,一定要先调用super().__init__        super(Model, self).__init__()        self.conv1 = nn.Conv2d(1, 20, 5) # submodule(child module)        self.conv2 = nn.Conv2d(20, 20, 5)        self.add_module("conv3", nn.Conv2d(10, 40, 5)) # 添加一个submodule到当前module,等价于self.conv3 = nn.Conv2d(10, 40, 5)        self.register_buffer("buffer", torch.randn([2,3])) # 给module添加一个presistent(持久的) buffer        self.param1 = nn.Parameter(torch.rand([1])) # module参数的tensor        self.register_parameter("param2", nn.Parameter(torch.rand([1]))) # 向module添加参数         # nn.Sequential: 顺序容器,module将按照它们在构造函数中传递的顺序添加,它允许将整个容器视为单个module        self.feature = nn.Sequential(nn.Linear(2, 2), nn.Linear(2, 2))        self.feature.apply(init_weights) # 将fn递归应用于每个submodule,典型用途为初始化模型参数        self.feature.to(torch.double) # 将参数数据类型转换为double        cpu = torch.device("cpu")        self.feature.to(cpu) # 将参数数据转换到cpu设备上     def forward(self, x):       x = F.relu(self.conv1(x))       return F.relu(self.conv2(x)) model = Model()print("## Model:", model) model.cpu() # 将所有模型参数和buffers移动到CPU上model.float() # 将所有浮点参数和buffers转换为float数据类型model.zero_grad() # 将所有模型参数的梯度设置为零 # state_dict:返回一个字典,保存着module的所有状态,参数和persistent buffers都会包含在字典中,字典的key就是参数和buffer的namesprint("## state_dict:", model.state_dict().keys()) for name, parameters in model.named_parameters(): # 返回module的参数(weight and bias)的迭代器,产生(yield)参数的名称以及参数本身    print(f"## named_parameters: name: {name}; parameters size: {parameters.size()}") for name, buffers in model.named_buffers(): # 返回module的buffers的迭代器,产生(yield)buffer的名称以及buffer本身    print(f"## named_buffers: name: {name}; buffers size: {buffers.size()}") # 注:children和modules中重复的module只被返回一次for children in model.children(): # 返回当前module的child module(submodule)的迭代器    print("## children:", children) for name, children in model.named_children(): # 返回直接submodule的迭代器,产生(yield) submodule的名称以及submodule本身    print(f"## named_children: name: {name}; children: {children}") for modules in model.modules(): # 返回当前模型所有module的迭代器,注意与children的区别    print("## modules:", modules) for name, modules in model.named_modules(): # 返回网络中所有modules的迭代器,产生(yield)module的名称以及module本身,注意与named_children的区别    print(f"## named_modules: name: {name}; module: {modules}") model.train() # 将module设置为训练模式model.eval() # 将module设置为评估模式 print("test finish")

PyTorch中nn.Module理解

nn.Module是Pytorch封装的一个类,是搭建神经网络时需要继承的父类:

import torchimport torch.nn as nn# 括号中加入nn.Module(父类)。Test2变成子类,继承父类(nn.Module)的所有特性。class Test2(nn.Module):      def __init__(self):  # Test2类定义初始化方法       super(Test2, self).__init__()  # 父类初始化       self.M = nn.Parameter(torch.ones(10))            def weightInit(self):        print('Testing')    def forward(self, n):        # print(2 * n)        print(self.M * n)        self.weightInit()# 调用方法network = Test2()network(2)  # 2赋值给forward(self, n)中的n。……省略一部分代码……# 因为Test2是nn.Module的子类,所以也可以执行父类中的方法。如:model_dict = network.state_dict()  # 调用父类中的方法state_dict(),将Test2中训练参数赋值model_dict。for k, v in model_dict.items():  # 查看自己网络参数各层名称、数值print(k)  # 输出网络参数名字    # print(v)  # 输出网络参数数值

继承nn.Module的子类程序是从forward()方法开始执行的,如果要想执行其他方法,必须把它放在forward()方法中。这一点与python中继承有稍许的不同。

感谢各位的阅读,以上就是“PyTorch中的nn.Module类怎么使用”的内容了,经过本文的学习后,相信大家对PyTorch中的nn.Module类怎么使用这一问题有了更深刻的体会,具体使用情况还需要大家实践验证。这里是编程网,小编将为大家推送更多相关知识点的文章,欢迎关注!

阅读原文内容投诉

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

软考中级精品资料免费领

  • 历年真题答案解析
  • 备考技巧名师总结
  • 高频考点精准押题
  • 2024年上半年信息系统项目管理师第二批次真题及答案解析(完整版)

    难度     813人已做
    查看
  • 【考后总结】2024年5月26日信息系统项目管理师第2批次考情分析

    难度     354人已做
    查看
  • 【考后总结】2024年5月25日信息系统项目管理师第1批次考情分析

    难度     318人已做
    查看
  • 2024年上半年软考高项第一、二批次真题考点汇总(完整版)

    难度     435人已做
    查看
  • 2024年上半年系统架构设计师考试综合知识真题

    难度     224人已做
    查看

相关文章

发现更多好内容

猜你喜欢

AI推送时光机
位置:首页-资讯-后端开发
咦!没有更多了?去看看其它编程学习网 内容吧
首页课程
资料下载
问答资讯