namespace v7 {
// classes
class DFT;
class Einsum;
class Gather;
class Gelu;
class IDFT;
class Roll;
// global functions
template <class T>
void shape_infer(
const Einsum \* op,
const std::vector<T>& input_shapes,
std::vector<T>& output_shapes
);
template <class TShape>
std::vector<TShape> shape_infer(
const Roll \* op,
const std::vector<TShape>& input_shapes,
const std::map<size_t, HostTensorPtr>& constant_data = {}
);
template <class TShape>
void shape_infer(
const Roll \* op,
const std::vector<TShape>& input_shapes,
std::vector<TShape>& output_shapes,
const std::map<size_t, HostTensorPtr>& constant_data = {}
);
} // namespace v7