ie_infer_async_request_thread_safe_default.hpp
1 // Copyright (C) 2018-2021 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
10 
11 #include <cpp_interfaces/interface/ie_iinfer_request_internal.hpp>
12 
13 #include <exception>
14 #include <future>
15 #include <map>
16 #include <memory>
17 #include <mutex>
18 #include <string>
19 #include <tuple>
20 #include <utility>
21 #include <vector>
22 
23 namespace InferenceEngine {
24 
25 /**
26  * @ingroup ie_dev_api_async_infer_request_api
27  * @brief Base class with default implementation of asynchronous multi staged inference request.
28  * To customize pipeline stages derived class should change the content
29  * of AsyncInferRequestThreadSafeDefault::_pipeline member container.
30  * It consists of pairs of tasks and executors which will run the task.
31  * The class is recommended to be used by plugins as a base class for asynchronous inference request implementation.
32  * @note To synchronize derived context with stages
33  * derived class should call AsyncInferRequestThreadSafeDefault::StopAndWait() function in destructor.
34  * @par Example
35  * Here is an example of asynchronous inference request implementation for some accelerator device.
36  * It uses 5 different executors to run different stages of a synchronous inference request.
37  *
38  * @snippet example_async_infer_request.cpp async_infer_request:define_pipeline
39  */
41  enum InferState {Idle, Busy, Canceled, Stop};
42  using Futures = std::vector<std::shared_future<void>>;
43  using Promise = std::shared_ptr<std::promise<void>>;
44  enum Stage_e : std::uint8_t { executor, task };
45  IInferRequestInternal::Ptr _syncRequest;
46 
47  friend struct DisableCallbackGuard;
48  struct DisableCallbackGuard {
49  explicit DisableCallbackGuard(AsyncInferRequestThreadSafeDefault* this_)
50  : _this{this_} {
51  std::lock_guard<std::mutex> lock{_this->_mutex};
52  std::swap(_callback, _this->_callback);
53  }
54  ~DisableCallbackGuard() {
55  std::lock_guard<std::mutex> lock{_this->_mutex};
56  _this->_callback = _callback;
57  }
58  AsyncInferRequestThreadSafeDefault* _this = nullptr;
60  };
61 
62  struct ImmediateStreamsExecutor : public InferenceEngine::ITaskExecutor {
63  explicit ImmediateStreamsExecutor(const IStreamsExecutor::Ptr& streamsExecutor) : _streamsExecutor{streamsExecutor} {}
64  void run(InferenceEngine::Task task) override {_streamsExecutor->Execute(std::move(task));}
65  IStreamsExecutor::Ptr _streamsExecutor;
66  };
67 
68  template<typename F>
69  void InferImpl(const F& f) {
70  _syncRequest->checkBlobs();
71  InferState state = InferState::Idle;
72  {
73  std::lock_guard<std::mutex> lock{_mutex};
74  state = _state;
75  switch (_state) {
76  case InferState::Busy :
77  IE_THROW(RequestBusy);
78  case InferState::Canceled :
79  IE_THROW(InferCancelled);
80  case InferState::Idle : {
81  _futures.erase(std::remove_if(std::begin(_futures), std::end(_futures),
82  [](const std::shared_future<void>& future) {
83  if (future.valid()) {
84  return (std::future_status::ready ==
85  future.wait_for(std::chrono::milliseconds {0}));
86  } else {
87  return true;
88  }
89  }),
90  _futures.end());
91  _promise = {};
92  _futures.emplace_back(_promise.get_future().share());
93  } break;
94  case InferState::Stop : break;
95  }
96  _state = InferState::Busy;
97  }
98  if (state != InferState::Stop) {
99  try {
100  f();
101  } catch (...) {
102  _promise.set_exception(std::current_exception());
103  std::lock_guard<std::mutex> lock{_mutex};
104  _state = InferState::Idle;
105  throw;
106  }
107  }
108  }
109 
110 protected:
111  /**
112  * @brief Throws exception if inference request is busy or canceled
113  */
114  void CheckState() const {
115  std::lock_guard<std::mutex> lock {_mutex};
116  switch (_state) {
117  case InferState::Busy :
118  IE_THROW(RequestBusy);
119  case InferState::Canceled :
120  IE_THROW(InferCancelled);
121  default: break;
122  }
123  }
124 
125 public:
126  /**
127  * @brief A shared pointer to AsyncInferRequestThreadSafeDefault
128  */
129  using Ptr = std::shared_ptr<AsyncInferRequestThreadSafeDefault>;
130 
131  /**
132  * @brief Wraps a IInferRequestInternal::Ptr implementation and constructs a
133  * AsyncInferRequestThreadSafeDefault::_pipeline where `taskExecutor` is used to run IInferRequestInternal::Infer
134  * asynchronously.
135  *
136  * @param[in] request The synchronous request
137  * @param[in] taskExecutor The task executor
138  * @param[in] callbackExecutor The callback executor
139  */
141  const ITaskExecutor::Ptr& taskExecutor,
142  const ITaskExecutor::Ptr& callbackExecutor) :
143  _syncRequest {request},
144  _requestExecutor {taskExecutor},
145  _callbackExecutor {callbackExecutor},
146  _pipeline {{taskExecutor, [this] {_syncRequest->InferImpl();}}},
147  _syncPipeline {{std::make_shared<ImmediateExecutor>(), [this] {_syncRequest->InferImpl();}}} {
148  auto streamsExecutor = std::dynamic_pointer_cast<IStreamsExecutor>(taskExecutor);
149  if (streamsExecutor != nullptr) {
150  _syncPipeline = {{std::make_shared<ImmediateStreamsExecutor>(std::move(streamsExecutor)), [this] {_syncRequest->InferImpl();}}};
151  }
152  }
153 
154  /**
155  * @brief Destroys the object, stops AsyncInferRequestThreadSafeDefault::_pipeline and waits for a finish.
156  */
158  StopAndWait();
159  }
160 
161  /**
162  * @brief Waits for completion of all pipeline stages
163  * If the pipeline raises an exception it will be rethrown here
164  * @param millis_timeout A timeout is `ms` to wait or special enum value of InferRequest::WaitMode
165  * @return A status code
166  */
167  StatusCode Wait(int64_t millis_timeout) override {
168  if (millis_timeout < InferRequest::WaitMode::RESULT_READY) {
169  IE_THROW(ParameterMismatch)
170  << " Timeout can't be less "
171  << InferRequest::WaitMode::RESULT_READY << " for InferRequest::Wait\n";
172  }
173  auto status = std::future_status::deferred;
174 
175  // Just use the last '_futures' member to wait pipeline completion
176  auto future = [&] {
177  std::lock_guard<std::mutex> lock {_mutex};
178  return _futures.empty() ? std::shared_future<void> {} : _futures.back();
179  }();
180 
181  if (!future.valid()) {
182  return StatusCode::INFER_NOT_STARTED;
183  }
184 
185  switch (millis_timeout) {
186  case InferRequest::WaitMode::RESULT_READY: {
187  future.wait();
188  status = std::future_status::ready;
189  } break;
190  case InferRequest::WaitMode::STATUS_ONLY: {
191  status = future.wait_for(std::chrono::milliseconds {0});
192  } break;
193  default: {
194  status = future.wait_for(std::chrono::milliseconds {millis_timeout});
195  } break;
196  }
197 
198  if (std::future_status::ready == status) {
199  future.get();
200  return StatusCode::OK;
201  } else {
202  return StatusCode::RESULT_NOT_READY;
203  }
204  }
205 
206  void StartAsync() override {
207  InferImpl([&] {StartAsync_ThreadUnsafe();});
208  }
209 
210  void Infer() override {
211  DisableCallbackGuard disableCallbackGuard{this};
212  InferImpl([&] {Infer_ThreadUnsafe();});
213  Wait(InferRequest::WaitMode::RESULT_READY);
214  }
215 
216  std::map<std::string, InferenceEngineProfileInfo> GetPerformanceCounts() const override {
217  CheckState();
218  return _syncRequest->GetPerformanceCounts();
219  }
220 
221  void SetBlob(const std::string& name, const Blob::Ptr& data) override {
222  CheckState();
223  _syncRequest->SetBlob(name, data);
224  }
225 
226  void SetBlob(const std::string& name, const Blob::Ptr& data, const PreProcessInfo& info) override {
227  CheckState();
228  _syncRequest->SetBlob(name, data, info);
229  }
230 
231  Blob::Ptr GetBlob(const std::string& name) override {
232  CheckState();
233  return _syncRequest->GetBlob(name);
234  }
235 
236  const PreProcessInfo& GetPreProcess(const std::string& name) const override {
237  return _syncRequest->GetPreProcess(name);
238  }
239 
240  void SetBatch(int batch) override {
241  CheckState();
242  _syncRequest->SetBatch(batch);
243  };
244 
245  void SetCallback(Callback callback) override {
246  CheckState();
247  _callback = std::move(callback);
248  }
249 
250  std::vector<std::shared_ptr<InferenceEngine::IVariableStateInternal>> QueryState() override {
251  CheckState();
252  return _syncRequest->QueryState();
253  }
254 
255  void ThrowIfCanceled() const {
256  std::lock_guard<std::mutex> lock{_mutex};
257  if (_state == InferState::Canceled) {
258  IE_THROW(InferCancelled);
259  }
260  }
261 
262  void Cancel() override {
263  std::lock_guard<std::mutex> lock{_mutex};
264  if (_state == InferState::Busy) {
265  _state = InferState::Canceled;
266  }
267  }
268 
269 protected:
270  /**
271  * @brief Each pipeline stage is a @ref Task that is executed by specified ITaskExecutor implementation
272  */
273  using Stage = std::pair<ITaskExecutor::Ptr, Task>;
274  /**
275  * @brief Pipeline is vector of stages
276  */
277  using Pipeline = std::vector<Stage>;
278 
279  /**
280  * @brief Creates and run the first stage task. If destructor was not called add a new std::future to the
281  * AsyncInferRequestThreadSafeDefault::_futures list that would be used to wait
282  * AsyncInferRequestThreadSafeDefault::_pipeline finish
283  * @param[in] itBeginStage Iterator to begin of pipeline
284  * @param[in] itEndStage End pipeline iterator
285  * @param[in] callbackExecutor Final or error stage executor
286  */
287  void RunFirstStage(const Pipeline::iterator itBeginStage, const Pipeline::iterator itEndStage,
288  const ITaskExecutor::Ptr callbackExecutor = {}) {
289  auto& firstStageExecutor = std::get<Stage_e::executor>(*itBeginStage);
290  IE_ASSERT(nullptr != firstStageExecutor);
291  firstStageExecutor->run(MakeNextStageTask(itBeginStage, itEndStage, std::move(callbackExecutor)));
292  }
293 
294  /**
295  * @brief Forbids pipeline start and wait for all started pipelines.
296  * @note Should be called in derived class destructor to wait for completion of usage of derived context captured by
297  * pipeline tasks
298  */
299  void StopAndWait() {
300  Futures futures;
301  InferState state = InferState::Idle;
302  {
303  std::lock_guard<std::mutex> lock{_mutex};
304  state = _state;
305  if (state != InferState::Stop) {
306  _callback = {};
307  _state = InferState::Stop;
308  futures = std::move(_futures);
309  }
310  }
311  if (state != InferState::Stop) {
312  for (auto&& future : futures) {
313  if (future.valid()) {
314  future.wait();
315  }
316  }
317  }
318  }
319 
320 
321  ITaskExecutor::Ptr _requestExecutor; //!< Used to run inference CPU tasks.
322  ITaskExecutor::Ptr _callbackExecutor; //!< Used to run post inference callback in asynchronous pipline
323  ITaskExecutor::Ptr _syncCallbackExecutor; //!< Used to run post inference callback in synchronous pipline
324  Pipeline _pipeline; //!< Pipeline variable that should be filled by inherited class.
325  Pipeline _syncPipeline; //!< Synchronous pipeline variable that should be filled by inherited class.
326 
327  /**
328  * @brief Starts an asynchronous pipeline thread unsafe.
329  * @note Used by StartAsync which ensures thread-safety and calls this method after.
330  */
331  virtual void StartAsync_ThreadUnsafe() {
332  RunFirstStage(_pipeline.begin(), _pipeline.end(), _callbackExecutor);
333  }
334 
335  /**
336  * @brief Performs inference of pipeline in syncronous mode
337  * @note Used by Infer which ensures thread-safety and calls this method after.
338  */
339  virtual void Infer_ThreadUnsafe() {
340  RunFirstStage(_syncPipeline.begin(), _syncPipeline.end(), _syncCallbackExecutor);
341  }
342 
343  /**
344  * @brief Implements Infer() using StartAsync() and Wait()
345  */
347  StartAsync_ThreadUnsafe();
348  }
349 
350 private:
351  /**
352  * @brief Create a task with next pipeline stage.
353  * Each call to MakeNextStageTask() generates @ref Task objects for each stage.
354  * On last stage or if the exception is raised from `_pipeline` task
355  * the last stage task is called or passed to callback executor if it is presented. The last stage task call the
356  * callback, if it is presented, capture the `_promise` member and use it to forward completion or exception to the
357  * one of `_futures` member
358  * @param[in] itStage Iterator to next stage of pipeline
359  * @param[in] itEndStage End pipeline iterator
360  * @param[in] callbackExecutor Executor that will run final stage with callback call
361  * @return A next stage task
362  */
363  Task MakeNextStageTask(const Pipeline::iterator itStage, const Pipeline::iterator itEndStage,
364  const ITaskExecutor::Ptr callbackExecutor) {
365  return std::bind([this, itStage, itEndStage](ITaskExecutor::Ptr& callbackExecutor) mutable {
366  std::exception_ptr currentException = nullptr;
367  auto& thisStage = *itStage;
368  auto itNextStage = itStage + 1;
369  try {
370  auto& stageTask = std::get<Stage_e::task>(thisStage);
371  IE_ASSERT(nullptr != stageTask);
372  stageTask();
373  if (itEndStage != itNextStage) {
374  auto& nextStage = *itNextStage;
375  auto& nextStageExecutor = std::get<Stage_e::executor>(nextStage);
376  IE_ASSERT(nullptr != nextStageExecutor);
377  nextStageExecutor->run(MakeNextStageTask(itNextStage, itEndStage, std::move(callbackExecutor)));
378  }
379  } catch (...) {
380  currentException = std::current_exception();
381  }
382 
383  if ((itEndStage == itNextStage) || (nullptr != currentException)) {
384  auto lastStageTask = [this, currentException]() mutable {
385  auto promise = std::move(_promise);
386  Callback callback;
387  {
388  std::lock_guard<std::mutex> lock{_mutex};
389  _state = InferState::Idle;
390  callback = _callback;
391  }
392  if (callback) {
393  try {
394  auto local_callback = std::move(callback);
395  local_callback(currentException);
396  } catch (...) {
397  currentException = std::current_exception();
398  }
399  }
400  if (nullptr == currentException) {
401  promise.set_value();
402  } else {
403  promise.set_exception(currentException);
404  }
405  };
406 
407  if (nullptr == callbackExecutor) {
408  lastStageTask();
409  } else {
410  callbackExecutor->run(std::move(lastStageTask));
411  }
412  }
413  }, std::move(callbackExecutor));
414  }
415 
416  std::promise<void> _promise;
417  mutable std::mutex _mutex;
418  Futures _futures;
419  InferState _state = InferState::Idle;
420 };
421 } // namespace InferenceEngine
Base class with default implementation of asynchronous multi staged inference request....
Definition: ie_infer_async_request_thread_safe_default.hpp:40
void InferUsingAsync()
Implements Infer() using StartAsync() and Wait()
Definition: ie_infer_async_request_thread_safe_default.hpp:346
StatusCode Wait(int64_t millis_timeout) override
Waits for completion of all pipeline stages If the pipeline raises an exception it will be rethrown h...
Definition: ie_infer_async_request_thread_safe_default.hpp:167
std::map< std::string, InferenceEngineProfileInfo > GetPerformanceCounts() const override
Queries performance measures per layer to get feedback of what is the most time consuming layer....
Definition: ie_infer_async_request_thread_safe_default.hpp:216
void RunFirstStage(const Pipeline::iterator itBeginStage, const Pipeline::iterator itEndStage, const ITaskExecutor::Ptr callbackExecutor={})
Creates and run the first stage task. If destructor was not called add a new std::future to the Async...
Definition: ie_infer_async_request_thread_safe_default.hpp:287
std::pair< ITaskExecutor::Ptr, Task > Stage
Each pipeline stage is a Task that is executed by specified ITaskExecutor implementation.
Definition: ie_infer_async_request_thread_safe_default.hpp:273
void SetBlob(const std::string &name, const Blob::Ptr &data) override
Set input/output data to infer.
Definition: ie_infer_async_request_thread_safe_default.hpp:221
std::shared_ptr< AsyncInferRequestThreadSafeDefault > Ptr
A shared pointer to AsyncInferRequestThreadSafeDefault.
Definition: ie_infer_async_request_thread_safe_default.hpp:129
void Cancel() override
Cancel current inference request execution.
Definition: ie_infer_async_request_thread_safe_default.hpp:262
void Infer() override
Infers specified input(s) in synchronous mode.
Definition: ie_infer_async_request_thread_safe_default.hpp:210
void SetBatch(int batch) override
Sets new batch size when dynamic batching is enabled in executable network that created this request.
Definition: ie_infer_async_request_thread_safe_default.hpp:240
ITaskExecutor::Ptr _syncCallbackExecutor
Used to run post inference callback in synchronous pipline.
Definition: ie_infer_async_request_thread_safe_default.hpp:323
void SetCallback(Callback callback) override
Set callback function which will be called on success or failure of asynchronous request.
Definition: ie_infer_async_request_thread_safe_default.hpp:245
AsyncInferRequestThreadSafeDefault(const IInferRequestInternal::Ptr &request, const ITaskExecutor::Ptr &taskExecutor, const ITaskExecutor::Ptr &callbackExecutor)
Wraps a IInferRequestInternal::Ptr implementation and constructs a AsyncInferRequestThreadSafeDefault...
Definition: ie_infer_async_request_thread_safe_default.hpp:140
Blob::Ptr GetBlob(const std::string &name) override
Get input/output data to infer.
Definition: ie_infer_async_request_thread_safe_default.hpp:231
virtual void StartAsync_ThreadUnsafe()
Starts an asynchronous pipeline thread unsafe.
Definition: ie_infer_async_request_thread_safe_default.hpp:331
std::vector< Stage > Pipeline
Pipeline is vector of stages.
Definition: ie_infer_async_request_thread_safe_default.hpp:277
void StartAsync() override
Start inference of specified input(s) in asynchronous mode.
Definition: ie_infer_async_request_thread_safe_default.hpp:206
Pipeline _syncPipeline
Synchronous pipeline variable that should be filled by inherited class.
Definition: ie_infer_async_request_thread_safe_default.hpp:325
void CheckState() const
Throws exception if inference request is busy or canceled.
Definition: ie_infer_async_request_thread_safe_default.hpp:114
void SetBlob(const std::string &name, const Blob::Ptr &data, const PreProcessInfo &info) override
Sets pre-process for input data.
Definition: ie_infer_async_request_thread_safe_default.hpp:226
~AsyncInferRequestThreadSafeDefault()
Destroys the object, stops AsyncInferRequestThreadSafeDefault::_pipeline and waits for a finish.
Definition: ie_infer_async_request_thread_safe_default.hpp:157
ITaskExecutor::Ptr _callbackExecutor
Used to run post inference callback in asynchronous pipline.
Definition: ie_infer_async_request_thread_safe_default.hpp:322
virtual void Infer_ThreadUnsafe()
Performs inference of pipeline in syncronous mode.
Definition: ie_infer_async_request_thread_safe_default.hpp:339
std::vector< std::shared_ptr< InferenceEngine::IVariableStateInternal > > QueryState() override
Queries memory states.
Definition: ie_infer_async_request_thread_safe_default.hpp:250
Pipeline _pipeline
Pipeline variable that should be filled by inherited class.
Definition: ie_infer_async_request_thread_safe_default.hpp:324
void StopAndWait()
Forbids pipeline start and wait for all started pipelines.
Definition: ie_infer_async_request_thread_safe_default.hpp:299
ITaskExecutor::Ptr _requestExecutor
Used to run inference CPU tasks.
Definition: ie_infer_async_request_thread_safe_default.hpp:321
const PreProcessInfo & GetPreProcess(const std::string &name) const override
Gets pre-process for input data.
Definition: ie_infer_async_request_thread_safe_default.hpp:236
std::shared_ptr< Blob > Ptr
An internal API of synchronous inference request to be implemented by plugin, which is used in InferR...
Definition: ie_iinfer_request_internal.hpp:28
virtual void InferImpl()
The minimal infer function to be implemented by plugins. It infers specified input(s) in synchronous ...
Callback _callback
A callback.
Definition: ie_iinfer_request_internal.hpp:239
std::shared_ptr< IInferRequestInternal > Ptr
A shared pointer to a IInferRequestInternal interface.
Definition: ie_iinfer_request_internal.hpp:33
std::function< void(std::exception_ptr)> Callback
Alias for callback type.
Definition: ie_iinfer_request_internal.hpp:147
std::shared_ptr< IStreamsExecutor > Ptr
Definition: ie_istreams_executor.hpp:36
Interface for Task Executor. Inference Engine uses InferenceEngine::ITaskExecutor interface to run al...
Definition: ie_itask_executor.hpp:46
std::shared_ptr< ITaskExecutor > Ptr
Definition: ie_itask_executor.hpp:51
std::function< void()> Task
Inference Engine Task Executor can use any copyable callable without parameters and output as a task....
Definition: ie_itask_executor.hpp:25
#define IE_THROW(...)
#define IE_ASSERT(EXPRESSION)
A header file for Inference Engine Immediate Executor implementation.
A header file for Inference Engine Streams-based Executor Interface.
A header file for Inference Engine Task Executor Interface.
Inference Engine Plugin API namespace.