namespace rnn {
// global functions
template <class OpType, class ShapeType>
void gru_cell_shape_infer(
const OpType \* op,
const std::vector<ShapeType>& input_shapes,
std::vector<ShapeType>& output_shapes
);
template <class OpType, class ShapeType>
void validate_inputs_rank(
const OpType \* op,
const std::vector<ShapeType>& input_shapes,
const std::vector<Rank>& expected_ranks
);
template <class OpType, class ShapeType>
void gru_sequence_shape_infer(
const OpType \* op,
const std::vector<ShapeType>& input_shapes,
std::vector<ShapeType>& output_shapes
);
} // namespace rnn