label.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include "ngraph/node.hpp"
8 #include "ngraph/pattern/op/pattern.hpp"
9 
10 namespace ngraph
11 {
12  namespace pattern
13  {
14  namespace op
15  {
16  /// Fails if the predicate returns false on the graph value.
17  ///
18  /// The graph value is added to the matched values list. If the Label is already
19  /// associated with a value, the match succeeds if the value is the same as the graph
20  /// value. Otherwise, the label is associated with the graph value and the match
21  /// succeeds if the pattern input matches the graph value.
22  ///
23  /// DEPRECATED: If no inputs are given to Label, a True node is serves as the input. If
24  /// more than one inputs are given, an Or pattern of the inputs serves as the input.
25  class NGRAPH_API Label : public Pattern
26  {
27  public:
28  static constexpr NodeTypeInfo type_info{"patternLabel", 0};
29  const NodeTypeInfo& get_type_info() const override;
30  /// \brief creates a Label node containing a sub-pattern described by \sa type and
31  /// \sa shape.
32  ///
33  /// this Label node can be bound only to the nodes in the input graph
34  /// that match the pattern specified by \sa wrapped_nodes
35  /// Example:
36  /// \code{.cpp}
37  /// auto add = a + b; // a and b are op::Parameter in this example
38  /// auto label = std::make_shared<pattern::op::Label>(element::f32,
39  /// Shape{2,2},
40  /// nullptr,
41  /// OutputVector{add});
42  /// \endcode
43  Label(const element::Type& type,
44  const PartialShape& s,
45  const ValuePredicate pred,
46  const OutputVector& wrapped_values)
47  : Pattern(OutputVector{wrap_values(wrapped_values)}, pred)
48  {
49  set_output_type(0, type, s);
50  }
51 
52  explicit Label(const element::Type& type = element::dynamic,
54  : Label(
55  type, s, [](const Output<Node>&) { return true; }, OutputVector())
56  {
57  }
58 
59  Label(const element::Type& type, const PartialShape& s, ValuePredicate pred)
60  : Label(type, s, pred, OutputVector{})
61  {
62  }
63 
64  Label(const element::Type& type, const PartialShape& s, NodePredicate pred)
65  : Label(type, s, as_value_predicate(pred), OutputVector{})
66  {
67  }
68 
69  Label(const element::Type& type,
70  const PartialShape& s,
71  const NodePredicate pred,
72  const NodeVector& wrapped_values)
73  : Label(type, s, as_value_predicate(pred), as_output_vector(wrapped_values))
74  {
75  }
76 
77  /// \brief creates a Label node containing a sub-pattern described by the type and
78  /// shape of \sa node.
79  ///
80  /// this Label node can be bound only to the nodes in the input graph
81  /// that match the pattern specified by \sa wrapped_values
82  /// Example:
83  /// \code{.cpp}
84  /// auto add = a + b; // a and b are op::Parameter in this example
85  /// auto label = std::make_shared<pattern::op::Label>(add,
86  /// nullptr,
87  /// OutputVector{add});
88  /// \endcode
89  Label(const Output<Node>& value,
90  const ValuePredicate pred,
91  const OutputVector& wrapped_values)
92  : Label(
93  value.get_element_type(), value.get_partial_shape(), pred, wrapped_values)
94  {
95  }
96  Label(const Output<Node>& value, const ValuePredicate pred)
97  : Label(
98  value.get_element_type(), value.get_partial_shape(), pred, OutputVector{})
99  {
100  }
101 
102  Label(const Output<Node>& value, const NodePredicate pred)
103  : Label(value.get_element_type(),
104  value.get_partial_shape(),
105  as_value_predicate(pred),
106  OutputVector{})
107  {
108  }
109  Label(const Output<Node>& value)
110  : Label(
111  value.get_element_type(),
112  value.get_partial_shape(),
113  [](const Output<Node>&) { return true; },
114  OutputVector{})
115  {
116  }
117  Label(const Output<Node>& node,
118  const NodePredicate pred,
119  const NodeVector& wrapped_values)
120  : Label(node.get_element_type(),
121  node.get_partial_shape(),
122  as_value_predicate(pred),
123  as_output_vector(wrapped_values))
124  {
125  }
126 
127  bool match_value(Matcher* matcher,
128  const Output<Node>& pattern_value,
129  const Output<Node>& graph_value) override;
130 
131  protected:
132  static Output<Node> wrap_values(const OutputVector& wrapped_values);
133  };
134  } // namespace op
135 
136  NGRAPH_API
137  std::shared_ptr<Node> any_input();
138 
139  NGRAPH_API
140  std::shared_ptr<Node> any_input(const pattern::op::ValuePredicate& pred);
141  } // namespace pattern
142 } // namespace ngraph
Definition: node.hpp:127
A handle for one of a node's outputs.
Definition: node_output.hpp:33
Definition: node_output.hpp:25
Class representing a shape that may be partially or totally dynamic.
Definition: partial_shape.hpp:34
static PartialShape dynamic(Rank r=Rank::dynamic())
Construct a PartialShape with the given rank and all dimensions (if any) dynamic.
Definition: element_type.hpp:51
Definition: label.hpp:26
Label(const element::Type &type, const PartialShape &s, const ValuePredicate pred, const OutputVector &wrapped_values)
creates a Label node containing a sub-pattern described by
Definition: label.hpp:43
Label(const Output< Node > &value, const ValuePredicate pred, const OutputVector &wrapped_values)
creates a Label node containing a sub-pattern described by the type and shape of
Definition: label.hpp:89
const NodeTypeInfo & get_type_info() const override
Definition: pattern.hpp:73
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16
Definition: type.hpp:27