文章详情

短信预约-IT技能 免费直播动态提醒

请输入下面的图形验证码

提交验证

短信预约提醒成功

pytorch中retain_graph==True的作用说明

2023-02-21 12:00

关注

pytorch retain_graph==True的作用说明

总的来说进行一次backward之后,各个节点的值会清除,这样进行第二次backward会报错,如果加上retain_graph==True后,可以再来一次backward。 

retain_graph参数的作用

官方定义:

retain_graph (bool, optional) – If False, the graph used to compute the grad will be freed. Note that in nearly all cases setting this option to True is not needed and often can be worked around in a much more efficient way. Defaults to the value of create_graph.

大意是如果设置为False,计算图中的中间变量在计算完后就会被释放。

但是在平时的使用中这个参数默认都为False从而提高效率,和creat_graph的值一样。

具体看一个例子理解

假设一个我们有一个输入x,y = x **2, z = y*4,然后我们有两个输出,一个output_1 = z.mean(),另一个output_2 = z.sum()。

然后我们对两个output执行backward。

import torch
x = torch.randn((1,4),dtype=torch.float32,requires_grad=True)
y = x ** 2
z = y * 4
print(x)
print(y)
print(z)
loss1 = z.mean()
loss2 = z.sum()
print(loss1,loss2)
loss1.backward()    # 这个代码执行正常,但是执行完中间变量都free了,所以下一个出现了问题
print(loss1,loss2)
loss2.backward()    # 这时会引发错误

程序正常执行到第12行,所有的变量正常保存。

但是在第13行报错:

RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.

分析:计算节点数值保存了,但是计算图x-y-z结构被释放了,而计算loss2的backward仍然试图利用x-y-z的结构,因此会报错。

因此需要retain_graph参数为True去保留中间参数从而两个loss的backward()不会相互影响。

正确的代码应当把第11行以及之后改成

create_graph参数比较简单,参考官方定义:

create_graph (bool, optional) – If True, graph of the derivative will be constructed, allowing to compute higher order derivative products. Defaults to False.

Pytorch retain_graph=True错误信息

(Pytorch:RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time)

具有多个loss值

retain_graph设置True,一般多用于两次backward

# 假如有两个Loss,先执行第一个的backward,再执行第二个backward
loss1.backward(retain_graph=True) # 这样计算图就不会立即释放
loss2.backward() # 执行完这个后,所有中间变量都会被释放,以便下一次的循环
optimizer.step() # 更新参数

retain_graph设置True后一定要知道释放,否则显卡会占用越来越多,代码速度也会跑的越来越慢。

有的时候我明明仅有一个模型的也会出现这种错误

第一种是输入的原因。

// Example
x = torch.randn((100,1), requires_grad = True)
y = 1 + 2 * x + 0.3 * torch.randn(100,1)
x_train, y_train = x[:70], y[:70]
x_val, y_val = x[70:], y[70:]

for epoch in range(n_epochs):
    ...
    prediction = model(x_train)
    loss.backward()
    ...

在多次循环的过程中,input的梯度没有清除,而且我们也不需要计算输入的梯度,因此将x的require_grad设置为False就可以解决问题。

第二种是我在训练LSTM时候发现的。

class LSTMpred(nn.Module):
    def __init__(self, input_size, hidden_dim):
        self.hidden = self.init_hidden()
       ...
    def init_hidden(self):    #这里我们是需要个隐层参数的
        return (torch.zeros(1, 1, self.hidden_dim, requires_grad=True),
                torch.zeros(1, 1, self.hidden_dim, requires_grad=True))
    def forward(self, seq):
        ...

这里面的self.hidden我们在每一次训练的时候都要重新初始化隐层参数:

for epoch in range(Epoch):
    ...
    model.hidden = model.init_hidden()
    modout = model(seq)
    ...

3. 我的看法

其实,想想这几种情况都是一回事,都是网络在反向传播中不允许多个backward(),也就是梯度下降反馈的时候,有多个循环过程中共用了同一个需要计算梯度的变量,在前一个循环清除梯度后,后面一个循环过程就会在这个变量上栽跟头(个人想法)。

总结

以上为个人经验,希望能给大家一个参考,也希望大家多多支持编程网。

阅读原文内容投诉

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

软考中级精品资料免费领

  • 历年真题答案解析
  • 备考技巧名师总结
  • 高频考点精准押题
  • 2024年上半年信息系统项目管理师第二批次真题及答案解析(完整版)

    难度     813人已做
    查看
  • 【考后总结】2024年5月26日信息系统项目管理师第2批次考情分析

    难度     354人已做
    查看
  • 【考后总结】2024年5月25日信息系统项目管理师第1批次考情分析

    难度     318人已做
    查看
  • 2024年上半年软考高项第一、二批次真题考点汇总(完整版)

    难度     435人已做
    查看
  • 2024年上半年系统架构设计师考试综合知识真题

    难度     224人已做
    查看

相关文章

发现更多好内容

猜你喜欢

AI推送时光机
位置:首页-资讯-后端开发
咦!没有更多了?去看看其它编程学习网 内容吧
首页课程
资料下载
问答资讯