Extending Model Optimizer for Custom MXNet* Operations¶
This section provides instruction on how to support a custom MXNet operation (in the MXNet documentation, called an operator or layer) that is not part of the MXNet operation set. Creating custom operations is described in this guide.
This section describes a procedure on how to extract operator attributes in the Model Optimizer. The rest of the operation-enabling pipeline and documentation on how to support MXNet operations from standard MXNet operation set is described in the main Customize_Model_Optimizer document.
Writing Extractor for Custom MXNet Operation¶
Custom MXNet operations have an attribute op
(defining the type of the operation) equal to Custom
and attribute op_type
which is an operation type defined by an user. Implement extractor class inherited from the MXNetCustomFrontExtractorOp
class instead of FrontExtractorOp
class used for standard framework operations in order to extract attributes for such kind of operations. The op
class attribute value should be set to the op_type
value so the extractor is triggered for this kind of operation.
There is the example of the extractor for the custom operation registered with type (op_type
value) equal to MyCustomOp
having attribute my_attribute
of the floating point type with default value 5.6
. In this sample we assume that we have already created the CustomOp
class (inherited from Op
class) for the Model Optimizer operation for this MXNet custom operation as described in the Customize_Model_Optimizer.
from extension.ops.custom_op import CustomOp # implementation of the MO operation class
from mo.front.mxnet.extractors.utils import get_mxnet_layer_attrs
from mo.front.extractor import MXNetCustomFrontExtractorOp
class CustomProposalFrontExtractor(MXNetCustomFrontExtractorOp): # inherit from specific base class
op = 'MyCustomOp' # the value corresponding to the `op_type` value of the MXNet operation
enabled = True # the extractor is enabled
@staticmethod
def extract(node):
attrs = get_mxnet_layer_attrs(node.symbol_dict) # parse the attributes to a dictionary with string values
node_attrs = {
'my_attribute': attrs.float('my_attribute', 5.6)
}
CustomOp.update_node_stat(node, node_attrs) # update the attributes of the node
return self.enabled