文章详情

短信预约-IT技能 免费直播动态提醒

请输入下面的图形验证码

提交验证

短信预约提醒成功

用 PyTorch 实现基于字符的循环神经网络

2024-12-03 14:57

关注

这个想法(来自 循环神经网络的不合理效应 )可以让你在文本上训练一个基于字符的 循环神经网络(recurrent neural network)(RNN),并得到一些出乎意料好的结果。

[[358756]]

不过,虽然没有得到我想要的结果,但是我还是想分享一些示例代码和结果,希望对其他开始尝试使用 PyTorch 和 RNN 的人有帮助。

这是 Jupyter 笔记本格式的代码: char-rnn in PyTorch.ipynb 。你可以点击这个网页最上面那个按钮 “Open in Colab”,就可以在 Google 的 Colab 服务中打开,并使用免费的 GPU 进行训练。所有的东西加起来大概有 75 行代码,我将在这篇博文中尽可能地详细解释。

第一步:准备数据

首先,我们要下载数据。我使用的是 古登堡项目(Project Gutenberg)中的这个数据: Hans Christian Anderson’s fairy tales 。

  1. !wget -O fairy-tales.txt 

这个是准备数据的代码。我使用 fastai 库中的 Vocab 类进行数据处理,它能将一堆字母转换成“词表”,然后用这个“词表”把字母变成数字。

之后我们就得到了一个大的数字数组(training_set),我们可以用于训练我们的模型。

  1. from fastai.text import * 
  2. text = unidecode.unidecode(open('fairy-tales.txt').read()) 
  3. v = Vocab.create((x for x in text), max_vocab=400min_freq=1
  4. training_set = torch.Tensor(v.numericalize([x for x in text])).type(torch.LongTensor).cuda() 
  5. num_letters = len(v.itos) 

第二步:定义模型

这个是 PyTorch 中 LSTM 类的封装。除了封装 LSTM 类以外,它还做了三件事:

  1. class MyLSTM(nn.Module): 
  2.     def __init__(self, input_size, hidden_size): 
  3.         super().__init__() 
  4.         self.lstm = nn.LSTM(input_size, hidden_size, batch_first=True
  5.         self.h2o = nn.Linear(hidden_size, input_size) 
  6.         self.input_size=input_size 
  7.         self.hidden = None 
  8.  
  9.     def forward(self, input): 
  10.         input = torch.nn.functional.one_hot(input, num_classes=self.input_size).type(torch.FloatTensor).cuda().unsqueeze(0) 
  11.         if self.hidden is None: 
  12.             l_output, selfself.hidden = self.lstm(input) 
  13.         else: 
  14.             l_output, selfself.hidden = self.lstm(input, self.hidden) 
  15.         self.hidden = (self.hidden[0].detach(), self.hidden[1].detach()) 
  16.  
  17.         return self.h2o(l_output) 

这个代码还做了一些比较神奇但是不太明显的功能。如果你的输入是一个向量(比如 [1,2,3,4,5,6]),对应六个字母,那么我的理解是 nn.LSTM 会在内部使用 沿时间反向传播 更新隐藏向量 6 次。

第三步:编写训练代码

模型不会自己训练的!

我最开始的时候尝试用 fastai 库中的一个辅助类(也是 PyTorch 中的封装)。我有点疑惑因为我不知道它在做什么,所以最后我自己编写了模型训练代码。

下面这些代码(epoch() 方法)就是有关于一轮训练过程的基本信息。基本上就是重复做下面这几件事情:

  1. class Trainer(): 
  2.   def __init__(self): 
  3.       self.rnn = MyLSTM(input_size, hidden_size).cuda() 
  4.       self.optimizer = torch.optim.Adam(self.rnn.parameters(), amsgrad=Truelrlr=lr) 
  5.   def epoch(self): 
  6.       i = 0 
  7.       while i < len(training_set) - 40: 
  8.         seq_len = random.randint(10, 40) 
  9.         input, target = training_set[i:i+seq_len],training_set[i+1:i+1+seq_len] 
  10.         i += seq_len 
  11.         # forward pass 
  12.         output = self.rnn(input) 
  13.         loss = F.cross_entropy(output.squeeze()[-1:], target[-1:]) 
  14.         # compute gradients and take optimizer step 
  15.         self.optimizer.zero_grad() 
  16.         loss.backward() 
  17.         self.optimizer.step() 

使用 nn.LSTM 沿着时间反向传播,不要自己写代码

开始的时候我自己写代码每次传一个字母到 LSTM 层中,之后定期计算导数,就像下面这样:

  1. for i in range(20): 
  2.     input, target = next(iter) 
  3.     output, hidden = self.lstm(input, hidden) 
  4. loss = F.cross_entropy(output, target) 
  5. hiddenhidden = hidden.detach() 
  6. self.optimizer.zero_grad() 
  7. loss.backward() 
  8. self.optimizer.step() 

这段代码每次传入 20 个字母,每次一个,并且在最后训练了一次。这个步骤就被称为 沿时间反向传播 ,Karpathy 在他的博客中就是用这种方法。

这个方法有些用处,我编写的损失函数开始能够下降一段时间,但之后就会出现峰值。我不知道为什么会出现这种现象,但之后我改为一次传入 20 个字符到 LSTM 之后(按 seq_len 维度),再进行反向传播,情况就变好了。

第四步:训练模型!

我在同样的数据上重复执行了这个训练代码大概 300 次,直到模型开始输出一些看起来像英文的文本。差不多花了一个多小时吧。

这种情况下我也不关注模型是不是过拟合了,但是如果你在真实场景中训练模型,应该要在验证集上验证你的模型。

第五步:生成输出!

最后一件要做的事就是用这个模型生成一些输出。我写了一个辅助方法从这个训练好的模型中生成文本(make_preds 和 next_pred)。这里主要是把向量的维度对齐,重要的一点是:

  1. output = rnn(input) 
  2. prediction_vector = F.softmax(output/temperature) 
  3. letter = v.textify(torch.multinomial(prediction_vector, 1).flatten(), sep='').replace('_', ' ') 

基本上做的事情就是这些:

如果我们想要处理的文本长度为 300,那么只需要重复这个过程 300 次就可以了。

结果!

我把预测函数中的参数设置为 temperature = 1 得到了下面的这些由模型生成的结果。看起来有点像英语,这个结果已经很不错了,因为这个模型要从头开始“学习”英语,并且是在字符序列的级别上进行学习的。

虽然这些话没有什么含义,但我们也不知道到底想要得到什么输出。

“An who was you colotal said that have to have been a little crimantable and beamed home the beetle. “I shall be in the head of the green for the sound of the wood. The pastor. “I child hand through the emperor’s sorthes, where the mother was a great deal down the conscious, which are all the gleam of the wood they saw the last great of the emperor’s forments, the house of a large gone there was nothing of the wonded the sound of which she saw in the converse of the beetle. “I shall know happy to him. This stories herself and the sound of the young mons feathery in the green safe.”

“That was the pastor. The some and hand on the water sound of the beauty be and home to have been consider and tree and the face. The some to the froghesses and stringing to the sea, and the yellow was too intention, he was not a warm to the pastor. The pastor which are the faten to go and the world from the bell, why really the laborer’s back of most handsome that she was a caperven and the confectioned and thoughts were seated to have great made

下面这些结果是当 temperature=0.1 时生成的,它选择字符的方式更接近于“每次都选择出现概率最高的字符”。这就使得输出结果有很多是重复的。

ole the sound of the beauty of the beetle. “She was a great emperor of the sea, and the sun was so warm to the confectioned the beetle. “I shall be so many for the beetle. “I shall be so many for the beetle. “I shall be so standen for the world, and the sun was so warm to the sea, and the sun was so warm to the sea, and the sound of the world from the bell, where the beetle was the sea, and the sound of the world from the bell, where the beetle was the sea, and the sound of the wood flowers and the sound of the wood, and the sound of the world from the bell, where the world from the wood, and the sound of the

这段输出对这几个单词 beetles、confectioners、sun 和 sea 有着奇怪的执念。

总结!

至此,我的结果远不及 Karpathy 的好,可能有一下几个原因:

但我得到了一些大致说得过去的结果!还不错!

 

来源:Linux中国内容投诉

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

软考中级精品资料免费领

  • 历年真题答案解析
  • 备考技巧名师总结
  • 高频考点精准押题
  • 2024年上半年信息系统项目管理师第二批次真题及答案解析(完整版)

    难度     813人已做
    查看
  • 【考后总结】2024年5月26日信息系统项目管理师第2批次考情分析

    难度     354人已做
    查看
  • 【考后总结】2024年5月25日信息系统项目管理师第1批次考情分析

    难度     318人已做
    查看
  • 2024年上半年软考高项第一、二批次真题考点汇总(完整版)

    难度     435人已做
    查看
  • 2024年上半年系统架构设计师考试综合知识真题

    难度     224人已做
    查看

相关文章

发现更多好内容

猜你喜欢

AI推送时光机
位置:首页-资讯-后端开发
咦!没有更多了?去看看其它编程学习网 内容吧
首页课程
资料下载
问答资讯