Class ov::op::v1::GatherTree¶
-
class GatherTree : public ov::op::Op¶
Generates the complete beams from the ids per each step and the parent beam ids.
Public Functions
-
GatherTree(const Output<Node> &step_ids, const Output<Node> &parent_idx, const Output<Node> &max_seq_len, const Output<Node> &end_token)¶
- Parameters
step_ids – Tensor of shape [MAX_TIME, BATCH_SIZE, BEAM_WIDTH] with indices from per each step
parent_idx – Tensor of shape [MAX_TIME, BATCH_SIZE, BEAM_WIDTH] with parent beam indices
max_seq_len – Tensor of shape [BATCH_SIZE] with maximum lengths for each sequence in the batch
end_token – Tensor of shape [MAX_TIME, BATCH_SIZE, BEAM_WIDTH]
-
virtual void validate_and_infer_types() override¶
Verifies that attributes and inputs are consistent and computes output shapes and element types. Must be implemented by concrete child classes so that it can be run any number of times.
Throws if the node is invalid.
-
GatherTree(const Output<Node> &step_ids, const Output<Node> &parent_idx, const Output<Node> &max_seq_len, const Output<Node> &end_token)¶