Class ov::op::util::EmbeddingBagPackedBase#

class EmbeddingBagPackedBase : public ov::op::Op#

Returns embeddings for given indices.

Subclassed by ov::op::v3::EmbeddingBagPackedSum

Public Functions

EmbeddingBagPackedBase() = default#

Constructs a EmbeddingBagPackedBase operation.

EmbeddingBagPackedBase(const Output<Node> &emb_table, const Output<Node> &indices, const Output<Node> &per_sample_weights)#

Constructs a EmbeddingBagPackedBase operation.

EmbeddingBagPackedBase constructs an output tensor by replacing every index in a given input tensor with a row (from the weights matrix) at that index

Parameters:
  • emb_tableTensor containing the embedding lookup table of the module of shape [num_emb, emb_dim1, emb_dim2, …] and of type T

  • indicesTensor of shape [batch, indices_per_bag] and of type T_IND. Required.

  • per_sample_weigths – tensor of the same shape as indices and of type T. Each value in this tensor are multiplied with each value pooled from embedding table for each index. Optional.

virtual void validate_and_infer_types() override#

Verifies that attributes and inputs are consistent and computes output shapes and element types. Must be implemented by concrete child classes so that it can be run any number of times.

Throws if the node is invalid.