文章详情

短信预约-IT技能 免费直播动态提醒

请输入下面的图形验证码

提交验证

短信预约提醒成功

python的tf.train.batch函数怎么用

2023-06-30 12:18

关注

这篇文章主要介绍“python的tf.train.batch函数怎么用”的相关知识,小编通过实际案例向大家展示操作过程,操作方法简单快捷,实用性强,希望这篇“python的tf.train.batch函数怎么用”文章能帮助大家解决问题。

tf.train.batch函数

tf.train.batch(    tensors,    batch_size,    num_threads=1,    capacity=32,    enqueue_many=False,    shapes=None,    dynamic_pad=False,    allow_smaller_final_batch=False,    shared_name=None,    name=None)

其中:

tensors:利用slice_input_producer获得的数据组合。

batch_size:设置每次从队列中获取出队数据的数量。

num_threads:用来控制线程的数量,如果其值不唯一,由于线程执行的特性,数据获取可能变成乱序。

capacity:一个整数,用来设置队列中元素的最大数量

allow_samller_final_batch:当其为True时,如果队列中的样本数量小于batch_size,出队的数量会以最终遗留下来的样本进行出队;当其为False时,小于batch_size的样本不会做出队处理。

name:名字

测试代码

1、allow_samller_final_batch=True

import pandas as pdimport numpy as npimport tensorflow as tf# 生成数据def generate_data():    num = 18    label = np.arange(num)    return label# 获取数据def get_batch_data():    label = generate_data()    input_queue = tf.train.slice_input_producer([label], shuffle=False,num_epochs=2)    label_batch = tf.train.batch(input_queue, batch_size=5, num_threads=1, capacity=64,allow_smaller_final_batch=True)    return label_batch# 数据组label = get_batch_data()sess = tf.Session()# 初始化变量sess.run(tf.global_variables_initializer())sess.run(tf.local_variables_initializer())# 初始化batch训练的参数coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(sess,coord)try:    while not coord.should_stop():        # 自动获取下一组数据        l = sess.run(label)        print(l)except tf.errors.OutOfRangeError:    print('Done training')finally:    coord.request_stop()coord.join(threads)sess.close()

运行结果为:

[0 1 2 3 4]
[5 6 7 8 9]
[10 11 12 13 14]
[15 16 17  0  1]
[2 3 4 5 6]
[ 7  8  9 10 11]
[12 13 14 15 16]
[17]
Done training

2、allow_samller_final_batch=False

相比allow_samller_final_batch=True,输出结果少了[17]

import pandas as pdimport numpy as npimport tensorflow as tf# 生成数据def generate_data():    num = 18    label = np.arange(num)    return label# 获取数据def get_batch_data():    label = generate_data()    input_queue = tf.train.slice_input_producer([label], shuffle=False,num_epochs=2)    label_batch = tf.train.batch(input_queue, batch_size=5, num_threads=1, capacity=64,allow_smaller_final_batch=False)    return label_batch# 数据组label = get_batch_data()sess = tf.Session()# 初始化变量sess.run(tf.global_variables_initializer())sess.run(tf.local_variables_initializer())# 初始化batch训练的参数coord = tf.train.Coordinator()threads = tf.train.start_queue_runners(sess,coord)try:    while not coord.should_stop():        # 自动获取下一组数据        l = sess.run(label)        print(l)except tf.errors.OutOfRangeError:    print('Done training')finally:    coord.request_stop()coord.join(threads)sess.close()

运行结果为:

[0 1 2 3 4]
[5 6 7 8 9]
[10 11 12 13 14]
[15 16 17  0  1]
[2 3 4 5 6]
[ 7  8  9 10 11]
[12 13 14 15 16]
Done training

关于“python的tf.train.batch函数怎么用”的内容就介绍到这里了,感谢大家的阅读。如果想了解更多行业相关的知识,可以关注编程网行业资讯频道,小编每天都会为大家更新不同的知识点。

阅读原文内容投诉

免责声明:

① 本站未注明“稿件来源”的信息均来自网络整理。其文字、图片和音视频稿件的所属权归原作者所有。本站收集整理出于非商业性的教育和科研之目的,并不意味着本站赞同其观点或证实其内容的真实性。仅作为临时的测试数据,供内部测试之用。本站并未授权任何人以任何方式主动获取本站任何信息。

② 本站未注明“稿件来源”的临时测试数据将在测试完成后最终做删除处理。有问题或投稿请发送至: 邮箱/279061341@qq.com QQ/279061341

软考中级精品资料免费领

  • 历年真题答案解析
  • 备考技巧名师总结
  • 高频考点精准押题
  • 2024年上半年信息系统项目管理师第二批次真题及答案解析(完整版)

    难度     813人已做
    查看
  • 【考后总结】2024年5月26日信息系统项目管理师第2批次考情分析

    难度     354人已做
    查看
  • 【考后总结】2024年5月25日信息系统项目管理师第1批次考情分析

    难度     318人已做
    查看
  • 2024年上半年软考高项第一、二批次真题考点汇总(完整版)

    难度     435人已做
    查看
  • 2024年上半年系统架构设计师考试综合知识真题

    难度     224人已做
    查看

相关文章

发现更多好内容

猜你喜欢

AI推送时光机
位置:首页-资讯-后端开发
咦!没有更多了?去看看其它编程学习网 内容吧
首页课程
资料下载
问答资讯