embeddingbag_offsets_base.hpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 
17 #pragma once
18 
19 #include "ngraph/axis_set.hpp"
20 #include "ngraph/op/util/index_reduction.hpp"
21 
22 namespace ngraph
23 {
24  namespace op
25  {
26  namespace util
27  {
28  /// \brief Returns embeddings for given indices
29  class NGRAPH_API EmbeddingBagOffsetsBase : public Op
30  {
31  public:
32  static constexpr NodeTypeInfo type_info{"EmbeddingBagOffsetsBase", 3};
33  const NodeTypeInfo& get_type_info() const override { return type_info; }
34  /// \brief Constructs a EmbeddingBagOffsetsBase operation.
36  /// \brief Constructs a EmbeddingBagOffsetsBase operation.
37  ///
38  /// EmbeddingBagOffsetsBase constructs an output tensor by replacing every index in
39  /// a
40  /// given
41  /// input tensor with a row (from the weights matrix) at that index
42  ///
43  /// \param emb_table tensor containing the embedding lookup table of the module of
44  /// shape [num_emb, emb_dim1, emb_dim2, ...] and of type T
45  /// \param tensor of shape [num_indices] and of type T_IND. Required
46  /// \param offsets tensor of shape [batch] and of type T_IND containing the starting
47  /// index positions of each "bag" in indices. Required.
48  /// \param per_sample_weigths tensor of the same shape as indices and of type T.
49  /// Each value in this tensor are multiplied with each
50  /// value pooled from embedding table for each index. Optional.
51  /// \param default_index scalar of type T_IND containing default index in embedding
52  /// table to fill empty "bags". If not provided empty "bags"
53  /// are filled with zeros. Optional.
54 
55  EmbeddingBagOffsetsBase(const Output<Node>& emb_table,
56  const Output<Node>& indices,
57  const Output<Node>& offsets,
58  const Output<Node>& default_index,
59  const Output<Node>& per_sample_weights);
60 
61  EmbeddingBagOffsetsBase(const Output<Node>& emb_table,
62  const Output<Node>& indices,
63  const Output<Node>& offsets,
64  const Output<Node>& default_index);
65 
66  EmbeddingBagOffsetsBase(const Output<Node>& emb_table,
67  const Output<Node>& indices,
68  const Output<Node>& offsets);
69 
70  void validate_and_infer_types() override;
71  bool visit_attributes(AttributeVisitor& visitor) override;
72 
73  private:
74  static constexpr int EMB_TABLE = 0;
75  static constexpr int INDICES = 1;
76  static constexpr int OFFSETS = 2;
77  static constexpr int DEFAULT_INDEX = 3;
78  static constexpr int PER_SAMPLE_WEIGHTS = 4;
79  };
80  }
81  }
82 }
ngraph::op::util::EmbeddingBagOffsetsBase::EmbeddingBagOffsetsBase
EmbeddingBagOffsetsBase()=default
Constructs a EmbeddingBagOffsetsBase operation.
ngraph::op::util::EmbeddingBagOffsetsBase::EmbeddingBagOffsetsBase
EmbeddingBagOffsetsBase(const Output< Node > &emb_table, const Output< Node > &indices, const Output< Node > &offsets, const Output< Node > &default_index, const Output< Node > &per_sample_weights)
Constructs a EmbeddingBagOffsetsBase operation.
ngraph::op::util::EmbeddingBagOffsetsBase::get_type_info
const NodeTypeInfo & get_type_info() const override
Definition: embeddingbag_offsets_base.hpp:33
ngraph
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28
ngraph::AttributeVisitor
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:70
ngraph::op::util::EmbeddingBagOffsetsBase
Returns embeddings for given indices.
Definition: embeddingbag_offsets_base.hpp:30
ngraph::op::util::EmbeddingBagOffsetsBase::validate_and_infer_types
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
ngraph::op::Op
Root of all actual ops.
Definition: op.hpp:29