einsum.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
7 #include "ngraph/node.hpp"
8 #include "ngraph/op/op.hpp"
9 
10 namespace ngraph
11 {
12  namespace op
13  {
14  namespace v7
15  {
16  /// \brief Einsum operation.
17  class NGRAPH_API Einsum : public Op
18  {
19  public:
20  NGRAPH_RTTI_DECLARATION;
21 
22  Einsum() = default;
23 
24  ///
25  /// \brief Constructs Einsum operation.
26  ///
27  /// \param inputs Input nodes on which Einsum operation performs
28  /// contraction
29  ///
30  /// \param equation Einstein summation convention
31  ///
32  Einsum(const OutputVector& inputs, const std::string& equation);
33 
34  void validate_and_infer_types() override;
35 
36  bool visit_attributes(AttributeVisitor& visitor) override;
37 
38  std::shared_ptr<Node>
39  clone_with_new_inputs(const OutputVector& new_args) const override;
40 
41  /// \brief Get an equation of Einsum operation
42  ///
43  /// \return Einsum equation
44  ///
45  std::string get_equation() const { return m_equation; }
46 
47  /// \brief Check correctness of equation format and extract input subscripts
48  /// and output subscript
49  ///
50  /// \param equation Equation to be parsed and checked
51  ///
52  /// \param input_subscripts A vector of extracted input subscripts
53  ///
54  /// \param output_subscript An output subscript
55  ///
56  static void parse_equation(const std::string& equation,
57  std::vector<std::string>& input_subscripts,
58  std::string& output_subscript);
59 
60  /// \brief Extract labels (from subscript) that can be alphabetic letters or
61  /// ellipsis
62  ///
63  /// \param subscript Subscript
64  ///
65  /// \return A vector of extracted labels from the input subscript in the order
66  /// of appearence
67  ///
68  static std::vector<std::string> extract_labels(const std::string& subscript);
69 
70  private:
71  std::string m_equation;
72  };
73  } // namespace v7
74  } // namespace op
75 } // namespace ngraph
Visits the attributes of a node, primarily for serialization-like tasks.
Definition: attribute_visitor.hpp:59
Root of all actual ops.
Definition: op.hpp:17
Einsum operation.
Definition: einsum.hpp:18
static void parse_equation(const std::string &equation, std::vector< std::string > &input_subscripts, std::string &output_subscript)
Check correctness of equation format and extract input subscripts and output subscript.
void validate_and_infer_types() override
Verifies that attributes and inputs are consistent and computes output shapes and element types....
std::string get_equation() const
Get an equation of Einsum operation.
Definition: einsum.hpp:45
Einsum(const OutputVector &inputs, const std::string &equation)
Constructs Einsum operation.
static std::vector< std::string > extract_labels(const std::string &subscript)
Extract labels (from subscript) that can be alphabetic letters or ellipsis.
The Intel nGraph C++ API.
Definition: attribute_adapter.hpp:16