class ngraph::Function

Overview

A user-defined function. More…

#include <function.hpp>

class Function
{
public:
    // typedefs

    typedef std::function<std::vector<std::shared_ptr<Node>>(const std::vector<std::shared_ptr<Node>>&root_nodes)> topological_sort_t;

    // fields

    static constexpr DiscreteTypeInfo type_info {"Function", 0};

    // construction

    Function(
        const NodeVector& results,
        const ParameterVector& parameters,
        const std::string& name = ""
        );

    Function(
        const OutputVector& results,
        const ParameterVector& parameters,
        const std::string& name = ""
        );

    Function(
        const std::shared_ptr<Node>& result,
        const ParameterVector& parameters,
        const std::string& name = ""
        );

    Function(
        const ResultVector& results,
        const ParameterVector& parameters,
        const std::string& name = ""
        );

    Function(
        const ResultVector& results,
        const SinkVector& sinks,
        const ParameterVector& parameters,
        const std::string& name = ""
        );

    Function(
        const OutputVector& results,
        const SinkVector& sinks,
        const ParameterVector& parameters,
        const std::string& name = ""
        );

    Function(
        const ResultVector& results,
        const SinkVector& sinks,
        const ParameterVector& parameters,
        const VariableVector& variables,
        const std::string& name = ""
        );

    Function(
        const OutputVector& results,
        const SinkVector& sinks,
        const ParameterVector& parameters,
        const VariableVector& variables,
        const std::string& name = ""
        );

    Function(
        const ResultVector& results,
        const ParameterVector& parameters,
        const VariableVector& variables,
        const std::string& name = ""
        );

    Function(
        const OutputVector& results,
        const ParameterVector& parameters,
        const VariableVector& variables,
        const std::string& name = ""
        );

    Function(const OutputVector& results, const std::string& name = "");

    Function(
        const OutputVector& results,
        const SinkVector& sinks,
        const std::string& name = ""
        );

    // methods

    const DiscreteTypeInfo& get_type_info() const;
    size_t get_output_size() const;
    std::shared_ptr<Node> get_output_op(size_t i) const;
    Output<Node> output(size_t i) const;
    const element::Type& get_output_element_type(size_t i) const;
    const Shape& get_output_shape(size_t i) const;
    const PartialShape& get_output_partial_shape(size_t i) const;
    std::shared_ptr<Node> get_result() const;
    const std::string& get_name() const;
    void set_friendly_name(const std::string& name);
    const std::string& get_friendly_name() const;
    std::vector<std::shared_ptr<Node>> get_ops() const;
    std::vector<std::shared_ptr<Node>> get_ordered_ops() const;
    void map_unordered_ops(std::function<void(Node*)> f) const;
    void replace_node(std::shared_ptr<Node> old, std::shared_ptr<Node> repl);
    void validate_nodes_and_infer_types() const;
    size_t get_graph_size() const;
    bool is_dynamic() const;

    void replace_parameter(
        size_t parameter_index,
        const std::shared_ptr<op::Parameter>& parameter
        );

    void set_topological_sort(topological_sort_t);
    virtual bool visit_attributes(AttributeVisitor& visitor);
    const ParameterVector& get_parameters() const;
    const ResultVector& get_results() const;
    int64_t get_parameter_index(const std::shared_ptr<op::Parameter>& parameter) const;
    int64_t get_result_index(const Output<Node>& value) const;

    bool evaluate(
        const HostTensorVector& output_tensors,
        const HostTensorVector& input_tensors,
        EvaluationContext evaluation_context = EvaluationContext()
        ) const;

    const SinkVector& get_sinks() const;
    void add_sinks(const SinkVector& sinks);
    void remove_sink(const std::shared_ptr<op::Sink>& sink);
    void add_results(const ResultVector& results);
    void remove_result(const std::shared_ptr<op::Result>& result);
    void add_parameters(const ParameterVector& params);
    void remove_parameter(const std::shared_ptr<op::Parameter>& param);
    void add_variables(const VariableVector& variables);
    void remove_variable(const VariablePtr& variable);
    const VariableVector& get_variables() const;
    VariablePtr get_variable_by_id(const std::string& variable_id) const;
};

Detailed Documentation

A user-defined function.

Construction

Function(const OutputVector& results, const std::string& name = "")

Constructs a Function. Lists of parameters and variables will be generated automatically based on traversing the graph from the results.

Function(
    const OutputVector& results,
    const SinkVector& sinks,
    const std::string& name = ""
    )

Constructs a Function. Lists of parameters and variables will be generated automatically based on traversing the graph from the results and the sinks.

Methods

size_t get_output_size() const

Return the number of outputs for this function.

std::shared_ptr<Node> get_output_op(size_t i) const

Return the op that generates output i.

const element::Type& get_output_element_type(size_t i) const

Return the element type of output i.

const Shape& get_output_shape(size_t i) const

Return the shape of element i.

const PartialShape& get_output_partial_shape(size_t i) const

Return the partial shape of element i.

std::shared_ptr<Node> get_result() const

Check that there is a single result and return it.

const std::string& get_name() const

Get the unique name of the function.

Returns:

A const reference to the function’s unique name.

void set_friendly_name(const std::string& name)

Sets a friendly name for a function. This does not overwrite the unique name of the function and is retrieved via get_friendly_name(). Used mainly for debugging.

Parameters:

name

is the friendly name to set

const std::string& get_friendly_name() const

Gets the friendly name for a function. If no friendly name has been set via set_friendly_name then the function’s unique name is returned.

Returns:

A const reference to the function’s friendly name.

size_t get_graph_size() const

Returns the sum of the size of all nodes in the graph plus the size of all constant data. This has little value beyond comparing the relative size of graphs and should not be considered the actual memory consumption of a graph.

bool is_dynamic() const

Returns true if any of the op’s defined in the function contains partial shape.

void replace_parameter(
    size_t parameter_index,
    const std::shared_ptr<op::Parameter>& parameter
    )

Replace the parameter_index th parameter of the function with parameter.

All users of the parameter_index th parameter are redirected to parameter, and the parameter_index th entry in the function parameter list is replaced with parameter.

Parameters:

parameter_index

The index of the parameter to replace.

parameter

The parameter to substitute for the parameter_index th parameter.

const ParameterVector& get_parameters() const

Return the function parameters.

const ResultVector& get_results() const

Return a list of function’s outputs.

int64_t get_parameter_index(const std::shared_ptr<op::Parameter>& parameter) const

Index for parameter, or -1.

int64_t get_result_index(const Output<Node>& value) const

Index for value or result referencing it, or -1.

bool evaluate(
    const HostTensorVector& output_tensors,
    const HostTensorVector& input_tensors,
    EvaluationContext evaluation_context = EvaluationContext()
    ) const

Evaluate the function on inputs, putting results in outputs.

Parameters:

output_tensors

Tensors for the outputs to compute. One for each result

input_tensors

Tensors for the inputs. One for each inputs.

evaluation_context

Storage of additional settings and attributes that can be used when evaluating the function. This additional information can be shared across nodes.

const SinkVector& get_sinks() const

Return a list of function’s sinks.

void add_sinks(const SinkVector& sinks)

Add new sink nodes to the list. Method doesn’t validate graph, it should be done manually after all changes.

Parameters:

sinks

new sink nodes

void remove_sink(const std::shared_ptr<op::Sink>& sink)

Delete sink node from the list of sinks. Method doesn’t delete node from graph.

Parameters:

sink

Sink to delete

void add_results(const ResultVector& results)

Add new Result nodes to the list. Method doesn’t validate graph, it should be done manually after all changes.

Parameters:

results

new Result nodes

void remove_result(const std::shared_ptr<op::Result>& result)

Delete Result node from the list of results. Method will not delete node from graph.

Parameters:

result

Result node to delete

void add_parameters(const ParameterVector& params)

Add new Parameter nodes to the list.

Method doesn’t change or validate graph, it should be done manually. For example, if you want to replace ReadValue node by Parameter, you should do the following steps:

  • replace node ReadValue by Parameter in graph

  • call add_parameter() to add new input to the list

  • call graph validation to check correctness of changes

Parameters:

params

new Parameter nodes

void remove_parameter(const std::shared_ptr<op::Parameter>& param)

Delete Parameter node from the list of parameters. Method will not delete node from graph. You need to replace Parameter with other operation manually. Attention: Indexing of parameters can be changed.

Possible use of method is to replace input by variable. For it the following steps should be done:

  • Parameter node should be replaced by ReadValue

  • call remove_parameter(param) to remove input from the list

  • check if any parameter indexes are saved/used somewhere, update it for all inputs because indexes can be changed

  • call graph validation to check all changes

Parameters:

param

Parameter node to delete

void add_variables(const VariableVector& variables)

Add new variables to the list. Method doesn’t validate graph, it should be done manually after all changes.

Parameters:

variables

new variables to add

void remove_variable(const VariablePtr& variable)

Delete variable from the list of variables. Method doesn’t delete nodes that used this variable from the graph.

Parameters:

variable

Variable to delete

const VariableVector& get_variables() const

Return a list of function’s variables.

VariablePtr get_variable_by_id(const std::string& variable_id) const

Return a variable by specified variable_id.