batch很好理解,就是batch size。注意在一个epoch中最后一个batch大小可能小于等于batch size
dataset.shuffle就是说维持一个buffer size 大小的 shuffle buffer,图中所需的每个样本从shuffle
buffer中获取,取得一个样本后,就从源数据集中加入一个样本到shuffle buffer中。
import os os.environ['CUDA_VISIBLE_DEVICES'] = "" import numpy as np import
tensorflowas tf np.random.seed(0) x = np.random.sample((11,2)) # make a dataset
from a numpy array print(x) print() dataset = dataset = dataset.shuffle(3) dataset =
dataset.batch(4) dataset = dataset.repeat(2) # create the iterator iter =
dataset.make_one_shot_iterator() el = iter.get_next()with tf.Session() as sess:
print( print( print( print(
print( print( print( print(
print( print( print( print(
print( #源数据集 [[ 0.5488135 0.71518937] [ 0.60276338 0.54488318] [
0.4236548 0.64589411] [ 0.43758721 0.891773 ] [ 0.96366276 0.38344152] [
0.79172504 0.52889492] [ 0.56804456 0.92559664] [ 0.07103606 0.0871293 ] [
0.0202184 0.83261985] [ 0.77815675 0.87001215] [ 0.97861834 0.79915856]] #
通过shuffle batch后取得的样本[[ 0.4236548 0.64589411] [ 0.60276338 0.54488318] [
0.43758721 0.891773 ] [ 0.5488135 0.71518937]] [[ 0.96366276 0.38344152] [
0.56804456 0.92559664] [ 0.0202184 0.83261985] [ 0.79172504 0.52889492]] [[
0.07103606 0.0871293 ] [ 0.97861834 0.79915856] [ 0.77815675 0.87001215]]
#最后一个batch样本个数为3 [[ 0.60276338 0.54488318] [ 0.5488135 0.71518937] [ 0.43758721
0.891773 ] [ 0.79172504 0.52889492]] [[ 0.4236548 0.64589411] [ 0.56804456
0.92559664] [ 0.0202184 0.83261985] [ 0.07103606 0.0871293 ]] [[ 0.77815675
0.87001215] [ 0.96366276 0.38344152] [ 0.97861834 0.79915856]] #最后一个batch样本个数为3
1、按照shuffle中设置的buffer size,首先从源数据集取得三个样本:
shuffle buffer:
[ 0.5488135 0.71518937]
[ 0.60276338 0.54488318]
[ 0.4236548 0.64589411]
shuffle buffer:
[ 0.5488135 0.71518937]
[ 0.60276338 0.54488318]
[ 0.4236548 0.64589411]
3、shuffle buffer不足三个样本,从源数据集提取一个样本:
shuffle buffer:
[ 0.5488135 0.71518937]
[ 0.60276338 0.54488318]
[ 0.43758721 0.891773 ]
shuffle buffer:
[ 0.5488135 0.71518937]
[ 0.43758721 0.891773 ]
[ 0.4236548 0.64589411]
[ 0.60276338 0.54488318]
5、如此反复。这就意味中如果shuffle 的buffer size=1,数据集不打乱。如果shuffle 的buffer
import os os.environ['CUDA_VISIBLE_DEVICES'] = "" import numpy as np import
tensorflow as tf np.random.seed(0) x = np.random.sample((11,2)) # make a
dataset from a numpy arrayprint(x) print() dataset = dataset = dataset.shuffle(1) dataset =
dataset.batch(4) dataset = dataset.repeat(2) # create the iterator iter =
dataset.make_one_shot_iterator() el = iter.get_next() with tf.Session() as sess:
print( print( print( print(
print( print( print( print(
print( print( print( print(
print( [[ 0.5488135 0.71518937] [ 0.60276338 0.54488318] [
0.4236548 0.64589411] [ 0.43758721 0.891773 ] [ 0.96366276 0.38344152] [
0.79172504 0.52889492] [ 0.56804456 0.92559664] [ 0.07103606 0.0871293 ] [
0.0202184 0.83261985] [ 0.77815675 0.87001215] [ 0.97861834 0.79915856]] [[
0.5488135 0.71518937] [ 0.60276338 0.54488318] [ 0.4236548 0.64589411] [
0.43758721 0.891773 ]] [[ 0.96366276 0.38344152] [ 0.79172504 0.52889492] [
0.56804456 0.92559664] [ 0.07103606 0.0871293 ]] [[ 0.0202184 0.83261985] [
0.77815675 0.87001215] [ 0.97861834 0.79915856]] [[ 0.5488135 0.71518937] [
0.60276338 0.54488318] [ 0.4236548 0.64589411] [ 0.43758721 0.891773 ]] [[
0.96366276 0.38344152] [ 0.79172504 0.52889492] [ 0.56804456 0.92559664] [
0.07103606 0.0871293 ]] [[ 0.0202184 0.83261985] [ 0.77815675 0.87001215] [
0.97861834 0.79915856]]
import os os.environ['CUDA_VISIBLE_DEVICES'] = "" import numpy as np import
tensorflow as tf np.random.seed(0) x = np.random.sample((11,2)) # make a
dataset from a numpy arrayprint(x) print() dataset = dataset = dataset.repeat(2) dataset =
dataset.shuffle(11) dataset = dataset.batch(4) # create the iterator iter =
dataset.make_one_shot_iterator() el = iter.get_next() with tf.Session() as sess:
print( print( print( print(
print( print( print( print(
print( print( print( print(
print( [[ 0.5488135 0.71518937] [ 0.60276338 0.54488318] [
0.4236548 0.64589411] [ 0.43758721 0.891773 ] [ 0.96366276 0.38344152] [
0.79172504 0.52889492] [ 0.56804456 0.92559664] [ 0.07103606 0.0871293 ] [
0.0202184 0.83261985] [ 0.77815675 0.87001215] [ 0.97861834 0.79915856]] [[
0.56804456 0.92559664] [ 0.5488135 0.71518937] [ 0.60276338 0.54488318] [
0.07103606 0.0871293 ]] [[ 0.96366276 0.38344152] [ 0.43758721 0.891773 ] [
0.43758721 0.891773 ] [ 0.77815675 0.87001215]] [[ 0.79172504 0.52889492]
#出现相同样本出现在同一个batch中 [ 0.79172504 0.52889492] [ 0.60276338 0.54488318] [
0.4236548 0.64589411]] [[ 0.07103606 0.0871293 ] [ 0.4236548 0.64589411] [
0.96366276 0.38344152] [ 0.5488135 0.71518937]] [[ 0.97861834 0.79915856] [
0.0202184 0.83261985] [ 0.77815675 0.87001215] [ 0.56804456 0.92559664]] [[
0.0202184 0.83261985] [ 0.97861834 0.79915856]] #可以看到最后个batch为2,而前面都是4 使用案例: def
input_fn(filenames, batch_size=32, num_epochs=1, perform_shuffle=False): print(
'Parsing', filenames) def decode_libsvm(line): #columns = tf.decode_csv(value,
record_defaults=CSV_COLUMN_DEFAULTS) #features = dict(zip(CSV_COLUMNS, columns))
#labels = features.pop(LABEL_COLUMN) columns = tf.string_split([line], ' ')
labels = tf.string_to_number(columns.values[0], out_type=tf.float32) splits =
tf.string_split(columns.values[1:], ':') id_vals =
tf.reshape(splits.values,splits.dense_shape) feat_ids, feat_vals =
tf.split(id_vals,num_or_size_splits=2,axis=1) feat_ids =
tf.string_to_number(feat_ids, out_type=tf.int32) feat_vals =
tf.string_to_number(feat_vals, out_type=tf.float32)#feat_ids =
tf.reshape(feat_ids,shape=[-1,FLAGS.field_size]) #for i in
range(splits.dense_shape.eval()[0]): #
feat_ids.append(tf.string_to_number(splits.values[2*i], out_type=tf.int32)) #
feat_vals.append(tf.string_to_number(splits.values[2*i+1])) #return
tf.reshape(feat_vals,shape=[-1,field_size]), labels return {"feat_ids":
feat_ids,"feat_vals": feat_vals}, labels # Extract lines from input files using
the Dataset API, can pass one filename or filename list dataset =, num_parallel_calls=10
).prefetch(500000) # multi-thread pre-process then prefetch # Randomizes input
using a window of 256 elements (read into memory) if perform_shuffle: dataset =
dataset.shuffle(buffer_size=256) # epochs from blending together. dataset =
dataset.repeat(num_epochs) dataset = dataset.batch(batch_size)# Batch size to
use #return dataset.make_one_shot_iterator() iterator =
dataset.make_one_shot_iterator() batch_features, batch_labels =
iterator.get_next()#return tf.reshape(batch_ids,shape=[-1,field_size]),
tf.reshape(batch_vals,shape=[-1,field_size]), batch_labels return
batch_features, batch_labels
