branch.hpp
1 //*****************************************************************************
2 // Copyright 2017-2021 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 
17 #pragma once
18 
19 #include "ngraph/node.hpp"
20 #include "ngraph/pattern/op/pattern.hpp"
21 
22 namespace ngraph
23 {
24  namespace pattern
25  {
26  namespace op
27  {
28  /// A branch adds a loop to the pattern. The branch match is successful if the
29  /// destination node pattern matches the graph value. The destination node is a node in
30  /// the pattern graph that will not have been created some time after the Branch node is
31  /// created; use set_destination to add it.
32  ///
33  /// The branch destination is not stored as a shared pointer to prevent reference
34  /// cycles. Thus the destination node must be referenced in some other way to prevent it
35  /// from being deleted.
36  class NGRAPH_API Branch : public Pattern
37  {
38  public:
39  static constexpr NodeTypeInfo type_info{"patternBranch", 0};
40  const NodeTypeInfo& get_type_info() const override;
41  /// \brief Creates a Branch pattern
42  /// \param pattern the destinationing pattern
43  /// \param labels Labels where the destination may occur
45  : Pattern(OutputVector{})
46  {
47  set_output_type(0, element::f32, Shape{});
48  }
49 
50  void set_destination(const Output<Node>& destination)
51  {
52  m_destination_node = destination.get_node();
53  m_destination_index = destination.get_index();
54  }
55 
56  Output<Node> get_destination() const
57  {
58  return m_destination_node == nullptr
59  ? Output<Node>()
60  : Output<Node>{m_destination_node->shared_from_this(),
61  m_destination_index};
62  }
63 
64  bool match_value(pattern::Matcher* matcher,
65  const Output<Node>& pattern_value,
66  const Output<Node>& graph_value) override;
67 
68  protected:
69  Node* m_destination_node{nullptr};
70  size_t m_destination_index{0};
71  };
72  }
73  }
74 }
Definition: node.hpp:132
A handle for one of a node's outputs.
Definition: node_output.hpp:42
Node * get_node() const
size_t get_index() const
Definition: node_output.hpp:36
Shape for a tensor.
Definition: shape.hpp:31
Definition: branch.hpp:37
const NodeTypeInfo & get_type_info() const override
Branch()
Creates a Branch pattern.
Definition: branch.hpp:44
Definition: pattern.hpp:82
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28
Definition: type.hpp:39