embedding_segments_sum.hpp
1 //*****************************************************************************
2 // Copyright 2017-2021 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 
17 #pragma once
18 
19 #include "ngraph/axis_set.hpp"
20 #include "ngraph/op/util/index_reduction.hpp"
21 
22 namespace ngraph
23 {
24  namespace op
25  {
26  namespace v3
27  {
28  /// \brief Returns embeddings for given indices
29  class NGRAPH_API EmbeddingSegmentsSum : public Op
30  {
31  public:
32  static constexpr NodeTypeInfo type_info{"EmbeddingSegmentsSum", 3};
33  const NodeTypeInfo& get_type_info() const override { return type_info; }
34  /// \brief Constructs a EmbeddingSegmentsSum operation.
35  EmbeddingSegmentsSum() = default;
36  /// \brief Constructs a EmbeddingSegmentsSum operation.
37  ///
38  /// EmbeddingSegmentsSum constructs an output tensor by replacing every index in a
39  /// given
40  /// input tensor with a row (from the weights matrix) at that index
41  ///
42  /// \param 'emb_table' tensor containing the embedding lookup table of the module of
43  /// shape [num_emb, emb_dim1, emb_dim2, ...] and of type T
44  /// \param 'indices' tensor of shape [num_indices] and of type T_IND. Required
45  /// \param `segment_ids` tensor of shape `[num_indices]` and of type *T_IND* with
46  /// indices
47  /// into the output Tensor. Values should be sorted and can be repeated. Required.
48  /// \param `num_segments` scalar of type *T_IND* indicating the number of segments.
49  /// Required.
50  /// \param 'default_index' scalar of type T_IND containing default index in
51  /// embedding
52  /// table to fill empty "bags". If not provided empty "bags"
53  /// are filled with zeros. Optional.
54  /// \param 'per_sample_weights' tensor of the same shape as indices and of type T.
55  /// Each value in this tensor are multiplied with each
56  /// value pooled from embedding table for each index. Optional.
57 
59  const Output<Node>& indices,
60  const Output<Node>& segment_ids,
61  const Output<Node>& num_segments,
62  const Output<Node>& default_index,
63  const Output<Node>& per_sample_weights);
64 
65  EmbeddingSegmentsSum(const Output<Node>& emb_table,
66  const Output<Node>& indices,
67  const Output<Node>& segment_ids,
68  const Output<Node>& num_segments,
69  const Output<Node>& default_index);
70 
71  EmbeddingSegmentsSum(const Output<Node>& emb_table,
72  const Output<Node>& indices,
73  const Output<Node>& segment_ids,
74  const Output<Node>& num_segments);
75 
76  void validate_and_infer_types() override;
77 
78  virtual std::shared_ptr<Node>
79  clone_with_new_inputs(const OutputVector& new_args) const override;
80 
81  virtual bool visit_attributes(AttributeVisitor& visitor) override { return true; }
82  private:
83  static constexpr int EMB_TABLE = 0;
84  static constexpr int INDICES = 1;
85  static constexpr int SEGMENT_IDS = 2;
86  static constexpr int NUM_SEGMENTS = 3;
87  static constexpr int DEFAULT_INDEX = 4;
88  static constexpr int PER_SAMPLE_WEIGHTS = 5;
89  };
90  }
91  using v3::EmbeddingSegmentsSum;
92  }
93 }
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:71
A handle for one of a node's outputs.
Definition: node_output.hpp:42
Root of all actual ops.
Definition: op.hpp:29
Returns embeddings for given indices.
Definition: embedding_segments_sum.hpp:30
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
EmbeddingSegmentsSum()=default
Constructs a EmbeddingSegmentsSum operation.
EmbeddingSegmentsSum(const Output< Node > &emb_table, const Output< Node > &indices, const Output< Node > &segment_ids, const Output< Node > &num_segments, const Output< Node > &default_index, const Output< Node > &per_sample_weights)
Constructs a EmbeddingSegmentsSum operation.
const NodeTypeInfo & get_type_info() const override
Definition: embedding_segments_sum.hpp:33
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28
Definition: type.hpp:39