文章详情

短信预约-IT技能 免费直播动态提醒

请输入下面的图形验证码

提交验证

短信预约提醒成功

返回最大值的index pytorch方式是什么

2023-07-02 18:32

关注

这篇文章主要讲解了“返回最大值的index pytorch方式是什么”,文中的讲解内容简单清晰,易于学习与理解,下面请大家跟着小编的思路慢慢深入,一起来研究和学习“返回最大值的index pytorch方式是什么”吧!

返回最大值的index

import torcha=torch.tensor([[.1,.2,.3],                [1.1,1.2,1.3],                [2.1,2.2,2.3],                [3.1,3.2,3.3]])print(a.argmax(dim=1))print(a.argmax())

输出:

tensor([ 2,  2,  2,  2])
tensor(11)

pytorch 找最大值

题意:使用神经网络实现,从数组中找出最大值。

提供数据:两个 csv 文件,一个存训练集:n 个 m 维特征自然数数据,另一个存每条数据对应的 label ,就是每条数据中的最大值。

这里将随机构建训练集:

#%%import numpy as np import pandas as pd import torch import random import torch.utils.data as Dataimport torch.nn as nnimport torch.optim as optim  def GetData(m, n):    dataset = []    for j in range(m):        max_v = random.randint(0, 9)        data = [random.randint(0, 9) for i in range(n)]        dataset.append(data)    label = [max(dataset[i]) for i in  range(len(dataset))]    data_list = np.column_stack((dataset, label))    data_list = data_list.astype(np.float32)    return data_list #%%# 数据集封装 重载函数len, getitemclass GetMaxEle(Data.Dataset):    def __init__(self, trainset):        self.data = trainset      def __getitem__(self, index):        item = self.data[index]        x = item[:-1]        y = item[-1]        return x, y        def __len__(self):        return len(self.data) # %% 定义网络模型class SingleNN(nn.Module):    def __init__(self, n_feature, n_hidden, n_output):        super(SingleNN, self).__init__()                self.hidden = nn.Linear(n_feature, n_hidden)        self.relu = nn.ReLU()        self.predict = nn.Linear(n_hidden, n_output)     def forward(self, x):        x = self.hidden(x)        x = self.relu(x)        x = self.predict(x)        return x  def train(m, n, batch_size, PATH):    # 随机生成 m 个 n 个维度的训练样本    data_list =GetData(m, n)    dataset = GetMaxEle(data_list)    trainset = torch.utils.data.DataLoader(dataset, batch_size=batch_size,                                      shuffle=True)     net = SingleNN(n_feature=10, n_hidden=100,                   n_output=10)    criterion = nn.CrossEntropyLoss()    optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)    #    total_epoch = 100    for epoch in range(total_epoch):        for index, data in enumerate(trainset):            input_x, labels = data            labels = labels.long()            optimizer.zero_grad()             output = net(input_x)            # print(output)            # print(labels)            loss = criterion(output, labels)            loss.backward()            optimizer.step()         # scheduled_optimizer.step()        print(f"Epoch {epoch}, loss:{loss.item()}")     # %% 保存参数    torch.save(net.state_dict(), PATH)    #测试   def test(m, n, batch_size, PATH):    data_list = GetData(m, n)    dataset = GetMaxEle(data_list)    testloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size)    dataiter = iter(testloader)    input_x, labels = dataiter.next()    net = SingleNN(n_feature=10, n_hidden=100,                   n_output=10)    net.load_state_dict(torch.load(PATH))    outputs = net(input_x)     _, predicted = torch.max(outputs, 1)    print("Ground_truth:",labels.numpy())    print("predicted:",predicted.numpy())  if __name__ == "__main__":    m = 1000    n = 10    batch_size = 64    PATH = './max_list.pth'    train(m, n, batch_size, PATH)    test(m, n, batch_size, PATH)

初始的想法是使用全连接网络+分类来实现, 但是结果不尽人意,主要原因:不同类别之间的样本量差太大,几乎90%都是最大值。

比如代码中随机构建 10 个 0~9 的数字构成一个样本[2, 3, 5, 8, 9, 5, 3, 9, 3, 6], 该样本标签是9。

感谢各位的阅读,以上就是“返回最大值的index pytorch方式是什么”的内容了,经过本文的学习后,相信大家对返回最大值的index pytorch方式是什么这一问题有了更深刻的体会,具体使用情况还需要大家实践验证。这里是编程网,小编将为大家推送更多相关知识点的文章,欢迎关注!

阅读原文内容投诉

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

软考中级精品资料免费领

  • 历年真题答案解析
  • 备考技巧名师总结
  • 高频考点精准押题
  • 2024年上半年信息系统项目管理师第二批次真题及答案解析(完整版)

    难度     813人已做
    查看
  • 【考后总结】2024年5月26日信息系统项目管理师第2批次考情分析

    难度     354人已做
    查看
  • 【考后总结】2024年5月25日信息系统项目管理师第1批次考情分析

    难度     318人已做
    查看
  • 2024年上半年软考高项第一、二批次真题考点汇总(完整版)

    难度     435人已做
    查看
  • 2024年上半年系统架构设计师考试综合知识真题

    难度     224人已做
    查看

相关文章

发现更多好内容

猜你喜欢

AI推送时光机
位置:首页-资讯-后端开发
咦!没有更多了?去看看其它编程学习网 内容吧
首页课程
资料下载
问答资讯