文章详情

短信预约-IT技能 免费直播动态提醒

请输入下面的图形验证码

提交验证

短信预约提醒成功

为什么加载之前保存的Keras模型得出不一样的结果:经验和教训

2024-12-02 08:00

关注

我不会详细介绍如何使用和保存Keras模型,只是假设读者熟悉该过程,直接介绍如何处理加载时意外的模型行为。也就是说,在训练存储在Model变量中的Keras模型之后,我们希望将其保存为原样,那样下次加载时我们可以跳过训练,就进行预测。

我首选的方法是保存模型的权重,权重在模型创建开始时是随机的,随着模型的训练而加以更新。于是我点击了model.save_weights(“model.h5”)。创建了“model.h5”文件,含有模型学习到的权重。接下来,在另一个会话中,我使用与以前相同的架构重建模型,并使用 new_model.load_weights(“model.h5”)加载我保存的训练权重。一切似乎都很好。只是我点击 new_model.predict(test_data)后,得到的准确性为零,不知道为什么。

事实证明,模型无法做出正确的预测有诸多原因。我在本文试着总结最常见的原因,并介绍如何解决。

1. 先仔细检查数据。

我知道这似乎很明显,但是从磁盘重新加载模型时,一有疏忽就会导致性能下降。比如说,如果您在构建语言模型,应确保在每个新会话中,您执行以下操作:

当然,您可能会遇到其他与数据相关的问题,具体取决于您从事的领域。然而,请始终检查数据表示的一致性。

2. 度量指标问题

导致错误或结果不一致的另一个原因是,准确性度量指标的选择。在构建模型并保存其权重时,我们通常执行以下操作:

def build_model(max_len, n_tags): 
input_layer = Input(shape=(max_len, ))
output_layer = Dense(n_tags, activation = 'softmax')(input_layer)
model = Model(input_layer, output_layer)

return model

model = build_model()
model.compile(optimizer="adam",
loss="sparse_categorical_crossentropy", metrics=["accuracy"])

model.fit(..)
model.save_weights("model.h5")

如果我们需要在新的会话/脚本中打开它,需要执行以下操作:

def build_model(max_len, n_tags): 
input_layer = Input(shape=(max_len, ))
output_layer = Dense(n_tags, activation = 'softmax')(input_layer)
model = Model(input_layer, output_layer)

return model
model = build_model()
model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
model.load_weights("model.h5")
model.evaluate()

这可能抛出错误,具体视所使用的特定的Keras/Tensorflow版本而定。编译模型并选择“准确性”作为指标时,会出现问题。Keras识别准确性的各种定义:“稀疏分类准确性”、“分类准确性”等;视您使用的数据而定,不同的定义是优选的解决方案。这是由于如果我们将度量指标设为“准确性”,Keras将试着分配其中一种特定的准确性类型,具体取决于它认为哪一种最适合数据分布。它可能会在不同的运行中推断出不同的准确性指标。这里最好的解决方法是,始终明确设置准确性指标,而不是让Keras自行选择。比如说,把

model.compile(optimizer="adam", 
loss="sparse_categorical_crossentropy", metrics=["accuracy"])

换成:

model.compile(optimizer="adam", 
loss="sparse_categorical_crossentropy", metrics=["sparse_categorical_accuracy"])

3. 随机性

在与以前相同的数据上重新训练Keras神经网络时,您很少两次获得同样的结果。这是由于Keras中的神经网络在初始化权重时使用随机性,因此每次运行时权重的初始化方式都不同,因此在学习过程中这些权重会以不同方式更新,于是在进行预测时不太可能获得相同的准确性结果。

如果出于某种原因,您需要在训练之前使权重相等,可以在代码前面设置随机数生成器:

from numpy.random import seed
seed(42)
from tensorflow import set_random_seed
set_random_seed(42)

numpy随机种子用于Keras,而至于Tensorflow后端,我们需要将其自己的随机数生成器设置为相等的种子。该代码片段将确保每次运行代码时,您的神经网络权重都会被同等地初始化。

4. 留意自定义层的使用

Keras提供了众多层(Dense、LSTM、Dropout和BatchNormalizaton等),但有时我们希望对模型中的数据采取某种特定的操作,但又没有为它定义的特定层。一般来说,Keras提供了两种类型的层:Lambda和基础层类。但对这两种层要很小心,如果您将模型架构保存为json格式更要小心。Lambda层的棘手地方在于序列化限制。由于它与Python字节码的序列化一同保存,它只能加载到保存它的同一个环境中,即它不可移植。遇到该问题时,通常建议覆盖keras.layers.Layer层,或者只保存其权重,从头开始重建模型,而不是保存整个模型。

5. 自定义对象

很多时候,您会想要使用自定义函数应用于数据,或计算损失/准确性等指标的函数。

Keras允许这种使用,为此让我们可以在保存/加载模型时指定额外的参数。假设我们想要将我们自行创建的特殊的损失函数与之前保存的模型一并加载:

model = load_model("model.h5", custom_objects=
{"custom_loss":custom_loss})

如果我们在新环境中加载该模型,必须在新环境中小心定义custom_loss函数,因为默认情况下,保存模型时不会记住这些函数。即使我们保存了模型的整个架构,它也会保存该自定义函数的名称,但函数体是我们需要额外提供的东西。

6. 全局变量初始化器

如果您使用Tensorflow 1.x作为后端——您可能仍然需要该后端用于许多应用程序,这点尤为重要。运行tf 1.x会话时,您需要运行tf.global_variables_initializer(),它随机初始化所有变量。这么做的副作用是,当您尝试保存模型时,它可能重新初始化所有权重。您可以手动停止该行为,只需运行:

from keras.backend import manual_variable_initialization manual_variable_initialization(True)

结语

本文列出了最常导致您的Keras模型无法在新环境中正确加载的几个因素。有时这些问题导致不可预测的结果,而在其他情况下,它们只会抛出错误。它们何时发生、如何发生,在很大程度上也取决于您使用的Python版本以及Tensorflow和Keras版本,因为其中一些版本不相兼容,从而导致意外的行为。但愿读完本文后,下次遇到此类问题时您知道从何处入手。

原文Why Loading a Previously Saved Keras Model Gives Different Results: Lessons Learned,作者:Kristina Popova

来源:51CTO内容投诉

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

软考中级精品资料免费领

  • 历年真题答案解析
  • 备考技巧名师总结
  • 高频考点精准押题
  • 2024年上半年信息系统项目管理师第二批次真题及答案解析(完整版)

    难度     813人已做
    查看
  • 【考后总结】2024年5月26日信息系统项目管理师第2批次考情分析

    难度     354人已做
    查看
  • 【考后总结】2024年5月25日信息系统项目管理师第1批次考情分析

    难度     318人已做
    查看
  • 2024年上半年软考高项第一、二批次真题考点汇总(完整版)

    难度     435人已做
    查看
  • 2024年上半年系统架构设计师考试综合知识真题

    难度     224人已做
    查看

相关文章

发现更多好内容

猜你喜欢

AI推送时光机
位置:首页-资讯-后端开发
咦!没有更多了?去看看其它编程学习网 内容吧
首页课程
资料下载
问答资讯