Custom ONNX* Operators

The ONNX* importer provides a mechanism to register custom ONNX operators based on predefined or custom nGraph operations. The function responsible for registering a new operator is called ngraph::onnx_import::register_operator and is defined in ` <https://docs.openvinotoolkit.org/latest/ngraph_cpp_api/onnx__utils_8hpp_source.html>`__.

Register Custom ONNX Operator Based on Predefined nGraph Operations

The steps below explain how to register a custom ONNX operator, for example, CustomRelu, in a domain called com.example. CustomRelu is defined as follows:

x >= 0 => f(x) = x * alpha
x <  0 => f(x) = x * beta

where alpha and beta are float constants.

  1. Include headers:

// onnx_import/onnx_utils.hpp provides ngraph::onnx_import::register_operator function, that registers operator in ONNX importer's set.
#include <onnx_import/onnx_utils.hpp>
// ngraph/opsets/opset5.hpp provides the declaration of predefined nGraph operator set
#include <ngraph/opsets/opset5.hpp>
  1. Register the CustomRelu operator in the ONNX importer:

ngraph::onnx_import::register_operator(
    "CustomRelu", 1, "com.example", [](const ngraph::onnx_import::Node& onnx_node) -> ngraph::OutputVector {
        namespace opset = ngraph::opset5;

        ngraph::OutputVector ng_inputs{onnx_node.get_ng_inputs()};
        const ngraph::Output<ngraph::Node>& data = ng_inputs.at(0);
        // create constant node with a single element that's equal to zero
        std::shared_ptr<ngraph::Node> zero_node = opset::Constant::create(data.get_element_type(), ngraph::Shape{}, {0});
        // create a negative map for 'data' node, 1 for negative values , 0 for positive values or zero
        // then convert it from boolean type to `data.get_element_type()`
        std::shared_ptr<ngraph::Node> negative_map = std::make_shared<opset::Convert>(
            std::make_shared<opset::Less>(data, zero_node), data.get_element_type());
        // create a positive map for 'data' node, 0 for negative values , 1 for positive values or zero
        // then convert it from boolean type to `data.get_element_type()`
        std::shared_ptr<ngraph::Node> positive_map = std::make_shared<opset::Convert>(
            std::make_shared<opset::GreaterEqual>(data, zero_node), data.get_element_type());

        // fetch alpha and beta attributes from ONNX node
        float alpha = onnx_node.get_attribute_value<float>("alpha", 1); // if 'alpha' attribute is not provided in the model, then the default value is 1
        float beta = onnx_node.get_attribute_value<float>("beta");
        // create constant node with a single element 'alpha' with type f32
        std::shared_ptr<ngraph::Node> alpha_node = opset::Constant::create(ngraph::element::f32, ngraph::Shape{}, {alpha});
        // create constant node with a single element 'beta' with type f32
        std::shared_ptr<ngraph::Node> beta_node = opset::Constant::create(ngraph::element::f32, ngraph::Shape{}, {beta});

        return {
            std::make_shared<opset::Add>(
                std::make_shared<opset::Multiply>(alpha_node, std::make_shared<opset::Multiply>(data, positive_map)),
                std::make_shared<opset::Multiply>(beta_node, std::make_shared<opset::Multiply>(data, negative_map))
            )
        };
});

The register_operator function takes four arguments: op_type, opset version, domain, and a function object. The function object is a user-defined function that takes ngraph::onnx_import::Node as an input and based on that, returns a graph with nGraph operations. The ngraph::onnx_import::Node class represents a node in an ONNX model. It provides functions to fetch input node(s) using get_ng_inputs, attribute value using get_attribute_value, and many more. See ` <https://docs.openvinotoolkit.org/latest/ngraph_cpp_api/core_2include_2ngraph_2node_8hpp_source.html>`__ for full class declaration.

New operator registration must happen before an ONNX model is read. For example, if an model uses the CustomRelu operator, call register_operator("CustomRelu", ...) before InferenceEngine::Core::ReadNetwork. Reregistering ONNX operators within the same process is supported. If you register an existing operator, you get a warning.

The example below demonstrates an exemplary model that requires a previously created CustomRelu operator:

R"ONNX(
ir_version: 3
producer_name: "nGraph ONNX Importer"
graph {
  node {
    input: "in"
    output: "out"
    name: "customrelu"
    op_type: "CustomRelu"
    domain: "com.example"
    attribute {
        name: "alpha"
        type: FLOAT
        f: 2
    }
    attribute {
        name: "beta"
        type: FLOAT
        f: 3
    }
  }
  name: "custom relu graph"
  input {
    name: "in"
    type {
      tensor_type {
        elem_type: 1
        shape {
          dim {
            dim_value: 8
          }
        }
      }
    }
  }
  output {
    name: "out"
    type {
      tensor_type {
        elem_type: 1
        shape {
          dim {
            dim_value: 8
          }
        }
      }
    }
  }
}
opset_import {
  domain: "com.example"
  version: 1
}
)ONNX";

To create a graph with nGraph operations, visit Custom nGraph Operations. For a complete list of predefined nGraph operators, visit Available Operations Sets.

If you do not need an operator anymore, unregister it by calling unregister_operator. The function takes three arguments: op_type, version, and domain.

ngraph::onnx_import::unregister_operator("CustomRelu", 1, "com.example");

Register Custom ONNX Operator Based on Custom nGraph Operations

The same principles apply when registering a custom ONNX operator based on custom nGraph operations. This example shows how to register a custom ONNX operator based on Operation presented in this tutorial, which is used in TemplateExtension :

Extension::Extension() {
#ifdef NGRAPH_ONNX_IMPORT_ENABLED
    ngraph::onnx_import::register_operator(Operation::type_info.name, 1, "custom_domain", [](const ngraph::onnx_import::Node& node) -> ngraph::OutputVector {
        ngraph::OutputVector ng_inputs {node.get_ng_inputs()};
        int64_t add = node.get_attribute_value<int64_t>("add");
        return {std::make_shared<Operation>(ng_inputs.at(0), add)};
    });
    #ifdef OPENCV_IMPORT_ENABLED
    ngraph::onnx_import::register_operator(FFTOp::type_info.name, 1, "custom_domain", [](const ngraph::onnx_import::Node& node) -> ngraph::OutputVector {
        ngraph::OutputVector ng_inputs {node.get_ng_inputs()};
        bool inverse = node.get_attribute_value<int64_t>("inverse");
        return {std::make_shared<FFTOp>(ng_inputs.at(0), inverse)};
    });
    #endif
#endif
}

Here, the register_operator function is called in the constructor of Extension. The constructor makes sure that the function is called before InferenceEngine::Core::ReadNetwork, because InferenceEngine::Core::AddExtension must be called before a model with a custom operator is read.

The example below demonstrates how to unregister an operator from the destructor of Extension:

Extension::~Extension() {
#ifdef NGRAPH_ONNX_IMPORT_ENABLED
    ngraph::onnx_import::unregister_operator(Operation::type_info.name, 1, "custom_domain");
    #ifdef OPENCV_IMPORT_ENABLED
    ngraph::onnx_import::unregister_operator(FFTOp::type_info.name, 1, "custom_domain");
    #endif  // OPENCV_IMPORT_ENABLED
#endif      // NGRAPH_ONNX_IMPORT_ENABLED
}

REQUIRED : It is mandatory to unregister a custom ONNX operator if it is defined in a dynamic shared library.

Requirements for Building with CMake

A program that uses the register_operator functionality requires ngraph::ngraph and ngraph::onnx_ngraph_frontend libraries in addition to the Inference Engine. The onnx_ngraph_frontend is a component of the ngraph package, so find_package(ngraph REQUIRED COMPONENTS onnx_ngraph_frontend) can find both. Those libraries need to be passed to the target_link_libraries command in the CMakeLists.txt file.

See CMakeLists.txt below for reference:

set(CMAKE_CXX_STANDARD 11)

set(TARGET_NAME "onnx_custom_op")

find_package(ngraph REQUIRED COMPONENTS onnx_importer)

add_library(${TARGET_NAME} STATIC onnx_custom_op.cpp onnx_custom_op.hpp)

target_link_libraries(${TARGET_NAME} PUBLIC ${NGRAPH_LIBRARIES} ${ONNX_IMPORTER_LIBRARIES})