Versioned name: GatherTree-1
Category: Beam search post-processing
Short description: Generates the complete beams from the ids per each step and the parent beam ids.
Detailed description
The GatherTree operation implements the same algorithm as the GatherTree operation in TensorFlow.
Pseudo code:
final_idx[ :, :, :] = end_token
for batch in range(BATCH_SIZE):
for beam in range(BEAM_WIDTH):
max_sequence_in_beam = min(MAX_TIME, max_seq_len[batch])
parent = parent_idx[max_sequence_in_beam - 1, batch, beam]
final_idx[max_sequence_in_beam - 1, batch, beam] = step_idx[max_sequence_in_beam - 1, batch, beam]
for level in reversed(range(max_sequence_in_beam - 1)):
final_idx[level, batch, beam] = step_idx[level, batch, parent]
parent = parent_idx[level, batch, parent]
# For a given beam, past the time step containing the first decoded end_token
# all values are filled in with end_token.
finished = False
for time in range(max_sequence_in_beam):
if(finished):
final_idx[time, batch, beam] = end_token
elif(final_idx[time, batch, beam] == end_token):
finished = True
Element data types for all input tensors should match each other.
Attributes: GatherTree has no attributes
Inputs
- 1:
step_ids
– a tensor of shape [MAX_TIME, BATCH_SIZE, BEAM_WIDTH]
of type T
with indices from per each step. Required.
- 2:
parent_idx
– a tensor of shape [MAX_TIME, BATCH_SIZE, BEAM_WIDTH]
of type T
with parent beam indices. Required.
- 3:
max_seq_len
– a tensor of shape [BATCH_SIZE]
of type T
with maximum lengths for each sequence in the batch. Required.
- 4:
end_token
– a scalar tensor of type T
with value of the end marker in a sequence. Required.
Outputs
- 1:
final_idx
– a tensor of shape [MAX_TIME, BATCH_SIZE, BEAM_WIDTH]
of type T
.
Types
- T:
float32
or int32
; float32
should have integer values only.
Example
<layer type="GatherTree" ...>
<input>
<port id="0">
<dim>100</dim>
<dim>1</dim>
<dim>10</dim>
</port>
<port id="1">
<dim>100</dim>
<dim>1</dim>
<dim>10</dim>
</port>
<port id="2">
<dim>1</dim>
</port>
<port id="3">
</port>
</input>
<output>
<port id="0">
<dim>100</dim>
<dim>1</dim>
<dim>10</dim>
</port>
</output>
</layer>