Class ov::op::v3::EmbeddingSegmentsSum#

class EmbeddingSegmentsSum : public ov::op::Op#

Returns embeddings for given indices.

Public Functions

EmbeddingSegmentsSum() = default#

Constructs a EmbeddingSegmentsSum operation.

EmbeddingSegmentsSum(const Output<Node> &emb_table, const Output<Node> &indices, const Output<Node> &segment_ids, const Output<Node> &num_segments, const Output<Node> &default_index, const Output<Node> &per_sample_weights)#

Constructs a EmbeddingSegmentsSum operation.

EmbeddingSegmentsSum constructs an output tensor by replacing every index in a given input tensor with a row (from the weights matrix) at that index

Parameters:
  • 'emb_table' – tensor containing the embedding lookup table of the module of shape [num_emb, emb_dim1, emb_dim2, …] and of type T

  • 'indices' – tensor of shape [num_indices] and of type T_IND. Required

  • <tt>segment_ids</tt> – tensor of shape [num_indices] and of type T_IND with indices into the output Tensor. Values should be sorted and can be repeated. Required.

  • <tt>num_segments</tt> – scalar of type T_IND indicating the number of segments. Required.

  • 'default_index' – scalar of type T_IND containing default index in embedding table to fill empty “bags”. If not provided empty “bags” are filled with zeros. Optional.

  • 'per_sample_weights' – tensor of the same shape as indices and of type T. Each value in this tensor are multiplied with each value pooled from embedding table for each index. Optional.

virtual void validate_and_infer_types() override#

Verifies that attributes and inputs are consistent and computes output shapes and element types. Must be implemented by concrete child classes so that it can be run any number of times.

Throws if the node is invalid.