文章详情

短信预约-IT技能 免费直播动态提醒

请输入下面的图形验证码

提交验证

短信预约提醒成功

如何深入理解Pytorch微调torchvision模型

2023-06-25 14:36

关注

如何深入理解Pytorch微调torchvision模型,针对这个问题,这篇文章详细介绍了相对应的分析和解答,希望可以帮助更多想解决这个问题的小伙伴找到更简单易行的方法。

一、简介

在本小节,深入探讨如何对torchvision进行微调和特征提取。所有模型都已经预先在1000类的magenet数据集上训练完成。 本节将深入介绍如何使用几个现代的CNN架构,并将直观展示如何微调任意的PyTorch模型。
本节将执行两种类型的迁移学习:

通常这两种迁移学习方法都会遵循一下步骤:

二、导入相关包

from __future__ import print_functionfrom __future__ import divisionimport torchimport torch.nn as nnimport torch.optim as optimimport numpy as npimport torchvision from torchvision import datasets,models,transformsimport matplotlib.pyplot as pltimport timeimport osimport copyprint("Pytorch version:",torch.__version__)print("torchvision version:",torchvision.__version__)

运行结果

如何深入理解Pytorch微调torchvision模型

三、数据输入

数据集——>我在这里

链接:https://pan.baidu.com/s/1G3yRfKTQf9sIq1iCSoymWQ
提取码:1234

#%%输入data_dir="D:\Python\Pytorch\data\hymenoptera_data"# 从[resnet,alexnet,vgg,squeezenet,desenet,inception]model_name='squeezenet'# 数据集中类别数量num_classes=2# 训练的批量大小batch_size=8# 训练epoch数num_epochs=15# 用于特征提取的标志。为FALSE,微调整个模型,为TRUE只更新图层参数feature_extract=True

四、辅助函数

1、模型训练和验证

#%%模型训练和验证device=torch.device("cuda:0" if torch.cuda.is_available() else "cpu")def train_model(model,dataloaders,criterion,optimizer,num_epochs=25,is_inception=False):    since=time.time()    val_acc_history=[]    best_model_wts=copy.deepcopy(model.state_dict())    best_acc=0.0    for epoch in range(num_epochs):        print('Epoch{}/{}'.format(epoch, num_epochs-1))        print('-'*10)        # 每个epoch都有一个训练和验证阶段        for phase in['train','val']:            if phase=='train':                model.train()            else:                model.eval()                            running_loss=0.0            running_corrects=0            # 迭代数据            for inputs,labels in dataloaders[phase]:                inputs=inputs.to(device)                labels=labels.to(device)                # 梯度置零                optimizer.zero_grad()                # 向前传播                with torch.set_grad_enabled(phase=='train'):                    # 获取模型输出并计算损失,开始的特殊情况在训练中他有一个辅助输出                    # 在训练模式下,通过将最终输出和辅助输出相加来计算损耗,在测试中值考虑最终输出                    if is_inception and phase=='train':                        outputs,aux_outputs=model(inputs)                        loss1=criterion(outputs,labels)                        loss2=criterion(aux_outputs,labels)                        loss=loss1+0.4*loss2                    else:                        outputs=model(inputs)                        loss=criterion(outputs,labels)                                            _,preds=torch.max(outputs,1)                                        if phase=='train':                        loss.backward()                        optimizer.step()                                        # 添加                running_loss+=loss.item()*inputs.size(0)                running_corrects+=torch.sum(preds==labels.data)                            epoch_loss=running_loss/len(dataloaders[phase].dataset)            epoch_acc=running_corrects.double()/len(dataloaders[phase].dataset)                        print('{}loss : {:.4f} acc:{:.4f}'.format(phase, epoch_loss,epoch_acc))                        if phase=='train' and epoch_acc>best_acc:                best_acc=epoch_acc                best_model_wts=copy.deepcopy(model.state_dict())            if phase=='val':                val_acc_history.append(epoch_acc)                    print()    time_elapsed=time.time()-since    print('training complete in {:.0f}s'.format(time_elapsed//60, time_elapsed%60))    print('best val acc:{:.4f}'.format(best_acc))        model.load_state_dict(best_model_wts)    return model,val_acc_history

2、设置模型参数的'.requires_grad属性'

当我们进行特征提取时,此辅助函数将模型中参数的 .requires_grad 属性设置为False。
默认情况下,当我们加载一个预训练模型时,所有参数都是 .requires_grad = True,如果我们从头开始训练或微调,这种设置就没问题。
但是,如果我们要运行特征提取并且只想为新初始化的层计算梯度,那么我们希望所有其他参数不需要梯度变化。

#%%设置模型参数的.require——grad属性def set_parameter_requires_grad(model,feature_extracting):    if feature_extracting:        for param in model.parameters():            param.require_grad=False

关于如何深入理解Pytorch微调torchvision模型问题的解答就分享到这里了,希望以上内容可以对大家有一定的帮助,如果你还有很多疑惑没有解开,可以关注编程网行业资讯频道了解更多相关知识。

阅读原文内容投诉

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

软考中级精品资料免费领

  • 历年真题答案解析
  • 备考技巧名师总结
  • 高频考点精准押题
  • 2024年上半年信息系统项目管理师第二批次真题及答案解析(完整版)

    难度     813人已做
    查看
  • 【考后总结】2024年5月26日信息系统项目管理师第2批次考情分析

    难度     354人已做
    查看
  • 【考后总结】2024年5月25日信息系统项目管理师第1批次考情分析

    难度     318人已做
    查看
  • 2024年上半年软考高项第一、二批次真题考点汇总(完整版)

    难度     435人已做
    查看
  • 2024年上半年系统架构设计师考试综合知识真题

    难度     224人已做
    查看

相关文章

发现更多好内容

猜你喜欢

AI推送时光机
位置:首页-资讯-后端开发
咦!没有更多了?去看看其它编程学习网 内容吧
首页课程
资料下载
问答资讯