文章详情

短信预约-IT技能 免费直播动态提醒

请输入下面的图形验证码

提交验证

短信预约提醒成功

Pytorch中retain_graph的坑如何解决

2023-07-05 04:22

关注

本篇内容主要讲解“Pytorch中retain_graph的坑如何解决”,感兴趣的朋友不妨来看看。本文介绍的方法操作简单快捷,实用性强。下面就让小编来带大家学习“Pytorch中retain_graph的坑如何解决”吧!

Pytorch中retain_graph的坑

在查看SRGAN源码时有如下损失函数,其中设置了retain_graph=True,其作用就是

在更新D网络时的loss反向传播过程中使用了retain_graph=True,目的为是为保留该过程中计算的梯度,后续G网络更新时使用;

        ############################        # (1) Update D network: maximize D(x)-1-D(G(z))        ###########################        real_img = Variable(target)        if torch.cuda.is_available():            real_img = real_img.cuda()        z = Variable(data)        if torch.cuda.is_available():            z = z.cuda()        fake_img = netG(z)         netD.zero_grad()        real_out = netD(real_img).mean()        fake_out = netD(fake_img).mean()        d_loss = 1 - real_out + fake_out        d_loss.backward(retain_graph=True) #####        optimizerD.step()         ############################        # (2) Update G network: minimize 1-D(G(z)) + Perception Loss + Image Loss + TV Loss        ###########################        netG.zero_grad()        g_loss = generator_criterion(fake_out, fake_img, real_img)        g_loss.backward()        optimizerG.step()        fake_img = netG(z)        fake_out = netD(fake_img).mean()         g_loss = generator_criterion(fake_out, fake_img, real_img)        running_results['g_loss'] += g_loss.data[0] * batch_size        d_loss = 1 - real_out + fake_out        running_results['d_loss'] += d_loss.data[0] * batch_size        running_results['d_score'] += real_out.data[0] * batch_size        running_results['g_score'] += fake_out.data[0] * batch_size

也就是说,只要我们有一个loss,我们就可以先loss.backward(retain_graph=True)  让它先计算梯度,若下面还有其他损失,但是可能你想扩展代码,可能有些loss是不用的,所以先加了 if 等判别语句进行了干预,使用loss.backward(retain_graph=True)就可以单独的计算梯度,屡试不爽。

但是另外一个问题在于,如果你都这么用的话,显存会爆炸,因为他保留了梯度,所以都没有及时释放掉,浪费资源。

而正确的做法应该是,在你最后一个loss 后面,一定要加上loss.backward()这样的形式,也就是让最后一个loss 释放掉之前所有暂时保存下来得梯度!!

Pytorch中有多次backward时需要retain_graph参数

Pytorch中的机制是每次调用loss.backward()时都会free掉计算图中所有缓存的buffers,当模型中可能有多次backward()时,因为前一次调用backward()时已经释放掉了buffer,所以下一次调用时会因为buffers不存在而报错

解决办法

loss.backward(retain_graph=True)

错误使用

因为每次调用bckward时都没有将buffers释放掉,所以会导致内存溢出,迭代越来越慢(因为梯度都保存了,没有free)

正确使用

最后一个 backward() 不要加 retain_graph 参数,这样每次更新完成后会释放占用的内存,也就不会出现越来越慢的情况了

到此,相信大家对“Pytorch中retain_graph的坑如何解决”有了更深的了解,不妨来实际操作一番吧!这里是编程网网站,更多相关内容可以进入相关频道进行查询,关注我们,继续学习!

阅读原文内容投诉

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

软考中级精品资料免费领

  • 历年真题答案解析
  • 备考技巧名师总结
  • 高频考点精准押题
  • 2024年上半年信息系统项目管理师第二批次真题及答案解析(完整版)

    难度     813人已做
    查看
  • 【考后总结】2024年5月26日信息系统项目管理师第2批次考情分析

    难度     354人已做
    查看
  • 【考后总结】2024年5月25日信息系统项目管理师第1批次考情分析

    难度     318人已做
    查看
  • 2024年上半年软考高项第一、二批次真题考点汇总(完整版)

    难度     435人已做
    查看
  • 2024年上半年系统架构设计师考试综合知识真题

    难度     224人已做
    查看

相关文章

发现更多好内容

猜你喜欢

AI推送时光机
位置:首页-资讯-后端开发
咦!没有更多了?去看看其它编程学习网 内容吧
首页课程
资料下载
问答资讯