文章详情

短信预约-IT技能 免费直播动态提醒

请输入下面的图形验证码

提交验证

短信预约提醒成功

keras回调函数如何使用

2023-07-05 11:50

关注

这篇文章主要介绍了keras回调函数如何使用的相关知识,内容详细易懂,操作简单快捷,具有一定借鉴价值,相信大家阅读完这篇keras回调函数如何使用文章都会有所收获,下面我们一起来看看吧。

回调函数

fit()方法中使用callbacks参数

# 这里有两个callback函数:早停和模型检查点callbacks_list=[    keras.callbacks.EarlyStopping(        monitor="val_accuracy",#监控指标        patience=2 #两轮内不再改善中断训练    ),    keras.callbacks.ModelCheckpoint(        filepath="checkpoint_path",        monitor="val_loss",        save_best_only=True    )]#模型获取model=get_minist_model()model.compile(optimizer="rmsprop",             loss="sparse_categorical_crossentropy",             metrics=["accuracy"])model.fit(train_images,train_labels,         epochs=10,callbacks=callbacks_list, #该参数使用回调函数         validation_data=(val_images,val_labels))test_metrics=model.evaluate(test_images,test_labels)#计算模型在新数据上的损失和指标predictions=model.predict(test_images)#计算模型在新数据上的分类概率

keras回调函数如何使用

模型的保存和加载

#也可以在训练完成后手动保存模型,只需调用model.save('my_checkpoint_path')。#重新加载模型model_new=keras.models.load_model("checkpoint_path.keras")

通过对Callback类子类化来创建自定义回调函数

on_epoch_begin(epoch, logs) ←----在每轮开始时被调用
on_epoch_end(epoch, logs) ←----在每轮结束时被调用
on_batch_begin(batch, logs) ←----在处理每个批量之前被调用
on_batch_end(batch, logs) ←----在处理每个批量之后被调用
on_train_begin(logs) ←----在训练开始时被调用
on_train_end(logs ←----在训练结束时被调用

from matplotlib import pyplot as plt# 实现记录每一轮中每个batch训练后的损失,并为每个epoch绘制一个图class LossHistory(keras.callbacks.Callback):    def on_train_begin(self, logs):        self.per_batch_losses = []    def on_batch_end(self, batch, logs):        self.per_batch_losses.append(logs.get("loss"))    def on_epoch_end(self, epoch, logs):        plt.clf()        plt.plot(range(len(self.per_batch_losses)), self.per_batch_losses,                 label="Training loss for each batch")        plt.xlabel(f"Batch (epoch {epoch})")        plt.ylabel("Loss")        plt.legend()        plt.savefig(f"plot_at_epoch_{epoch}")        self.per_batch_losses = [] #清空,方便下一轮的技术
model = get_mnist_model()model.compile(optimizer="rmsprop",              loss="sparse_categorical_crossentropy",              metrics=["accuracy"])model.fit(train_images, train_labels,          epochs=10,          callbacks=[LossHistory()],          validation_data=(val_images, val_labels))

keras回调函数如何使用

【其他】模型的定义 和 数据加载

def get_minist_model():    inputs=keras.Input(shape=(28*28,))    features=layers.Dense(512,activation="relu")(inputs)    features=layers.Dropout(0.5)(features)    outputs=layers.Dense(10,activation="softmax")(features)    model=keras.Model(inputs,outputs)    return model    #datsetfrom tensorflow.keras.datasets import mnist(train_images,train_labels),(test_images,test_labels)=mnist.load_data()train_images=train_images.reshape((60000,28*28)).astype("float32")/255test_images=test_images.reshape((10000,28*28)).astype("float32")/255train_images,val_images=train_images[10000:],train_images[:10000]train_labels,val_labels=train_labels[10000:],train_labels[:10000]

关于“keras回调函数如何使用”这篇文章的内容就介绍到这里,感谢各位的阅读!相信大家对“keras回调函数如何使用”知识都有一定的了解,大家如果还想学习更多知识,欢迎关注编程网行业资讯频道。

阅读原文内容投诉

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

软考中级精品资料免费领

  • 历年真题答案解析
  • 备考技巧名师总结
  • 高频考点精准押题
  • 2024年上半年信息系统项目管理师第二批次真题及答案解析(完整版)

    难度     813人已做
    查看
  • 【考后总结】2024年5月26日信息系统项目管理师第2批次考情分析

    难度     354人已做
    查看
  • 【考后总结】2024年5月25日信息系统项目管理师第1批次考情分析

    难度     318人已做
    查看
  • 2024年上半年软考高项第一、二批次真题考点汇总(完整版)

    难度     435人已做
    查看
  • 2024年上半年系统架构设计师考试综合知识真题

    难度     224人已做
    查看

相关文章

发现更多好内容

猜你喜欢

AI推送时光机
位置:首页-资讯-后端开发
咦!没有更多了?去看看其它编程学习网 内容吧
首页课程
资料下载
问答资讯