class ngraph::pattern::op::Label

Overview

Fails if the predicate returns false on the graph value. More…

#include <label.hpp>

class Label: public ngraph::pattern::op::Pattern
{
public:
    // fields

    static constexpr NodeTypeInfo type_info {"patternLabel", 0};

    // construction

    Label(
        const element::Type& type,
        const PartialShape& s,
        const ValuePredicate pred,
        const OutputVector& wrapped_values
        );

    Label(
        const element::Type& type = element::dynamic,
        const PartialShape& s = PartialShape::dynamic()
        );

    Label(const element::Type& type, const PartialShape& s, ValuePredicate pred);
    Label(const element::Type& type, const PartialShape& s, NodePredicate pred);

    Label(
        const element::Type& type,
        const PartialShape& s,
        const NodePredicate pred,
        const NodeVector& wrapped_values
        );

    Label(
        const Output<Node>& value,
        const ValuePredicate pred,
        const OutputVector& wrapped_values
        );

    Label(const Output<Node>& value, const ValuePredicate pred);
    Label(const Output<Node>& value, const NodePredicate pred);
    Label(const Output<Node>& value);

    Label(
        const Output<Node>& node,
        const NodePredicate pred,
        const NodeVector& wrapped_values
        );

    // methods

    virtual const NodeTypeInfo& get_type_info() const;

    virtual bool match_value(
        Matcher* matcher,
        const Output<Node>& pattern_value,
        const Output<Node>& graph_value
        );
};

Inherited Members

public:
    // typedefs

    typedef DiscreteTypeInfo type_info_t;
    typedef std::map<std::string, std::shared_ptr<Variant>> RTMap;

    // fields

    NGRAPH_DEPRECATED("The tensor name was deprecated. Use get_input_tensor(i).get_names() instead.") const std std::unordered_set<descriptor::Tensor*> liveness_new_list;
    std::unordered_set<descriptor::Tensor*> liveness_free_list;

    // methods

    virtual void validate_and_infer_types();
    void constructor_validate_and_infer_types();
    virtual bool visit_attributes(AttributeVisitor&);
    virtual const op::AutoBroadcastSpec& get_autob() const;
    virtual bool has_evaluate() const;

    virtual bool evaluate(
        const HostTensorVector& output_values,
        const HostTensorVector& input_values
        ) const;

    virtual bool evaluate(
        const HostTensorVector& output_values,
        const HostTensorVector& input_values,
        const EvaluationContext& evaluationContext
        ) const;

    virtual bool evaluate_lower(const HostTensorVector& output_values) const;
    virtual bool evaluate_upper(const HostTensorVector& output_values) const;

    virtual bool constant_fold(
        OutputVector& output_values,
        const OutputVector& inputs_values
        );

    virtual OutputVector decompose_op() const;
    virtual const type_info_t& get_type_info() const = 0;
    const char* get_type_name() const;
    void set_arguments(const NodeVector& arguments);
    void set_arguments(const OutputVector& arguments);
    void set_argument(size_t position, const Output<Node>& argument);

    void set_output_type(
        size_t i,
        const element::Type& element_type,
        const PartialShape& pshape
        );

    void set_output_size(size_t output_size);
    void invalidate_values();
    virtual void revalidate_and_infer_types();
    virtual std::string description() const;
    const std::string& get_name() const;
    void set_friendly_name(const std::string& name);
    const std::string& get_friendly_name() const;
    virtual bool is_dynamic() const;
    size_t get_instance_id() const;
    virtual std::ostream& write_description(std::ostream& os, uint32_t depth = 0) const;
    const std::vector<std::shared_ptr<Node>>& get_control_dependencies() const;
    const std::vector<Node*>& get_control_dependents() const;
    void add_control_dependency(std::shared_ptr<Node> node);
    void remove_control_dependency(std::shared_ptr<Node> node);
    void clear_control_dependencies();
    void clear_control_dependents();
    void add_node_control_dependencies(std::shared_ptr<Node> source_node);
    void add_node_control_dependents(std::shared_ptr<Node> source_node);
    void transfer_control_dependents(std::shared_ptr<Node> replacement);
    size_t get_output_size() const;
    const element::Type& get_output_element_type(size_t i) const;
    const element::Type& get_element_type() const;
    const Shape& get_output_shape(size_t i) const;
    const PartialShape& get_output_partial_shape(size_t i) const;
    Output<const Node> get_default_output() const;
    Output<Node> get_default_output();
    virtual size_t get_default_output_index() const;
    size_t no_default_index() const;
    const Shape& get_shape() const;
    descriptor::Tensor& get_output_tensor(size_t i) const;
    descriptor::Tensor& get_input_tensor(size_t i) const;
    NGRAPH_DEPRECATED("The tensor name was deprecated. Use get_output_tensor(i).get_names() instead.") const std std::set<Input<Node>> get_output_target_inputs(size_t i) const;
    size_t get_input_size() const;
    const element::Type& get_input_element_type(size_t i) const;
    const Shape& get_input_shape(size_t i) const;
    const PartialShape& get_input_partial_shape(size_t i) const;
    Node* get_input_node_ptr(size_t index) const;
    std::shared_ptr<Node> get_input_node_shared_ptr(size_t index) const;
    Output<Node> get_input_source_output(size_t i) const;
    virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector& inputs) const = 0;
    std::shared_ptr<Node> copy_with_new_inputs(const OutputVector& new_args) const;

    std::shared_ptr<Node> copy_with_new_inputs(
        const OutputVector& inputs,
        const std::vector<std::shared_ptr<Node>>& control_dependencies
        ) const;

    bool has_same_type(std::shared_ptr<const Node> node) const;
    RTMap& get_rt_info();
    const RTMap& get_rt_info() const;
    const std::unordered_set<std::string>& get_provenance_tags() const;
    void add_provenance_tag(const std::string& tag);

    template <typename T>
    void add_provenance_tags(T tag_set);

    void add_provenance_tags_above(
        const OutputVector& base,
        const std::unordered_set<std::string>& tag_set
        );

    void remove_provenance_tag(const std::string& tag);
    void add_provenance_group_member(const std::shared_ptr<Node>& node);
    void remove_provenance_group_member(const std::shared_ptr<Node>& node);

    void replace_provenance_group_member(
        const std::shared_ptr<Node>& current_node,
        const std::shared_ptr<Node>& replacement_node
        );

    const std::set<std::shared_ptr<Node>>& get_provenance_group_members() const;
    std::shared_ptr<Node> add_provenance_group_members_above(const OutputVector& base);
    void merge_provenance_tags_from(const std::shared_ptr<const Node>& source);
    void transfer_provenance_tags(const std::shared_ptr<Node>& replacement);
    NodeVector get_users(bool check_is_used = false) const;
    virtual size_t get_version() const;
    virtual std::shared_ptr<Node> get_default_value() const;
    bool operator < (const Node& other) const;
    std::vector<Input<Node>> inputs();
    std::vector<Input<const Node>> inputs() const;
    std::vector<Output<Node>> input_values() const;
    std::vector<Output<Node>> outputs();
    std::vector<Output<const Node>> outputs() const;
    Input<Node> input(size_t input_index);
    Input<const Node> input(size_t input_index) const;
    Output<Node> input_value(size_t input_index) const;
    Output<Node> output(size_t output_index);
    Output<const Node> output(size_t output_index) const;
    void set_op_annotations(std::shared_ptr<ngraph::op::util::OpAnnotations> op_annotations);
    std::shared_ptr<ngraph::op::util::OpAnnotations> get_op_annotations() const;

    virtual bool match_value(
        pattern::Matcher* matcher,
        const Output<Node>& pattern_value,
        const Output<Node>& graph_value
        );

    virtual bool match_node(
        pattern::Matcher* matcher,
        const Output<Node>& graph_value
        );

    virtual std::shared_ptr<Node> clone_with_new_inputs(const OutputVector&) const;
    ValuePredicate get_predicate() const;

Detailed Documentation

Fails if the predicate returns false on the graph value.

The graph value is added to the matched values list. If the Label is already associated with a value, the match succeeds if the value is the same as the graph value. Otherwise, the label is associated with the graph value and the match succeeds if the pattern input matches the graph value.

DEPRECATED: If no inputs are given to Label, a True node is serves as the input. If more than one inputs are given, an Or pattern of the inputs serves as the input.

Construction

Label(
    const element::Type& type,
    const PartialShape& s,
    const ValuePredicate pred,
    const OutputVector& wrapped_values
    )

creates a Label node containing a sub-pattern described by

this Label node can be bound only to the nodes in the input graph that match the pattern specified by

See also:

type and

shape.

wrapped_nodes Example:

auto add = a + b; // a and b are op::Parameter in this example
auto label = std::make_shared<pattern::op::Label>(element::f32,
                                                  Shape{2,2},
                                                  nullptr,
                                                  OutputVector{add});
Label(
    const Output<Node>& value,
    const ValuePredicate pred,
    const OutputVector& wrapped_values
    )

creates a Label node containing a sub-pattern described by the type and shape of

this Label node can be bound only to the nodes in the input graph that match the pattern specified by

See also:

node.

wrapped_values Example:

auto add = a + b; // a and b are op::Parameter in this example
auto label = std::make_shared<pattern::op::Label>(add,
                                                  nullptr,
                                                  OutputVector{add});

Methods

virtual const NodeTypeInfo& get_type_info() const

Returns the NodeTypeInfo for the node’s class. During transition to type_info, returns a dummy type_info for Node if the class has not been updated yet.