scatter_nd_update.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include "ngraph/op/op.hpp"
8 #include "ngraph/op/util/scatter_nd_base.hpp"
9 
10 namespace ngraph
11 {
12  namespace op
13  {
14  namespace v3
15  {
16  /// \brief Add updates to slices from inputs addressed by indices
17  class NGRAPH_API ScatterNDUpdate : public util::ScatterNDBase
18  {
19  public:
20  static constexpr NodeTypeInfo type_info{"ScatterNDUpdate", 3};
21  const NodeTypeInfo& get_type_info() const override { return type_info; }
22  ScatterNDUpdate() = default;
23  /// \param inputs Tensor
24  /// \param indices Index tensor: Data type must be `element::i32` or `element::i64`
25  /// \param updates Tensor: Must have same type as inputs
27  const Output<Node>& indices,
28  const Output<Node>& updates)
29  : util::ScatterNDBase(inputs, indices, updates)
30  {
31  }
32 
33  virtual std::shared_ptr<Node>
34  clone_with_new_inputs(const OutputVector& new_args) const override;
35  bool evaluate(const HostTensorVector& outputs,
36  const HostTensorVector& inputs) const override;
37  bool has_evaluate() const override;
38  };
39  } // namespace v3
40  using v3::ScatterNDUpdate;
41  } // namespace op
42 } // namespace ngraph
A handle for one of a node's outputs.
Definition: node_output.hpp:33
Base class for ScatterNDXXX operators.
Definition: scatter_nd_base.hpp:19
Add updates to slices from inputs addressed by indices.
Definition: scatter_nd_update.hpp:18
const NodeTypeInfo & get_type_info() const override
Definition: scatter_nd_update.hpp:21
bool evaluate(const HostTensorVector &outputs, const HostTensorVector &inputs) const override
Evaluates the op on input_values putting results in output_values.
bool has_evaluate() const override
Allows to get information about availability of evaluate method for the current operation.
ScatterNDUpdate(const Output< Node > &inputs, const Output< Node > &indices, const Output< Node > &updates)
Definition: scatter_nd_update.hpp:26
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16
Definition: type.hpp:27