value_info.hpp
1 //*****************************************************************************
2 // Copyright 2017-2020 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 
17 #pragma once
18 
19 #include <onnx/onnx_pb.h>
20 
21 #include "ngraph/op/constant.hpp"
22 #include "ngraph/op/parameter.hpp"
23 #include "ngraph/partial_shape.hpp"
24 #include "ngraph/type/element_type.hpp"
25 #include "node.hpp"
26 #include "onnx_import/default_opset.hpp"
27 #include "onnx_import/utils/common.hpp"
28 #include "tensor.hpp"
29 
30 namespace ngraph
31 {
32  namespace onnx_import
33  {
34  namespace error
35  {
36  namespace value_info
37  {
38  struct unspecified_element_type : ngraph_error
39  {
41  : ngraph_error{"value info has no element type specified"}
42  {
43  }
44  };
45  } // namespace value_info
46  } // namespace error
47 
48  class ValueInfo
49  {
50  public:
51  ValueInfo(ValueInfo&&) = default;
52  ValueInfo(const ValueInfo&) = default;
53 
54  ValueInfo() = delete;
55  explicit ValueInfo(const ONNX_NAMESPACE::ValueInfoProto& value_info_proto)
56  : m_value_info_proto{&value_info_proto}
57  {
58  if (value_info_proto.type().has_tensor_type())
59  {
60  const auto& onnx_tensor = value_info_proto.type().tensor_type();
61 
62  if (onnx_tensor.has_shape())
63  {
64  m_partial_shape = to_ng_shape(onnx_tensor.shape());
65  }
66  else
67  {
68  m_partial_shape = PartialShape::dynamic();
69  }
70  }
71  }
72 
73  ValueInfo& operator=(const ValueInfo&) = delete;
74  ValueInfo& operator=(ValueInfo&&) = delete;
75 
76  const std::string& get_name() const { return m_value_info_proto->name(); }
77  const PartialShape& get_shape() const { return m_partial_shape; }
78  const element::Type& get_element_type() const
79  {
80  if (!m_value_info_proto->type().tensor_type().has_elem_type())
81  {
83  }
84  return common::get_ngraph_element_type(
85  m_value_info_proto->type().tensor_type().elem_type());
86  }
87 
88  std::shared_ptr<ngraph::Node>
89  get_ng_node(ParameterVector& parameters,
90  const std::map<std::string, Tensor>& initializers) const
91  {
92  const auto it = initializers.find(get_name());
93  if (it != std::end(initializers))
94  {
95  return get_ng_constant(it->second);
96  }
97  parameters.push_back(get_ng_parameter());
98  return parameters.back();
99  }
100 
101  protected:
102  std::shared_ptr<ngraph::op::Parameter> get_ng_parameter() const
103  {
104  auto parameter =
105  std::make_shared<ngraph::op::Parameter>(get_element_type(), get_shape());
106  parameter->set_friendly_name(get_name());
107  return parameter;
108  }
109 
110  std::shared_ptr<ngraph::op::Constant> get_ng_constant(const Tensor& tensor) const
111  {
112  return tensor.get_ng_constant();
113  }
114 
115  PartialShape to_ng_shape(const ONNX_NAMESPACE::TensorShapeProto& onnx_shape) const
116  {
117  if (onnx_shape.dim_size() == 0)
118  {
119  return Shape{}; // empty list of dimensions denotes a scalar
120  }
121 
122  std::vector<Dimension> dims;
123  for (const auto& onnx_dim : onnx_shape.dim())
124  {
125  if (onnx_dim.has_dim_value())
126  {
127  dims.emplace_back(onnx_dim.dim_value());
128  }
129  else // has_dim_param() == true or it is empty dim
130  {
131  dims.push_back(Dimension::dynamic());
132  }
133  }
134  return PartialShape{dims};
135  }
136 
137  private:
138  const ONNX_NAMESPACE::ValueInfoProto* m_value_info_proto;
139  PartialShape m_partial_shape;
140  };
141 
142  inline std::ostream& operator<<(std::ostream& outs, const ValueInfo& info)
143  {
144  return (outs << "<ValueInfo: " << info.get_name() << ">");
145  }
146 
147  } // namespace onnx_import
148 
149 } // namespace ngraph
ngraph::onnx_import::Tensor
Definition: tensor.hpp:383
ngraph::onnx_import::error::value_info::unspecified_element_type
Definition: value_info.hpp:39
ngraph::Dimension::dynamic
static Dimension dynamic()
Create a dynamic dimension.
Definition: dimension.hpp:130
ngraph::element::Type
Definition: element_type.hpp:61
ngraph
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28
ngraph::onnx_import::ValueInfo
Definition: value_info.hpp:49
ngraph::PartialShape::dynamic
static PartialShape dynamic(Rank r=Rank::dynamic())
Construct a PartialShape with the given rank and all dimensions (if any) dynamic.