Custom ONNX operators

ONNX importer provides mechanism to register custom ONNX operators based on predefined or user-defined nGraph operations. The function responsible for registering a new operator is called ngraph::onnx_import::register_operator and is defined in onnx_import/onnx_utils.hpp.

Registering custom ONNX operator based on predefined nGraph operations

The steps below explain how to register a custom ONNX operator, for example, CustomRelu, in a domain called com.example. CustomRelu is defined as follows:

x >= 0 => f(x) = x * alpha
x < 0 => f(x) = x * beta

where alpha, beta are float constants.

  1. Include headers:
    // onnx_import/onnx_utils.hpp provides ngraph::onnx_import::register_operator function, that registers operator in ONNX importer's set.
    #include <onnx_import/onnx_utils.hpp>
    // ngraph/opsets/opset5.hpp provides the declaration of predefined nGraph operator set
    #include <ngraph/opsets/opset5.hpp>
  2. Register the CustomRelu operator in the ONNX importer:
    ngraph::onnx_import::register_operator(
    "CustomRelu", 1, "com.example", [](const ngraph::onnx_import::Node& onnx_node) -> ngraph::OutputVector {
    namespace opset = ngraph::opset5;
    ngraph::OutputVector ng_inputs{onnx_node.get_ng_inputs()};
    const ngraph::Output<ngraph::Node>& data = ng_inputs.at(0);
    // create constant node with a single element that's equal to zero
    std::shared_ptr<ngraph::Node> zero_node = opset::Constant::create(data.get_element_type(), ngraph::Shape{}, {0});
    // create a negative map for 'data' node, 1 for negative values , 0 for positive values or zero
    // then convert it from boolean type to `data.get_element_type()`
    std::shared_ptr<ngraph::Node> negative_map = std::make_shared<opset::Convert>(
    std::make_shared<opset::Less>(data, zero_node), data.get_element_type());
    // create a positive map for 'data' node, 0 for negative values , 1 for positive values or zero
    // then convert it from boolean type to `data.get_element_type()`
    std::shared_ptr<ngraph::Node> positive_map = std::make_shared<opset::Convert>(
    std::make_shared<opset::GreaterEqual>(data, zero_node), data.get_element_type());
    // fetch alpha and beta attributes from ONNX node
    float alpha = onnx_node.get_attribute_value<float>("alpha", 1); // if 'alpha' attribute is not provided in the model, then the default value is 1
    float beta = onnx_node.get_attribute_value<float>("beta");
    // create constant node with a single element 'alpha' with type f32
    std::shared_ptr<ngraph::Node> alpha_node = opset::Constant::create(ngraph::element::f32, ngraph::Shape{}, {alpha});
    // create constant node with a single element 'beta' with type f32
    std::shared_ptr<ngraph::Node> beta_node = opset::Constant::create(ngraph::element::f32, ngraph::Shape{}, {beta});
    return {
    std::make_shared<opset::Add>(
    std::make_shared<opset::Multiply>(alpha_node, std::make_shared<opset::Multiply>(data, positive_map)),
    std::make_shared<opset::Multiply>(beta_node, std::make_shared<opset::Multiply>(data, negative_map))
    )
    };
    });
    The register_operator function takes four arguments: op_type, opset version, domain, and a function object. The function object is a user-defined function that takes ngraph::onnx_import::Node as an input and based on that, returns a graph with nGraph operations. The ngraph::onnx_import::Node class represents a node in ONNX model. It provides functions to fetch input node(s) (get_ng_inputs), fetch attribute value (get_attribute_value) and many more (please refer to onnx_import/core/node.hpp for full class declaration). New operator registration must happen before the ONNX model is read, for example, if an ONNX model uses the 'CustomRelu' operator, register_operator("CustomRelu", ...) must be called before InferenceEngine::Core::ReadNetwork. Re-registering ONNX operators within the same process is supported. During registration of the existing operator, a warning is printed.

The example below demonstrates an exemplary model that requires previously created 'CustomRelu' operator:

R"ONNX(
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
node {
input: "in"
output: "out"
name: "customrelu"
op_type: "CustomRelu"
domain: "com.example"
attribute {
name: "alpha"
type: FLOAT
f: 2
}
attribute {
name: "beta"
type: FLOAT
f: 3
}
}
name: "custom relu graph"
input {
name: "in"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 8
}
}
}
}
}
output {
name: "out"
type {
tensor_type {
elem_type: 1
shape {
dim {
dim_value: 8
}
}
}
}
}
}
opset_import {
domain: "com.example"
version: 1
}
)ONNX";

For a reference on how to create a graph with nGraph operations, visit Custom nGraph Operation. For a complete list of predefined nGraph operators, visit available operations sets.

If operator is no longer needed, it can be unregistered by calling unregister_operator. The function takes three arguments op_type, version, and domain.

ngraph::onnx_import::unregister_operator("CustomRelu", 1, "com.example");

Registering custom ONNX operator based on custom nGraph operations

The same principles apply when registering custom ONNX operator based on custom nGraph operations. This example shows how to register custom ONNX operator based on Operation presented in this tutorial, which is used in TemplateExtension.

Extension::Extension() {
#ifdef NGRAPH_ONNX_IMPORT_ENABLED
ngraph::onnx_import::register_operator(
Operation::type_info.name, 1, "custom_domain", [](const ngraph::onnx_import::Node& node) -> ngraph::OutputVector {
ngraph::OutputVector ng_inputs{node.get_ng_inputs()};
int64_t add = node.get_attribute_value<int64_t>("add");
return {std::make_shared<Operation>(ng_inputs.at(0), add)};
});
#endif
}

Here, the register_operator function is called in Extension's constructor, which makes sure that it is called before InferenceEngine::Core::ReadNetwork (since InferenceEngine::Core::AddExtension must be called before a model with custom operator is read).

The example below demonstrates how to unregister operator from Extension's destructor:

Extension::~Extension() {
#ifdef NGRAPH_ONNX_IMPORT_ENABLED
ngraph::onnx_import::unregister_operator(Operation::type_info.name, 1, "custom_domain");
#endif
}

Note that it is mandatory to unregister custom ONNX operator if it is defined in dynamic shared library.

Requirements for building with CMake

Program that uses the register_operator functionality, requires (in addition to Inference Engine) ngraph and onnx_importer libraries. The onnx_importer is a component of ngraph package , so find_package(ngraph REQUIRED COMPONENTS onnx_importer) is sufficient to find both. The ngraph package exposes two variables (${NGRAPH_LIBRARIES} and ${ONNX_IMPORTER_LIBRARIES}), which reference ngraph and onnx_importer libraries. Those variables need to be passed to the target_link_libraries command in the CMakeLists.txt file.

See below CMakeLists.txt for reference:

set(CMAKE_CXX_STANDARD 11)
set(TARGET_NAME "onnx_custom_op")
find_package(ngraph REQUIRED COMPONENTS onnx_importer)
add_library(${TARGET_NAME} STATIC onnx_custom_op.cpp)
target_link_libraries(${TARGET_NAME} PUBLIC ${NGRAPH_LIBRARIES} ${ONNX_IMPORTER_LIBRARIES})