To add your custom nGraph operation, create a new class that extends ngraph::Op
which is in turn derived from ngraph::Node
, the base class for all graph operations in nGraph. Follow the steps below:
- Define a
NodeTypeInfo
object that identifies the type of the operation to the graph users and helps with dynamic type resolution. The type info of an nGraph operation currently consists of a string identifier and a version number, but this may change in the future.
- Implement constructors that can optionally take the operation inputs and attributes as parameters.
- Override the shape inference method
validate_and_infer_types
. This method is called multiple times during graph manipulations to determine the shapes and element types of the outputs of the operations. You can access the input shapes through the get_input_partial_shape()
method and input element types through the get_input_element_type()
method of ngraph::Node
. Set the inferred shape and element type of the output using set_output_type
.
- Override the
copy_with_new_args
method that allows graph manipulation routines to create copies of this operation and connect it to different nodes during optimization.
- Override the
visit_attributes
method that allows serialization and deserialization of attributes in MyCustomOp
. An AttributeVisitor
is passed to the method, and the implementation is expected to walk over all the attributes in the op using the type-aware on_attribute
helper. Helpers are already implemented for standard C++ types like int64_t
, float
, bool
, vector
and for existing nGraph defined types.
Code Sample
For code sample, see the my_custom_op.hpp
and my_custom_op.cpp
files below.
Click to expand/collapse my_custom_op.hpp
#include "ngraph/op/op.hpp"
namespace op {
class MyCustomOp : public Op {
public:
static constexpr NodeTypeInfo type_info{"MyCustomOp", 0};
const NodeTypeInfo& get_type_info() const override { return type_info; }
MyCustomOp() = default;
MyCustomOp(const Output<Node> & input0, const Output<Node> & input1, const int attribute)
void validate_and_infer_types() override;
std::shared_ptr<Node> copy_with_new_args(const NodeVector& new_args) const override;
bool visit_attributes(AttributeVisitor& visitor) override;
int get_attribute() const { return m_attribute; }
void set_attribute(const int attribute) { m_attribute = attribute; }
private:
int m_attribute;
};
}
}
Click to expand/collapse my_custom_op.cpp
#include "my_custom_op.hpp"
constexpr NodeTypeInfo op::MyCustomOp::type_info;
op::MyCustomOp::MyCustomOp(const Output<Node>& input0, const Output<Node>& input1, const int attribute)
: Op({input0, input1}), m_attribute(attribute) {
constructor_validate_and_infer_types();
}
void op::MyCustomOp::validate_and_infer_types() {
auto output0_shape = get_input_partial_shape(0);
auto output1_shape = get_input_partial_shape(1);
auto output_et = get_input_element_type(0);
set_output_type(0, output_et, output0_shape);
set_output_type(1, output_et, output1_shape);
}
shared_ptr<Node> op::MyCustomOp::copy_with_new_args(const NodeVector& new_args) const {
check_new_args_count(this, new_args);
return make_shared<MyCustomOp>(new_args.at(0), new_args.at(1), m_attribute);
}
bool op::MyCustomOp::visit_attributes(AttributeVisitor& visitor)
{
visitor.on_attribute("attribute", m_attribute);
return true;
}