tensorflow nmt Data preprocessing process of  

staytensorflow/nmt Item, The input of training data and inference data uses the newDataset API, Should betensorflow
1.2 Later introducedAPI, Convenient data operation. If you're still using the oldQueue andCoordinator Way, It is recommended to upgrade to a higher version oftensorflow And useDataset
API.

This tutorial will focus on training data and inferring data, Detailed analysis of the specific processing process of analytical data, You'll see how the text data is converted to the real numbers required by the model, And what is the dimension of the middle tensor,
batch_size And how do other super parameters work.

Training data processing

Let's first look at the processing of training data. The processing of training data is slightly more complicated than that of inferential data, Understand the process of training data processing, You can easily understand the processing of inferential data.
The processing code of training data is located innmt/utils/iterator_utils.py In fileget_iterator function. Let's take a look at the parameters required for this function:

parameter explain
src_dataset Source dataset
tgt_dataset Target data set
src_vocab_table Source data word lookup table, It's a word andint Corresponding table of type data
tgt_vocab_table Target data word lookup table, It's a word andint Corresponding table of type data
batch_size Batch size
sos Sentence start mark
eos End of sentence marker
random_seed Random seed, Used to scramble data sets
num_buckets Barrels quantity
src_max_len Maximum length of source data
tgt_max_len Maximum length of target data
num_parallel_calls Concurrent number of concurrent processing data
output_buffer_size Output buffer size
skip_count Skip data rows
num_shards Number of data set segments, Useful in distributed training
shard_index After data set fragmentationid
reshuffle_each_iteration Whether to re order every iteration
The above explanation, If there is something unclear, You can see my previous article on super parameters:
tensorflow_nmt A detailed explanation of the super parameters of
<https://blog.csdn.net/stupid_3/article/details/tensorflow_nmt Super parameter.md>

The main code of this function to process training data is as follows:
if not output_buffer_size: output_buffer_size = batch_size * 1000 src_eos_id =
tf.cast(src_vocab_table.lookup(tf.constant(eos)), tf.int32) tgt_sos_id =
tf.cast(tgt_vocab_table.lookup(tf.constant(sos)), tf.int32) tgt_eos_id =
tf.cast(tgt_vocab_table.lookup(tf.constant(eos)), tf.int32) src_tgt_dataset =
tf.data.Dataset.zip((src_dataset, tgt_dataset)) src_tgt_dataset =
src_tgt_dataset.shard(num_shards, shard_index)if skip_count is not None:
src_tgt_dataset = src_tgt_dataset.skip(skip_count) src_tgt_dataset =
src_tgt_dataset.shuffle( output_buffer_size, random_seed,
reshuffle_each_iteration) src_tgt_dataset = src_tgt_dataset.map(lambda src,
tgt: ( tf.string_split([src]).values, tf.string_split([tgt]).values),
num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size)# Filter
zero length input sequences. src_tgt_dataset = src_tgt_dataset.filter( lambda
src, tgt: tf.logical_and(tf.size(src) >0, tf.size(tgt) > 0)) if src_max_len:
src_tgt_dataset = src_tgt_dataset.map(lambda src, tgt: (src[:src_max_len],
tgt), num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size)if
tgt_max_len: src_tgt_dataset = src_tgt_dataset.map(lambda src, tgt: (src,
tgt[:tgt_max_len]),
num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size)# Convert
the word strings to ids. Word strings that are not in the # vocab get the
lookup table's default_value integer. src_tgt_dataset = src_tgt_dataset.map(
lambda src, tgt: (tf.cast(src_vocab_table.lookup(src), tf.int32),
tf.cast(tgt_vocab_table.lookup(tgt), tf.int32)),
num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size)# Create a
tgt_input prefixed with <sos> and a tgt_output suffixed with <eos>.
src_tgt_dataset = src_tgt_dataset.map(lambda src, tgt: (src,
tf.concat(([tgt_sos_id], tgt),0), tf.concat((tgt, [tgt_eos_id]), 0)),
num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size)# Add in
sequence lengths. src_tgt_dataset = src_tgt_dataset.map( lambda src, tgt_in,
tgt_out: ( src, tgt_in, tgt_out, tf.size(src), tf.size(tgt_in)),
num_parallel_calls=num_parallel_calls).prefetch(output_buffer_size)
Let's step by step analyze, What did this process do, How does the data tensor change.

How to align data

num_buckets What does it do

num_buckets The code that works is:  
if num_buckets > 1: def key_func(unused_1, unused_2, unused_3, src_len,
tgt_len): # Calculate bucket_width by maximum source sequence length. # Pairs
with length [0, bucket_width) go to bucket 0, length # [bucket_width, 2 *
bucket_width) go to bucket 1, etc. Pairs with length # over ((num_bucket-1) *
bucket_width) words all go into the last bucket. if src_max_len: bucket_width =
(src_max_len + num_buckets -1) // num_buckets else: bucket_width = 10 # Bucket
sentence pairs by the length of their source sentence and target # sentence.
bucket_id = tf.maximum(src_len // bucket_width, tgt_len // bucket_width)return
tf.to_int64(tf.minimum(num_buckets, bucket_id))def reduce_func(unused_key,
windowed_data): return batching_func(windowed_data) batched_dataset =
src_tgt_dataset.apply( tf.contrib.data.group_by_window( key_func=key_func,
reduce_func=reduce_func, window_size=batch_size))