any_of.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include "ngraph/node.hpp"
8 #include "ngraph/pattern/op/pattern.hpp"
9 
10 namespace ngraph
11 {
12  namespace pattern
13  {
14  namespace op
15  {
16  /// The graph value is added to the matched values list. If the predicate is true for
17  /// the
18  /// graph node, a submatch is performed on the input of AnyOf and each input of the
19  /// graph node. The first match that succeeds results in a successful match. Otherwise
20  /// the match fails.
21  ///
22  /// AnyOf may be given a type and shape for use in strict mode.
23  class NGRAPH_API AnyOf : public Pattern
24  {
25  public:
26  static constexpr NodeTypeInfo type_info{"patternAnyOf", 0};
27  const NodeTypeInfo& get_type_info() const override;
28  /// \brief creates a AnyOf node containing a sub-pattern described by \sa type and
29  /// \sa shape.
30  AnyOf(const element::Type& type,
31  const PartialShape& s,
32  ValuePredicate pred,
33  const OutputVector& wrapped_values)
34  : Pattern(wrapped_values, pred)
35  {
36  if (wrapped_values.size() != 1)
37  {
38  throw ngraph_error("AnyOf expects exactly one argument");
39  }
40  set_output_type(0, type, s);
41  }
42  AnyOf(const element::Type& type,
43  const PartialShape& s,
44  NodePredicate pred,
45  const NodeVector& wrapped_values)
46  : AnyOf(
47  type,
48  s,
49  [pred](const Output<Node>& value) {
50  return pred(value.get_node_shared_ptr());
51  },
52  as_output_vector(wrapped_values))
53  {
54  }
55 
56  /// \brief creates a AnyOf node containing a sub-pattern described by the type and
57  /// shape of \sa node.
58  AnyOf(const Output<Node>& node,
59  ValuePredicate pred,
60  const OutputVector& wrapped_values)
61  : AnyOf(node.get_element_type(), node.get_partial_shape(), pred, wrapped_values)
62  {
63  }
64  AnyOf(std::shared_ptr<Node> node,
65  NodePredicate pred,
66  const NodeVector& wrapped_values)
67  : AnyOf(node, as_value_predicate(pred), as_output_vector(wrapped_values))
68  {
69  }
70  bool match_value(Matcher* matcher,
71  const Output<Node>& pattern_value,
72  const Output<Node>& graph_value) override;
73  };
74  } // namespace op
75  } // namespace pattern
76 } // namespace ngraph
Definition: node.hpp:127
A handle for one of a node's outputs.
Definition: node_output.hpp:33
Definition: node_output.hpp:25
Class representing a shape that may be partially or totally dynamic.
Definition: partial_shape.hpp:34
Definition: element_type.hpp:51
Base error for ngraph runtime errors.
Definition: except.hpp:16
Definition: matcher.hpp:63
Definition: any_of.hpp:24
const NodeTypeInfo & get_type_info() const override
AnyOf(const Output< Node > &node, ValuePredicate pred, const OutputVector &wrapped_values)
creates a AnyOf node containing a sub-pattern described by the type and shape of
Definition: any_of.hpp:58
AnyOf(const element::Type &type, const PartialShape &s, ValuePredicate pred, const OutputVector &wrapped_values)
creates a AnyOf node containing a sub-pattern described by
Definition: any_of.hpp:30
Definition: pattern.hpp:73
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16
Definition: type.hpp:27