editor.hpp
1 //*****************************************************************************
2 // Copyright 2017-2021 Intel Corporation
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 // http://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //*****************************************************************************
16 
17 #pragma once
18 
19 #include <istream>
20 #include <map>
21 #include <memory>
22 
23 #include "ngraph/op/constant.hpp"
24 #include "ngraph/partial_shape.hpp"
25 #include "ngraph/type/element_type.hpp"
26 #include "onnx_import/utils/onnx_importer_visibility.hpp"
27 
28 namespace ONNX_NAMESPACE
29 {
30  // forward declaration to avoid the necessity of include paths setting in components
31  // that don't directly depend on the ONNX library
32  class ModelProto;
33 } // namespace ONNX_NAMESPACE
34 
35 namespace ngraph
36 {
37  namespace onnx_import
38  {
39  /// \brief A class representing a set of utilities allowing modification of an ONNX model
40  ///
41  /// \note This class can be used to modify an ONNX model before it gets translated to
42  /// an ngraph::Function by the import_onnx_model function. It lets you modify the
43  /// model's input types and shapes, extract a subgraph and more. An instance of this
44  /// class can be passed directly to the onnx_importer API.
45  class ONNX_IMPORTER_API ONNXModelEditor final
46  {
47  public:
48  ONNXModelEditor() = delete;
49 
50  /// \brief Creates an editor from a model file located on a storage device. The file
51  /// is parsed and loaded into the m_model_proto member variable.
52  ///
53  /// \param model_path Path to the file containing the model.
54  ONNXModelEditor(const std::string& model_path);
55 
56  /// \brief Modifies the in-memory representation of the model (m_model_proto) by setting
57  /// custom input types for all inputs specified in the provided map.
58  ///
59  /// \param input_types A collection of pairs {input_name: new_input_type} that should be
60  /// used to modified the ONNX model loaded from a file. This method
61  /// throws an exception if the model doesn't contain any of
62  /// the inputs specified in its parameter.
63  void set_input_types(const std::map<std::string, element::Type_t>& input_types);
64 
65  /// \brief Modifies the in-memory representation of the model (m_model_proto) by setting
66  /// custom input shapes for all inputs specified in the provided map.
67  ///
68  /// \param input_shapes A collection of pairs {input_name: new_input_shape} that should
69  /// be used to modified the ONNX model loaded from a file. This
70  /// method throws an exception if the model doesn't contain any of
71  /// the inputs specified in its parameter.
72  void set_input_shapes(const std::map<std::string, ngraph::PartialShape>& input_shapes);
73 
74  /// \brief Modifies the in-memory representation of the model by setting custom input
75  /// values for inputs specified in the provided map.
76  ///
77  /// \note This method modifies existing initializer tensor if its name matches one of
78  /// input_name. Otherwise it adds initializer tensor into the model.
79  /// If input tensor of matching name is present in the model, its type and shape
80  /// are modified accordingly.
81  ///
82  /// \param input_values A collection of pairs {input_name: new_input_values} used to
83  /// update the ONNX model. Initializers already existing are
84  /// overwritten.
86  const std::map<std::string, std::shared_ptr<ngraph::op::Constant>>& input_values);
87 
88  /// \brief Returns a non-const reference to the underlying ModelProto object, possibly
89  /// modified by the editor's API calls
90  ///
91  /// \return A reference to ONNX ModelProto object containing the in-memory model
92  ONNX_NAMESPACE::ModelProto& model() const;
93 
94  /// \brief Returns the path to the original model file
95  const std::string& model_path() const;
96 
97  /// \brief Saves the possibly model held by this class to a file. Serializes in binary
98  /// mode.
99  ///
100  /// \param out_file_path A path to the file where the modified model should be dumped.
101  void serialize(const std::string& out_file_path) const;
102 
103  private:
104  const std::string m_model_path;
105 
106  class Impl;
107  std::unique_ptr<Impl, void (*)(Impl*)> m_pimpl;
108  };
109  } // namespace onnx_import
110 } // namespace ngraph
A class representing a set of utilities allowing modification of an ONNX model.
Definition: editor.hpp:46
void set_input_types(const std::map< std::string, element::Type_t > &input_types)
Modifies the in-memory representation of the model (m_model_proto) by setting custom input types for ...
const std::string & model_path() const
Returns the path to the original model file.
ONNX_NAMESPACE::ModelProto & model() const
Returns a non-const reference to the underlying ModelProto object, possibly modified by the editor's ...
void set_input_values(const std::map< std::string, std::shared_ptr< ngraph::op::Constant >> &input_values)
Modifies the in-memory representation of the model by setting custom input values for inputs specifie...
void set_input_shapes(const std::map< std::string, ngraph::PartialShape > &input_shapes)
Modifies the in-memory representation of the model (m_model_proto) by setting custom input shapes for...
ONNXModelEditor(const std::string &model_path)
Creates an editor from a model file located on a storage device. The file is parsed and loaded into t...
void serialize(const std::string &out_file_path) const
Saves the possibly model held by this class to a file. Serializes in binary mode.
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:28