Tensorflow:np数组的next_batch函数 [英] Tensorflow: next_batch function of np array

查看:463
本文介绍了Tensorflow:np数组的next_batch函数的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

我的火车数据为

xTrain = numpy.asarray([100, 1, 5, 6 ...])
yTrain = numpy.asarray([200, 2, 10, 12 ...])

如何定义next_batch(size)方法以从训练数据中获取随机元素的大小数.

How to define next_batch(size) method to get size number of random elements from the train data.

推荐答案

您可以将其用作next batch函数:

def batch_data(source, target, batch_size):

   # Shuffle data
   shuffle_indices = np.random.permutation(np.arange(len(target)))
   source = source[shuffle_indices]
   target = target[shuffle_indices]

   for batch_i in range(0, len(source)//batch_size):
      start_i = batch_i * batch_size
      source_batch = source[start_i:start_i + batch_size]
      target_batch = target[start_i:start_i + batch_size]

      yield np.array(source_batch), np.array(target_batch)

这篇关于Tensorflow:np数组的next_batch函数的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆