wrap_type.hpp
1 //*****************************************************************************
2 // Copyright 2017-2021 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 
17 #pragma once
18 
19 #include "ngraph/node.hpp"
20 #include "ngraph/pattern/op/pattern.hpp"
21 
22 namespace ngraph
23 {
24  namespace pattern
25  {
26  namespace op
27  {
28  class NGRAPH_API WrapType : public Pattern
29  {
30  public:
31  static constexpr NodeTypeInfo type_info{"patternAnyType", 0};
32  const NodeTypeInfo& get_type_info() const override;
33 
34  explicit WrapType(NodeTypeInfo wrapped_type,
35  const ValuePredicate& pred =
36  [](const Output<Node>& output) { return true; },
37  const OutputVector& input_values = {})
38  : Pattern(input_values, pred)
39  , m_wrapped_types({wrapped_type})
40  {
41  set_output_type(0, element::Type_t::dynamic, PartialShape::dynamic());
42  }
43 
44  explicit WrapType(std::vector<NodeTypeInfo> wrapped_types,
45  const ValuePredicate& pred =
46  [](const Output<Node>& output) { return true; },
47  const OutputVector& input_values = {})
48  : Pattern(input_values, pred)
49  , m_wrapped_types(std::move(wrapped_types))
50  {
51  set_output_type(0, element::Type_t::dynamic, PartialShape::dynamic());
52  }
53 
54  bool match_value(pattern::Matcher* matcher,
55  const Output<Node>& pattern_value,
56  const Output<Node>& graph_value) override;
57 
58  NodeTypeInfo get_wrapped_type() const;
59 
60  const std::vector<NodeTypeInfo>& get_wrapped_types() const;
61 
62  private:
63  std::vector<NodeTypeInfo> m_wrapped_types;
64  };
65  }
66 
67  template <class... Args>
68  std::shared_ptr<Node> wrap_type(const OutputVector& inputs,
69  const pattern::op::ValuePredicate& pred)
70  {
71  std::vector<DiscreteTypeInfo> info{Args::type_info...};
72  return std::make_shared<op::WrapType>(info, pred, inputs);
73  }
74 
75  template <class... Args>
76  std::shared_ptr<Node> wrap_type(const OutputVector& inputs = {})
77  {
78  return wrap_type<Args...>(inputs, [](const Output<Node>& output) { return true; });
79  }
80 
81  template <class... Args>
82  std::shared_ptr<Node> wrap_type(const pattern::op::ValuePredicate& pred)
83  {
84  return wrap_type<Args...>({}, pred);
85  }
86  }
87 }
A handle for one of a node's outputs.
Definition: node_output.hpp:42
static PartialShape dynamic(Rank r=Rank::dynamic())
Construct a PartialShape with the given rank and all dimensions (if any) dynamic.
Definition: pattern.hpp:82
Definition: wrap_type.hpp:29
const NodeTypeInfo & get_type_info() const override
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28
Definition: type.hpp:39