文章详情

短信预约-IT技能 免费直播动态提醒

请输入下面的图形验证码

提交验证

短信预约提醒成功

解决使用copy.deepcopy()拷贝Tensor或model时报错只支持用户显式创建的Tensor问题

2023-10-05 09:06

关注

模型训练过程中常需边训练边做validation或在训练完的模型需要做测试,通常的做法当然是先创建model实例然后掉用load_state_dict()装载训练出来的权重到model里再调用model.eval()把模型转为测试模式,这样写对于训练完专门做测试时当然是比较合适的,但是对于边训练边做validation使用这种方式就需要写一堆代码,如果能使用copy.deepcopy()直接深度拷贝训练中的model用来做validation显然是比较简洁的写法,但是由于copy.deepcopy()的限制,写model里代码时如果没注意,调用copy.deepcopy(model)时可能就会遇到这个错误:Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment,详细错误信息如下:

 File "/usr/local/lib/python3.6/site-packages/prc/framework/model/validation.py", line 147, in init_val_model    val_model = copy.deepcopy(model)  File "/usr/lib64/python3.6/copy.py", line 180, in deepcopy    y = _reconstruct(x, memo, *rv)  File "/usr/lib64/python3.6/copy.py", line 280, in _reconstruct    state = deepcopy(state, memo)  File "/usr/lib64/python3.6/copy.py", line 150, in deepcopy    y = copier(x, memo)  File "/usr/lib64/python3.6/copy.py", line 240, in _deepcopy_dict    y[deepcopy(key, memo)] = deepcopy(value, memo)  File "/usr/lib64/python3.6/copy.py", line 180, in deepcopy    y = _reconstruct(x, memo, *rv)  File "/usr/lib64/python3.6/copy.py", line 306, in _reconstruct    value = deepcopy(value, memo)  File "/usr/lib64/python3.6/copy.py", line 180, in deepcopy    y = _reconstruct(x, memo, *rv)  File "/usr/lib64/python3.6/copy.py", line 280, in _reconstruct    state = deepcopy(state, memo)  File "/usr/lib64/python3.6/copy.py", line 150, in deepcopy    y = copier(x, memo)  File "/usr/lib64/python3.6/copy.py", line 240, in _deepcopy_dict    y[deepcopy(key, memo)] = deepcopy(value, memo)  File "/usr/lib64/python3.6/copy.py", line 180, in deepcopy    y = _reconstruct(x, memo, *rv)  File "/usr/lib64/python3.6/copy.py", line 306, in _reconstruct    value = deepcopy(value, memo)  File "/usr/lib64/python3.6/copy.py", line 180, in deepcopy    y = _reconstruct(x, memo, *rv)  File "/usr/lib64/python3.6/copy.py", line 280, in _reconstruct    state = deepcopy(state, memo)  File "/usr/lib64/python3.6/copy.py", line 150, in deepcopy    y = copier(x, memo)  File "/usr/lib64/python3.6/copy.py", line 240, in _deepcopy_dict    y[deepcopy(key, memo)] = deepcopy(value, memo)  File "/usr/lib64/python3.6/copy.py", line 161, in deepcopy    y = copier(memo)  File "/root/.local/lib/python3.6/site-packages/torch/_tensor.py", line 55, in __deepcopy__    raise RuntimeError("Only Tensors created explicitly by the user "RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment

这个错误简单地说就是copy.deepcopy()不支持拷贝requires_grad=True的Tensor(在网络中一般是非叶子结点Tensor, grad_fn不为None),开始以为真的哪个地方Tensor的requires_grad没有按要求设置,熬了几个夜去检查调试网络代码没发现什么线索很郁闷,后来想既然是copy.deepcopy()里报错的,源码也有那就去它里面debug看是拷贝网络的那部分时抛出的Exception吧,折腾了一阵发现里面这个地方加breakpoint比较合适:

   if dictiter is not None:        if deep:            for key, value in dictiter:                key = deepcopy(key, memo)                value = deepcopy(value, memo)                y[key] = value        else:            for key, value in dictiter:                y[key] = value

我这个网络的结构是使用的python dict方式定义的,运行时使用注册机制动态创建出来的,既然是dict,这里的key和value就是对应配置文件里的定义网络每层结构的dict的key和value,在这里加bp可以比较清楚地跟踪看到是在哪个地方导致的抛出Exception,结果发现原因是因为有个实现分割功能的head类的内部有个成员变量保存了这层的输出结果Tensor用于后面计算loss,模型每层的输出数据Tensor自然是requires_grad=True,把这个成员变量去掉,改成forward()输出结果,然后在网络的主类里接收它并传入计算Loss的函数,然后deepcopy(model)就不报上面的错了!

另外,显式创建一个Tensor时指定requires_grad=True(默认是False)并不会导致copy.deepcopy()报错,不管这个Tensor是在cpu上还是gpu上,关键是用户自己创建的Tensor是叶子结点Tensor,它的grad_fn是None,在这个Tensor上做切片或者加载到gpu上等操作得到的新的Tensor就不是叶子结点了,pytorch认为requires_grad=Trued的Tensor经过运算得到新的Tensor是需要求导的会自动加上grad_fn而不管这个Tensor是不是网络的一部分,这时再使用copy.deepcopy()深度拷贝新的Tensor时会抛出上面的错误,看完下面的示例就知道了:

>>> t = torch.tensor([1,2,3.5],dtype=torch.float32, requires_grad=True, device='cuda:0')>>> ttensor([1.0000, 2.0000, 3.5000], device='cuda:0', requires_grad=True)>>> x = copy.deepcopy(t)>>> xtensor([1.0000, 2.0000, 3.5000], device='cuda:0', requires_grad=True)>>> t1 = t[:2]>>> t1tensor([1., 2.], device='cuda:0', grad_fn=)>>> x = copy.deepcopy(t1)Traceback (most recent call last):  File "", line 1, in   File "/usr/local/python3.8/lib/python3.8/copy.py", line 153, in deepcopy    y = copier(memo)  File "/root/.local/lib/python3.8/site-packages/torch/_tensor.py", line 85, in __deepcopy__    raise RuntimeError("Only Tensors created explicitly by the user "RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment>>> t = torch.tensor([1,2,3.5],dtype=torch.float32, requires_grad=True)>>> t1 = t.cuda()>>> t1tensor([1.0000, 2.0000, 3.5000], device='cuda:0', grad_fn=)>>> x = copy.deepcopy(t1)Traceback (most recent call last):  File "", line 1, in   File "/usr/local/python3.8/lib/python3.8/copy.py", line 153, in deepcopy    y = copier(memo)  File "/root/.local/lib/python3.8/site-packages/torch/_tensor.py", line 85, in __deepcopy__    raise RuntimeError("Only Tensors created explicitly by the user "RuntimeError: Only Tensors created explicitly by the user (graph leaves) support the deepcopy protocol at the moment>>> t = torch.tensor([1,2,3.5],dtype=torch.float32, requires_grad=False)>>> ttensor([1.0000, 2.0000, 3.5000])>>> x = copy.deepcopy(t)>>> xtensor([1.0000, 2.0000, 3.5000])>>> t1 = t[:2]  >>> t1tensor([1., 2.])>>> x = copy.deepcopy(t1)

为何deepcopy()不直接支持有梯度的Tensor,按理要支持复制一个当时的瞬间值应该也没问题,看到https://discuss.pytorch.org/t/copy-deepcopy-vs-clone/55022/10这里这个经常回答问题的胡子哥给了个猜测:

来源地址:https://blog.csdn.net/XCCCCZ/article/details/128794986

阅读原文内容投诉

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

软考中级精品资料免费领

  • 历年真题答案解析
  • 备考技巧名师总结
  • 高频考点精准押题
  • 2024年上半年信息系统项目管理师第二批次真题及答案解析(完整版)

    难度     813人已做
    查看
  • 【考后总结】2024年5月26日信息系统项目管理师第2批次考情分析

    难度     354人已做
    查看
  • 【考后总结】2024年5月25日信息系统项目管理师第1批次考情分析

    难度     318人已做
    查看
  • 2024年上半年软考高项第一、二批次真题考点汇总(完整版)

    难度     435人已做
    查看
  • 2024年上半年系统架构设计师考试综合知识真题

    难度     224人已做
    查看

相关文章

发现更多好内容

猜你喜欢

AI推送时光机
位置:首页-资讯-后端开发
咦!没有更多了?去看看其它编程学习网 内容吧
首页课程
资料下载
问答资讯