class ngraph::pass::low_precision::NetworkHelper

NetworkHelper class encapsulates manipulations with nGraph function.

#include <network_helper.hpp>

class NetworkHelper
{
public:
    // classes

    class InsertDequantizationResult;

    // methods

    static bool is_castable_to_one_of(
        NodeTypeInfo type,
        const std::unordered_set<NodeTypeInfo>& types
        );

    static std::vector<Input<Node>> consumer_inputs(std::shared_ptr<Node> node);
    static std::vector<std::shared_ptr<Node>> consumers(std::shared_ptr<Node> node);
    static bool isConstantPath(const std::shared_ptr<Node>& op);

    template <typename OperationType>
    static std::shared_ptr<Node> setOutDataPrecisionForTypeRelaxed(
        std::shared_ptr<OperationType> operation,
        const element::Type& precision
        );

    template <typename OperationType>
    static std::shared_ptr<Node> setOutDataPrecision(
        std::shared_ptr<OperationType> operation,
        const element::Type& precision
        );

    static std::shared_ptr<opset1::Constant> foldDequantizationConstant(
        const std::shared_ptr<opset1::Constant>& foldingConstant,
        const std::shared_ptr<Node>& operation,
        const size_t outIdx = 0
        );

    static size_t getOutputChannelsCount(
        std::shared_ptr<const Node> layer,
        bool isOnWeights = false
        );

    static std::vector<std::shared_ptr<Node>> getParentsRecursivelyExceptTypes(
        std::shared_ptr<Node> layer,
        const std::unordered_set<NodeTypeInfo>& exceptionLayerTypes = {},
        const int portIndex = -1
        );

    static size_t getInputChannelsCount(std::shared_ptr<Node> layer);
    static size_t getGroupsCount(std::shared_ptr<Node> layer);
    static void removeLayer(std::shared_ptr<Node> node);

    static std::shared_ptr<Node> swapMultiplyAndAdd(
        std::shared_ptr<opset1::Add> addAfterMultiply,
        const int multiplyBranch
        );

    static void copyInfo(
        const std::vector<std::shared_ptr<Node>>& sources,
        const std::vector<std::shared_ptr<Node>>& targets,
        bool overrideName = true
        );

    static void copyInfo(
        const std::vector<std::shared_ptr<Node>>& sources,
        const std::shared_ptr<Node>& target,
        bool overrideName = true
        );

    static void copyInfo(
        const std::shared_ptr<Node>& source,
        const std::shared_ptr<Node>& target,
        bool overrideName = true
        );

    static bool isScalarLike(std::shared_ptr<opset1::Constant> constant);
    static bool isZero(std::shared_ptr<opset1::Constant> constant);
    static std::shared_ptr<opset1::Constant> toScalar(std::shared_ptr<opset1::Constant> constant);

    static std::shared_ptr<Node> getConstantInput(
        const std::shared_ptr<const Node>& node,
        const bool convertIsExpected = false
        );

    static std::vector<size_t> updateReshapeValues(
        const Shape& elementwiseConstantShape,
        const Shape& elementwiseShape,
        const std::vector<size_t>& reshapeValues
        );

    static std::shared_ptr<ngraph::opset1::Multiply> optimizeMultipliesAfter(std::shared_ptr<Node> multiply);

    static std::shared_ptr<opset1::Constant> round(
        std::shared_ptr<Node> node,
        element::Type target_type
        );

    static std::shared_ptr<opset1::FakeQuantize> composeFakeQuantize(
        const std::shared_ptr<opset1::FakeQuantize>& fq,
        const std::vector<ngraph::element::Type>& defaultPrecisions = precision_set::int8_support
        );

    static std::tuple<std::shared_ptr<Node>, std::shared_ptr<Node>> decomposeFakeQuantize(
        std::shared_ptr<opset1::FakeQuantize> fq,
        const element::Type precision,
        const float min,
        const float max,
        const bool hasZeroPoint,
        const bool updatePrecision,
        const element::Type deqPrecision = element::f32,
        const size_t outChannelsShapeIndex = 0
        );

    static std::shared_ptr<opset1::FakeQuantize> updateFakeQuantize(
        std::shared_ptr<opset1::FakeQuantize> fq,
        element::Type precision,
        float min,
        float max,
        const bool replace = true
        );

    static FakeQuantizeDequantization makeDequantization(
        const float dequantizationMul,
        const float dequantizationSub,
        const ngraph::element::Type originalPrecision,
        const ngraph::PartialShape& dataNodeOutputShape,
        element::Type precision,
        const element::Type deqPrecision = element::f32,
        std::shared_ptr<ngraph::Node> input = nullptr
        );

    static std::shared_ptr<ngraph::Node> makeDequantizationSubtract(
        const ngraph::Output<ngraph::Node>& parent,
        const ngraph::Output<ngraph::Node>& subtract_constant
        );

    static FakeQuantizeDequantization createDequantizationFromFakeQuantize(
        std::shared_ptr<opset1::FakeQuantize> fq,
        element::Type precision,
        float min,
        float max,
        const bool hasZeroPoint,
        const bool updatePrecision,
        const element::Type deqPrecision = element::f32
        );

    static bool areQuantizeAndDequantizeSupportedForSubtract(
        const std::shared_ptr<const ngraph::Node>& node,
        const std::vector<ngraph::element::Type>& defaultPrecisions = precision_set::int8_support
        );

    static bool areQuantizeAndDequantizeSupportedForMultiply(
        const std::shared_ptr<const ngraph::Node>& node,
        const std::vector<ngraph::element::Type>& _defaultPrecisions = precision_set::int8_support
        );

    static bool isQuantizeSupported(const std::shared_ptr<opset1::FakeQuantize>& fakeQuantize);

    static FakeQuantizeDequantization getDequantization(
        const std::shared_ptr<const Node>& node,
        const std::vector<ngraph::element::Type> _defaultPrecisions = precision_set::int8_support,
        const size_t parentIndex = 0ul,
        const bool inPlace = false
        );

    static FakeQuantizeDequantization getDequantizationBelow(
        const std::shared_ptr<Node>& node,
        const bool convertIsMandatory = false
        );

    static FakeQuantizeDequantization normalizeDequantization(FakeQuantizeDequantization dequantization);

    static std::shared_ptr<opset1::Constant> normalizeDequantizationShape(
        const std::shared_ptr<Node>& eltwise,
        const bool convertIsExpected = true
        );

    static std::shared_ptr<Node> optimizeSubtract(std::shared_ptr<opset1::Subtract> add);

    static InsertDequantizationResult moveDequantizationAfter(
        const std::shared_ptr<ngraph::Node>& operation,
        const FakeQuantizeDequantization& dequantization,
        const bool updatePrecision,
        const bool moveSubtract,
        const std::vector<ngraph::element::Type>& defaultPrecisions = precision_set::int8_support
        );

    static InsertDequantizationResult moveDequantizationBefore(
        const std::shared_ptr<ngraph::Node>& operation,
        const FakeQuantizeDequantization& dequantization,
        const bool updatePrecision,
        const bool moveSubtract
        );

    static std::vector<std::vector<std::shared_ptr<ngraph::opset1::Constant>>> splitConstantsBeforeConcat(
        const std::shared_ptr<ov::Node> concat,
        const std::vector<std::shared_ptr<opset1::Constant>> currConstants
        );

    static bool checkConstantValuePrecision(
        const element::Type expectedPrecision,
        const std::shared_ptr<Node>& constant
        );

    static size_t getChildInputIndex(
        const std::shared_ptr<ngraph::Node>& parent,
        const std::shared_ptr<ngraph::Node>& child
        );

    static size_t getParentOutputIndex(
        const std::shared_ptr<ngraph::Node>& parent,
        const std::shared_ptr<ngraph::Node>& child
        );

    static FakeQuantizeDequantizationValues createEmptyValues(
        const FakeQuantizeDequantization& dequantization,
        const element::Type precision
        );

    static bool isZeroConst(const std::shared_ptr<Node>& node);

    static bool checkZeroPoint(
        const std::shared_ptr<Node>& node,
        const DataPrecision& dataPrecision = DataPrecision()
        );

    static std::shared_ptr<Node> toScalarIfPossible(std::shared_ptr<Node> node);
    static std::shared_ptr<Node> fold_fake_quantize(const std::shared_ptr<opset1::FakeQuantize>& fq);

    static std::shared_ptr<Node> fold_fake_quantize(
        const std::shared_ptr<opset1::FakeQuantize>& fq,
        const bool roundValues,
        int outChannelsShapeIndex = 0
        );

    static FakeQuantizeDequantization foldDequantization(
        const std::shared_ptr<Node>& node,
        const size_t branchIndex,
        const std::vector<ngraph::element::Type>& defaultPrecisions = precision_set::int8_support,
        const bool inPlace = false
        );

    static std::shared_ptr<ngraph::Node> separateInStandaloneBranch(
        std::shared_ptr<ngraph::Node> node,
        const std::vector<ngraph::element::Type>& defaultPrecisions = precision_set::int8_support
        );

    static std::shared_ptr<opset1::FakeQuantize> fuseConvert(const std::shared_ptr<opset1::FakeQuantize>& fakeQuantize);

    static std::vector<element::Type> precisionIntersection(
        const std::vector<element::Type>& v1,
        const std::vector<element::Type>& v2
        );

    static bool isPrecisionPreserved(const std::shared_ptr<ngraph::Node>& node);

    static void insertDequantizationAfter(
        const std::shared_ptr<Node>& originalNode,
        const std::shared_ptr<Node>& dequantization,
        const std::shared_ptr<Node>& newNode
        );

    template <typename SharedAttribute>
    static void reassign(
        const std::shared_ptr<typename SharedAttribute::SharedValueAttribute::SharedValue>& sharedValue,
        const std::vector<std::weak_ptr<typename SharedAttribute::SharedValueAttribute>>& attributes
        );

    static size_t calculateLevels(
        const float dataPrecisionMin,
        const float dataPrecisionMax,
        const float combinedIntervalLow,
        const float combinedIntervalHigh,
        const float minIntervalLow,
        const float minIntervalHigh,
        float& dequantizationMul,
        float& dequantizationSub,
        float& updatedOutputLowValue,
        float& updatedOutputHighValue
        );

    static ov::Output<ov::Node> getSingleConsumerConstant(const ov::Output<ov::Node>& output);
};