tf.contrib.seq2seq.GreedyEmbeddingHelper

A helper for use during inference.

Inherits From: Helper

Uses the argmax of the output (treated as logits) and passes the result through an embedding layer to get the next input.

Args
embedding A callable that takes a vector tensor of ids (argmax ids), or the params argument for embedding_lookup. The returned tensor will be passed to the decoder input.
start_tokens int32 vector shaped [batch_size], the start tokens.
end_token int32 scalar, the token that marks end of decoding.
Raises
ValueError if start_tokens is not a 1D tensor or end_token is not a scalar.
Attributes
batch_size Batch size of tensor returned by sample.

Returns a scalar int32 tensor.

sample_ids_dtype DType of tensor returned by sample.

Returns a DType.

sample_ids_shape Shape of tensor returned by sample, excluding the batch dimension.

Returns a TensorShape.

Methods

initialize

View source

Returns (initial_finished, initial_inputs).

next_inputs

View source

next_inputs_fn for GreedyEmbeddingHelper.

sample

View source

sample for GreedyEmbeddingHelper.

© 2020 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/contrib/seq2seq/GreedyEmbeddingHelper