The ONNX* importer provides a mechanism to register custom ONNX operators based on predefined or custom nGraph operations. The function responsible for registering a new operator is called ngraph::onnx_import::register_operator
and is defined in onnx_import/onnx_utils.hpp
.
Register Custom ONNX Operator Based on Predefined nGraph Operations
The steps below explain how to register a custom ONNX operator, for example, CustomRelu, in a domain called com.example
. CustomRelu is defined as follows:
x >= 0 => f(x) = x * alpha
x < 0 => f(x) = x * beta
where alpha
and beta
are float constants.
- Include headers:
#include <onnx_import/onnx_utils.hpp>
#include <ngraph/opsets/opset5.hpp>
- Register the CustomRelu operator in the ONNX importer:
ngraph::OutputVector ng_inputs{onnx_node.get_ng_inputs()};
std::shared_ptr<ngraph::Node> zero_node = opset::Constant::create(data.get_element_type(),
ngraph::Shape{}, {0});
std::shared_ptr<ngraph::Node> negative_map = std::make_shared<opset::Convert>(
std::make_shared<opset::Less>(data, zero_node), data.get_element_type());
std::shared_ptr<ngraph::Node> positive_map = std::make_shared<opset::Convert>(
std::make_shared<opset::GreaterEqual>(data, zero_node), data.get_element_type());
float alpha = onnx_node.get_attribute_value<float>("alpha", 1);
float beta = onnx_node.get_attribute_value<float>("beta");
std::shared_ptr<ngraph::Node> alpha_node = opset::Constant::create(ngraph::element::f32,
ngraph::Shape{}, {alpha});
std::shared_ptr<ngraph::Node> beta_node = opset::Constant::create(ngraph::element::f32,
ngraph::Shape{}, {beta});
return {
std::make_shared<opset::Add>(
std::make_shared<opset::Multiply>(alpha_node, std::make_shared<opset::Multiply>(data, positive_map)),
std::make_shared<opset::Multiply>(beta_node, std::make_shared<opset::Multiply>(data, negative_map))
)
};
});
NGRAPH_HELPER_DLL_IMPORT void register_operator(const std::string &name, std::int64_t version, const std::string &domain, Operator fn)
The register_operator
function takes four arguments: op_type, opset version, domain, and a function object. The function object is a user-defined function that takes ngraph::onnx_import::Node
as an input and based on that, returns a graph with nGraph operations. The ngraph::onnx_import::Node
class represents a node in an ONNX model. It provides functions to fetch input node(s) using get_ng_inputs
, attribute value using get_attribute_value
, and many more. See onnx_import/core/node.hpp
for full class declaration.
New operator registration must happen before an ONNX model is read. For example, if an model uses the CustomRelu
operator, call register_operator("CustomRelu", ...)
before InferenceEngine::Core::ReadNetwork. Reregistering ONNX operators within the same process is supported. If you register an existing operator, you get a warning.
The example below demonstrates an exemplary model that requires a previously created CustomRelu
operator:
R"ONNX(
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "in"
output: "out"
name: "customrelu"
op_type: "CustomRelu"
domain: "com.example"
attribute {
name: "alpha"
type: FLOAT
f: 2
}
attribute {
name: "beta"
type: FLOAT
f: 3
}
}
name: "custom relu graph"
input {
name: "in"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 8
}
}
}
}
}
output {
name: "out"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 8
}
}
}
}
}
}
opset_import {
domain: "com.example"
version: 1
}
)ONNX";
To create a graph with nGraph operations, visit Custom nGraph Operations. For a complete list of predefined nGraph operators, visit Available Operations Sets.
If you do not need an operator anymore, unregister it by calling unregister_operator
. The function takes three arguments: op_type
, version
, and domain
.
NGRAPH_HELPER_DLL_IMPORT void unregister_operator(const std::string &name, std::int64_t version, const std::string &domain)
Register Custom ONNX Operator Based on Custom nGraph Operations
The same principles apply when registering a custom ONNX operator based on custom nGraph operations. This example shows how to register a custom ONNX operator based on Operation
presented in this tutorial, which is used in TemplateExtension.
Extension::Extension() {
#ifdef NGRAPH_ONNX_IMPORT_ENABLED
ngraph::OutputVector ng_inputs{node.get_ng_inputs()};
int64_t
add = node.get_attribute_value<int64_t>(
"add");
return {std::make_shared<Operation>(ng_inputs.at(0), add)};
});
#ifdef OPENCV_IMPORT_ENABLED
ngraph::OutputVector ng_inputs{node.get_ng_inputs()};
bool inverse = node.get_attribute_value<int64_t>("inverse");
return {std::make_shared<FFTOp>(ng_inputs.at(0), inverse)};
});
#endif
#endif
}
Node add(NodeInput left_node, NodeInput right_node, str auto_broadcast="NUMPY", Optional[str] name=None)
Here, the register_operator
function is called in the constructor of Extension. The constructor makes sure that the function is called before InferenceEngine::Core::ReadNetwork, because InferenceEngine::Core::AddExtension must be called before a model with a custom operator is read.
The example below demonstrates how to unregister an operator from the destructor of Extension:
Extension::~Extension() {
#ifdef NGRAPH_ONNX_IMPORT_ENABLED
#ifdef OPENCV_IMPORT_ENABLED
#endif
#endif
}
NOTE: It is mandatory to unregister a custom ONNX operator if it is defined in a dynamic shared library.
Requirements for Building with CMake
A program that uses the register_operator
functionality requires ngraph
and onnx_importer
libraries in addition to the Inference Engine. The onnx_importer
is a component of the ngraph
package , so find_package(ngraph REQUIRED COMPONENTS onnx_importer)
can find both. The ngraph
package exposes two variables, ${NGRAPH_LIBRARIES}
and ${ONNX_IMPORTER_LIBRARIES}
, which reference the ngraph
and onnx_importer
libraries. Those variables need to be passed to the target_link_libraries
command in the CMakeLists.txt file.
See CMakeLists.txt below for reference:
set(CMAKE_CXX_STANDARD 11)
set(TARGET_NAME "onnx_custom_op")
find_package(ngraph REQUIRED COMPONENTS onnx_importer)
add_library(${TARGET_NAME} STATIC onnx_custom_op.cpp)
target_link_libraries(${TARGET_NAME} PUBLIC ${NGRAPH_LIBRARIES} ${ONNX_IMPORTER_LIBRARIES})