tf.contrib.recurrent.Recurrent
Compute a recurrent neural net.
tf.contrib.recurrent.Recurrent( theta, state0, inputs, cell_fn, cell_grad=None, extras=None, max_input_length=None, use_tpu=False, aligned_end=False )
Roughly, Recurrent() computes the following: state = state0 for t in inputs' sequence length: state = cell_fn(theta, state, inputs[t, :]) accumulate_state[t, :] = state return accumulate_state, state
theta, state, inputs are all structures of tensors.
inputs[t, :] means taking a slice out from every tensor in the inputs.
accumulate_state[t, :] = state means that we stash every tensor in 'state' into a slice of the corresponding tensor in accumulate_state.
cell_fn is a python callable computing (building up a TensorFlow graph) the recurrent neural network's one forward step. Two calls of cell_fn must describe two identical computations.
By construction, Recurrent()'s backward computation does not access any intermediate values computed by cell_fn during forward computation. We may extend Recurrent() to support that by taking a customized backward function of cell_fn.
Args | |
---|---|
theta | weights. A structure of tensors. |
state0 | initial state. A structure of tensors. |
inputs | inputs. A structure of tensors. |
cell_fn | A python function, which computes: state1, extras = cell_fn(theta, state0, inputs[t, :]) |
cell_grad | A python function which computes: dtheta, dstate0, dinputs[t, :] = cell_grad( theta, state0, inputs[t, :], extras, dstate1) |
extras | A structure of tensors. The 2nd return value of every invocation of cell_fn is a structure of tensors with matching keys and shapes of this extras . |
max_input_length | maximum length of effective input. This is used to truncate the computation if the inputs have been allocated to a larger size. A scalar tensor. |
use_tpu | whether or not we are on TPU. |
aligned_end | A boolean indicating whether the sequence is aligned at the end. |
Returns | |
---|---|
accumulate_state and the final state. |
© 2020 The TensorFlow Authors. All rights reserved.
Licensed under the Creative Commons Attribution License 3.0.
Code samples licensed under the Apache 2.0 License.
https://www.tensorflow.org/versions/r1.15/api_docs/python/tf/contrib/recurrent/Recurrent