文章详情

短信预约-IT技能 免费直播动态提醒

请输入下面的图形验证码

提交验证

短信预约提醒成功

Pytorch深度学习之实现病虫害图像分类

2024-04-02 19:55

关注

一、pytorch框架

1.1、概念

PyTorch是一个开源的Python机器学习库,基于Torch,用于自然语言处理等应用程序。

2017年1月,由Facebook人工智能研究院(FAIR)基于Torch推出了PyTorch。它是一个基于Python的可续计算包,提供两个高级功能:

1、具有强大的GPU加速的张量计算(如NumPy)。

2、包含自动求导系统的深度神经网络。

1.2、机器学习与深度学习的区别

两者之间区别很多,在本篇博客中只简单描述一部分。以图片的形式展现。

前者为机器学习的过程。

后者为深度学习的过程。

1.3、在python中导入pytorch成功截图

二、数据集

本次实验使用的是coco数据集中的植物病虫害数据集。分为训练文件Traindata和测试文件TestData.,

TrainData有9种分类,每一种分类有100张图片。

TestData有9中分类,每一种分类有10张图片。

在我下一篇博客中将数据集开源。

下面是我的数据集截图:

三、代码复现

3.1、导入第三方库


import torch
from torch.utils.data import Dataset, DataLoader
import numpy as np
import matplotlib
import os
import cv2
from PIL import Image
import torchvision.transforms as transforms
import torch.optim as optim
from torch.autograd import Variable
import torch.nn as nn
import torch.nn.functional as F
from Test.CNN import Net
import json
from Test.train_data import Mydataset,pad_image

3.2、CNN代码


# 构建神经网络
class Net(nn.Module):#定义网络模块
    def __init__(self):
        super(Net, self).__init__()
        # 卷积,该图片有3层,6个特征,长宽均为5*5的像素点,每隔1步跳一下
        self.conv1 = nn.Conv2d(3, 6, 5)
        #//(conv1): Conv2d(3, 6, kernel_size=(5, 5), stride=(1, 1))
        self.pool = nn.MaxPool2d(2, 2)#最大池化
        #//(pool): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
        self.conv2 = nn.Conv2d(6, 16, 5)#卷积
        #//(conv2): Conv2d(6, 16, kernel_size=(5, 5), stride=(1, 1))
        self.fc1 = nn.Linear(16*77*77, 120)#全连接层,图片的维度为16,
        #(fc1): Linear(in_features=94864, out_features=120, bias=True)
        self.fc2 = nn.Linear(120, 84)#全连接层,输入120个特征输出84个特征
        self.fc3 = nn.Linear(84, 7)#全连接层,输入84个特征输出7个特征
 
   def forward(self, x):
        print("x.shape1: ", x.shape)
        x = self.pool(F.relu(self.conv1(x)))
        print("x.shape2: ", x.shape)
        x = self.pool(F.relu(self.conv2(x)))
        print("x.shape3: ", x.shape)
        x = x.view(-1, 16*77*77)
        print("x.shape4: ", x.shape)
        x = F.relu(self.fc1(x))
        print("x.shape5: ", x.shape)
        x = F.relu(self.fc2(x))
        print("x.shape6: ", x.shape)
        x = self.fc3(x)
        print("x.shape7: ", x.shape)
        return x


3.3、测试代码


img_path = "TestData/test_data/1/Apple2 (1).jpg" #使用相对路径
image = Image.open(img_path).convert('RGB')
image_pad = pad_image(image, (320, 320))
input = transform(image_pad).to(device).unsqueeze(0)
output = F.softmax(net(input), 1)
_, predicted = torch.max(output, 1)
score = float(output[0][predicted]*100)
print(class_map[predicted], " ", str(score)+" %")
plt.imshow(image_pad) # 显示图片

四、训练结果

4.1、LOSS损失函数

4.2、 ACC

4.3、单张图片识别准确率

四、小结

这次搭建的网络是基于深度学习框架Lenet,并自己做了一些修改完成。最终的训练的结果LOSS接近0,ACC接近100%。但是一般的识别率不会达到这么高,该模型可能会过拟合。可采取剪枝等操作减小过拟合。

到此这篇关于Pytorch深度学习之实现病虫害图像分类的文章就介绍到这了,更多相关Pytorch图像分类内容请搜索编程网以前的文章或继续浏览下面的相关文章希望大家以后多多支持编程网!

阅读原文内容投诉

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

软考中级精品资料免费领

  • 历年真题答案解析
  • 备考技巧名师总结
  • 高频考点精准押题
  • 2024年上半年信息系统项目管理师第二批次真题及答案解析(完整版)

    难度     807人已做
    查看
  • 【考后总结】2024年5月26日信息系统项目管理师第2批次考情分析

    难度     351人已做
    查看
  • 【考后总结】2024年5月25日信息系统项目管理师第1批次考情分析

    难度     314人已做
    查看
  • 2024年上半年软考高项第一、二批次真题考点汇总(完整版)

    难度     433人已做
    查看
  • 2024年上半年系统架构设计师考试综合知识真题

    难度     221人已做
    查看

相关文章

发现更多好内容

猜你喜欢

AI推送时光机
位置:首页-资讯-后端开发
咦!没有更多了?去看看其它编程学习网 内容吧
首页课程
资料下载
问答资讯