19 #include <onnx/onnx_pb.h>
21 #include "ngraph/op/constant.hpp"
22 #include "ngraph/op/parameter.hpp"
23 #include "ngraph/partial_shape.hpp"
24 #include "ngraph/type/element_type.hpp"
26 #include "onnx_import/default_opset.hpp"
27 #include "onnx_import/utils/common.hpp"
41 : ngraph_error{
"value info has no element type specified"}
55 explicit ValueInfo(
const ONNX_NAMESPACE::ValueInfoProto& value_info_proto)
56 : m_value_info_proto{&value_info_proto}
58 if (value_info_proto.type().has_tensor_type())
60 const auto& onnx_tensor = value_info_proto.type().tensor_type();
62 if (onnx_tensor.has_shape())
64 m_partial_shape = to_ng_shape(onnx_tensor.shape());
76 const std::string& get_name()
const {
return m_value_info_proto->name(); }
77 const PartialShape& get_shape()
const {
return m_partial_shape; }
80 if (!m_value_info_proto->type().tensor_type().has_elem_type())
84 return common::get_ngraph_element_type(
85 m_value_info_proto->type().tensor_type().elem_type());
88 std::shared_ptr<ngraph::Node>
89 get_ng_node(ParameterVector& parameters,
90 const std::map<std::string, Tensor>& initializers)
const
92 const auto it = initializers.find(get_name());
93 if (it != std::end(initializers))
95 return get_ng_constant(it->second);
97 parameters.push_back(get_ng_parameter());
98 return parameters.back();
102 std::shared_ptr<ngraph::op::Parameter> get_ng_parameter()
const
105 std::make_shared<ngraph::op::Parameter>(get_element_type(), get_shape());
106 parameter->set_friendly_name(get_name());
110 std::shared_ptr<ngraph::op::Constant> get_ng_constant(
const Tensor& tensor)
const
112 return tensor.get_ng_constant();
115 PartialShape to_ng_shape(
const ONNX_NAMESPACE::TensorShapeProto& onnx_shape)
const
117 if (onnx_shape.dim_size() == 0)
122 std::vector<Dimension> dims;
123 for (
const auto& onnx_dim : onnx_shape.dim())
125 if (onnx_dim.has_dim_value())
127 dims.emplace_back(onnx_dim.dim_value());
134 return PartialShape{dims};
138 const ONNX_NAMESPACE::ValueInfoProto* m_value_info_proto;
139 PartialShape m_partial_shape;
142 inline std::ostream& operator<<(std::ostream& outs,
const ValueInfo& info)
144 return (outs <<
"<ValueInfo: " << info.get_name() <<
">");