branch.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include "ngraph/node.hpp"
8 #include "ngraph/pattern/op/pattern.hpp"
9 
10 namespace ngraph
11 {
12  namespace pattern
13  {
14  namespace op
15  {
16  /// A branch adds a loop to the pattern. The branch match is successful if the
17  /// destination node pattern matches the graph value. The destination node is a node in
18  /// the pattern graph that will not have been created some time after the Branch node is
19  /// created; use set_destination to add it.
20  ///
21  /// The branch destination is not stored as a shared pointer to prevent reference
22  /// cycles. Thus the destination node must be referenced in some other way to prevent it
23  /// from being deleted.
24  class NGRAPH_API Branch : public Pattern
25  {
26  public:
27  static constexpr NodeTypeInfo type_info{"patternBranch", 0};
28  const NodeTypeInfo& get_type_info() const override;
29  /// \brief Creates a Branch pattern
30  /// \param pattern the destinationing pattern
31  /// \param labels Labels where the destination may occur
33  : Pattern(OutputVector{})
34  {
35  set_output_type(0, element::f32, Shape{});
36  }
37 
38  void set_destination(const Output<Node>& destination)
39  {
40  m_destination_node = destination.get_node();
41  m_destination_index = destination.get_index();
42  }
43 
44  Output<Node> get_destination() const
45  {
46  return m_destination_node == nullptr
47  ? Output<Node>()
48  : Output<Node>{m_destination_node->shared_from_this(),
49  m_destination_index};
50  }
51 
52  bool match_value(pattern::Matcher* matcher,
53  const Output<Node>& pattern_value,
54  const Output<Node>& graph_value) override;
55 
56  protected:
57  Node* m_destination_node{nullptr};
58  size_t m_destination_index{0};
59  };
60  } // namespace op
61  } // namespace pattern
62 } // namespace ngraph
Definition: node.hpp:127
A handle for one of a node's outputs.
Definition: node_output.hpp:33
Node * get_node() const
size_t get_index() const
Definition: node_output.hpp:25
Shape for a tensor.
Definition: shape.hpp:19
Definition: branch.hpp:25
const NodeTypeInfo & get_type_info() const override
Branch()
Creates a Branch pattern.
Definition: branch.hpp:32
Definition: pattern.hpp:73
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16
Definition: type.hpp:27