tensorflow nmt Data preprocessing process of   

stay tensorflow/nmt In the project , The input of training data and inference data uses the new Dataset API, Should be tensorflow
1.2 Later introduced API, Convenient data operation . If you're still using the old Queue and Coordinator How , It is recommended to upgrade to a higher version of tensorflow And use Dataset

This tutorial will focus on training data and inferring data , Detailed analysis of the specific processing process of analytical data , You'll see how the text data is converted to the real numbers required by the model , And what is the dimension of the middle tensor ,
batch_size And how do other super parameters work .

Training data processing

Let's first look at the processing of training data . The processing of training data is slightly more complicated than that of inferential data , Understand the process of training data processing , You can easily understand the processing of inferential data .
The processing code of training data is located in nmt/utils/iterator_utils.py In document get_iterator function . Let's take a look at the parameters required for this function :

parameter explain
src_dataset Source data set
tgt_dataset Target data set
src_vocab_table Source data word lookup table , It's a word and int Corresponding table of type data
tgt_vocab_table Target data word lookup table , It's a word and int Corresponding table of type data
batch_size Batch size
sos Sentence start mark
eos End of sentence marker
random_seed Random seed , Used to scramble data sets
num_buckets Number of barrels
src_max_len Maximum length of source data
tgt_max_len Maximum length of target data
num_parallel_calls Concurrent number of concurrent processing data
output_buffer_size Output buffer size
skip_count Skip data rows
num_shards Number of data set segments , Useful in distributed training
shard_index After data set fragmentation id
reshuffle_each_iteration Whether to re order every iteration
The above explanation , If there is something unclear , You can see my previous article on super parameters :
tensorflow_nmt A detailed explanation of the super parameters of
<https://blog.csdn.net/stupid_3/article/details/tensorflow_nmt Super parameter of .md>

The main code of this function to process training data is as follows :
if not output_buffer_size: output_buffer_size = batch_size * 1000 src_eos_id =
tf.cast(src_vocab_table.lookup(tf.constant(eos)), tf.int32) tgt_sos_id =
tf.cast(tgt_vocab_table.lookup(tf.constant(sos)), tf.int32) tgt_eos_id =
tf.cast(tgt_vocab_table.lookup(tf.constant(eos)), tf.int32) src_tgt_dataset =
tf.data.Dataset.zip((src_dataset, tgt_dataset)) src_tgt_dataset =
src_tgt_dataset.shard(num_shards, shard_index)if skip_count is not None:
src_tgt_dataset = src_tgt_dataset.skip(skip_count) src_tgt_dataset =
src_tgt_dataset.shuffle( output_buffer_size, random_seed,
reshuffle_each_iteration) src_tgt_dataset = src_tgt_dataset.map(lambda src,
tgt: ( tf.string_split([src]).values, tf.string_split([tgt]).values),
num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size)# Filter
zero length input sequences. src_tgt_dataset = src_tgt_dataset.filter( lambda
src, tgt: tf.logical_and(tf.size(src) >0, tf.size(tgt) > 0)) if src_max_len:
src_tgt_dataset = src_tgt_dataset.map(lambda src, tgt: (src[:src_max_len],
tgt), num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size)if
tgt_max_len: src_tgt_dataset = src_tgt_dataset.map(lambda src, tgt: (src,
num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size)# Convert
the word strings to ids. Word strings that are not in the # vocab get the
lookup table's default_value integer. src_tgt_dataset = src_tgt_dataset.map(
lambda src, tgt: (tf.cast(src_vocab_table.lookup(src), tf.int32),
tf.cast(tgt_vocab_table.lookup(tgt), tf.int32)),
num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size)# Create a
tgt_input prefixed with <sos> and a tgt_output suffixed with <eos>.
src_tgt_dataset = src_tgt_dataset.map(lambda src, tgt: (src,
tf.concat(([tgt_sos_id], tgt),0), tf.concat((tgt, [tgt_eos_id]), 0)),
num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size)# Add in
sequence lengths. src_tgt_dataset = src_tgt_dataset.map( lambda src, tgt_in,
tgt_out: ( src, tgt_in, tgt_out, tf.size(src), tf.size(tgt_in)),
Let's step by step analyze , What did this process do , How does the data tensor change .

How to align data

num_buckets What does it do

num_buckets The code that works is :  
if num_buckets > 1: def key_func(unused_1, unused_2, unused_3, src_len,
tgt_len): # Calculate bucket_width by maximum source sequence length. # Pairs
with length [0, bucket_width) go to bucket 0, length # [bucket_width, 2 *
bucket_width) go to bucket 1, etc. Pairs with length # over ((num_bucket-1) *
bucket_width) words all go into the last bucket. if src_max_len: bucket_width =
(src_max_len + num_buckets -1) // num_buckets else: bucket_width = 10 # Bucket
sentence pairs by the length of their source sentence and target # sentence.
bucket_id = tf.maximum(src_len // bucket_width, tgt_len // bucket_width)return
tf.to_int64(tf.minimum(num_buckets, bucket_id))def reduce_func(unused_key,
windowed_data): return batching_func(windowed_data) batched_dataset =
src_tgt_dataset.apply( tf.contrib.data.group_by_window( key_func=key_func,
reduce_func=reduce_func, window_size=batch_size))