ie_infer_async_request_thread_safe_default.hpp
1 // Copyright (C) 2018-2020 Intel Corporation
2 // SPDX-License-Identifier: Apache-2.0
3 //
4 
5 #pragma once
6 
9 
10 #include <cpp_interfaces/interface/ie_iinfer_async_request_internal.hpp>
11 #include <cpp_interfaces/impl/ie_infer_async_request_thread_safe_internal.hpp>
13 #include <ie_system_conf.h>
14 
15 #include <exception>
16 #include <future>
17 #include <map>
18 #include <memory>
19 #include <mutex>
20 #include <string>
21 #include <tuple>
22 #include <utility>
23 #include <vector>
24 
25 namespace InferenceEngine {
26 
27 /**
28  * @ingroup ie_dev_api_async_infer_request_api
29  * @brief Base class with default implementation of asynchronous multi staged inference request.
30  * To customize pipeline stages derived class should change the content
31  * of AsyncInferRequestThreadSafeDefault::_pipeline member container.
32  * It consists of pairs of tasks and executors which will run the task.
33  * The class is recommended to be used by plugins as a base class for asynchronous inference request implementation.
34  * @note To synchronize derived context with stages
35  * derived class should call AsyncInferRequestThreadSafeDefault::StopAndWait() function in destructor.
36  * @par Example
37  * Here is an example of asynchronous inference request implementation for some accelerator device.
38  * It uses 5 different executors to run different stages of a synchronous inference request.
39  *
40  * @snippet example_async_infer_request.cpp async_infer_request:define_pipeline
41  */
43  using AtomicCallback = std::atomic<IInferRequest::CompletionCallback>;
44  using Futures = std::vector<std::shared_future<void>>;
45  using Promise = std::shared_ptr<std::promise<void>>;
46  enum Stage_e : std::uint8_t { executor, task };
47  struct DisableCallbackGuard{
48  explicit DisableCallbackGuard(AtomicCallback& callback)
49  : _callbackRef(callback), _callback(callback.exchange(nullptr)) {}
50  ~DisableCallbackGuard() {
51  _callbackRef = _callback;
52  }
53  AtomicCallback& _callbackRef;
55  };
56  InferRequestInternal::Ptr _syncRequest;
57 
58 public:
59  /**
60  * @brief A shared pointer to AsyncInferRequestThreadSafeDefault
61  */
62  using Ptr = std::shared_ptr<AsyncInferRequestThreadSafeDefault>;
63 
64  /**
65  * @brief Wraps a InferRequestInternal::Ptr implementation and constructs a
66  * AsyncInferRequestThreadSafeDefault::_pipeline where `taskExecutor` is used to run InferRequestInternal::Infer
67  * asynchronously.
68  *
69  * @param[in] request The synchronous request
70  * @param[in] taskExecutor The task executor
71  * @param[in] callbackExecutor The callback executor
72  */
74  const ITaskExecutor::Ptr& taskExecutor,
75  const ITaskExecutor::Ptr& callbackExecutor)
76  : _syncRequest {request},
77  _requestExecutor {taskExecutor},
78  _callbackExecutor {callbackExecutor},
79  _pipeline {{taskExecutor, [this] {_syncRequest->Infer();}}},
80  _syncPipeline{{std::make_shared<ImmediateExecutor>(), [this] {_syncRequest->Infer();}}} {
81  }
82 
83  /**
84  * @brief Destroys the object, stops AsyncInferRequestThreadSafeDefault::_pipeline and waits for a finish.
85  */
87  StopAndWait();
88  }
89 
90  /**
91  * @brief Waits for completion of all pipeline stages
92  * If the pipeline raises an exception it will be rethrown here
93  * @param millis_timeout A timeout is `ms` to wait or special enum value of IInferRequest::WaitMode
94  * @return A status code
95  */
96  StatusCode Wait(int64_t millis_timeout) override {
97  if (millis_timeout < IInferRequest::WaitMode::RESULT_READY) {
98  THROW_IE_EXCEPTION << PARAMETER_MISMATCH_str + "Timeout can't be less "
99  << IInferRequest::WaitMode::RESULT_READY << " for InferRequest::Wait\n";
100  }
101  auto status = std::future_status::deferred;
102 
103  // Just use the last '_futures' member to wait pipeline completion
104  auto future = [&] {
105  std::lock_guard<std::mutex> lock {_mutex};
106  return _futures.empty() ? std::shared_future<void> {} : _futures.back();
107  }();
108 
109  if (!future.valid()) {
110  return StatusCode::INFER_NOT_STARTED;
111  }
112 
113  switch (millis_timeout) {
114  case IInferRequest::WaitMode::RESULT_READY: {
115  future.wait();
116  status = std::future_status::ready;
117  } break;
118  case IInferRequest::WaitMode::STATUS_ONLY: {
119  status = future.wait_for(std::chrono::milliseconds {0});
120  } break;
121  default: {
122  status = future.wait_for(std::chrono::milliseconds {millis_timeout});
123  } break;
124  }
125 
126  if (std::future_status::ready == status) {
127  future.get();
128  return StatusCode::OK;
129  } else {
130  return StatusCode::RESULT_NOT_READY;
131  }
132  }
133 
134  /**
135  * @brief Sets the pointer to public interface.
136  * @note Needed to correctly handle ownership between objects
137  * @param[in] ptr A shared pointer to a public IInferRequest interface.
138  */
140  _publicInterface = std::shared_ptr<IInferRequest>(ptr.get(), [](IInferRequest*) {});
141  }
142 
143 protected:
144  /**
145  * @brief Each pipeline stage is a @ref Task that is executed by specified ITaskExecutor implementation
146  */
147  using Stage = std::pair<ITaskExecutor::Ptr, Task>;
148  /**
149  * @brief Pipeline is vector of stages
150  */
151  using Pipeline = std::vector<Stage>;
152 
153  /**
154  * @brief Creates and run the first stage task. If destructor was not called add a new std::future to the
155  * AsyncInferRequestThreadSafeDefault::_futures list that would be used to wait
156  * AsyncInferRequestThreadSafeDefault::_pipeline finish
157  * @param[in] itBeginStage Iterator to begin of pipeline
158  * @param[in] itEndStage End pipeline iterator
159  * @param[in] callbackExecutor Final or error stage executor
160  */
161  void RunFirstStage(const Pipeline::iterator itBeginStage, const Pipeline::iterator itEndStage,
162  const ITaskExecutor::Ptr callbackExecutor = {}) {
163  _promise = {};
164  bool stop = [&] {
165  std::lock_guard<std::mutex> lock(_mutex);
166  if (!_stop) {
167  _futures.erase(std::remove_if(std::begin(_futures), std::end(_futures),
168  [](const std::shared_future<void>& future) {
169  if (future.valid()) {
170  return (std::future_status::ready ==
171  future.wait_for(std::chrono::milliseconds {0}));
172  } else {
173  return true;
174  }
175  }),
176  _futures.end());
177 
178  _futures.emplace_back(_promise.get_future().share());
179  }
180  return _stop;
181  }();
182 
183  if (!stop) {
184  try {
185  auto& firstStageExecutor = std::get<Stage_e::executor>(*itBeginStage);
186  IE_ASSERT(nullptr != firstStageExecutor);
187  firstStageExecutor->run(MakeNextStageTask(itBeginStage, itEndStage, std::move(callbackExecutor)));
188  } catch (...) {
189  _promise.set_exception(std::current_exception());
190  throw;
191  }
192  }
193  }
194 
195  /**
196  * @brief Forbids pipeline start and wait for all started pipelines.
197  * @note Should be called in derived class destructor to wait for completion of usage of derived context captured by
198  * pipeline tasks
199  */
200  void StopAndWait() {
201  _callback = nullptr;
202  {
203  std::lock_guard<std::mutex> lock(_mutex);
204  if (!_stop) {
205  _stop = true;
206  for (auto&& future : _futures) {
207  if (future.valid()) {
208  future.wait();
209  }
210  }
211  }
212  }
213  }
214 
215  /**
216  * @brief Implements Infer() using StartAsync() and Wait()
217  */
219  DisableCallbackGuard disableCallbackGuard{_callback};
220  StartAsync_ThreadUnsafe();
221  Wait(InferenceEngine::IInferRequest::WaitMode::RESULT_READY);
222  }
223 
224  /**
225  * @brief Implements Infer() using synchronous pipeline and Wait()
226  */
227  void InferUsingSync() {
228  DisableCallbackGuard disableCallbackGuard{_callback};
229  _syncRequest->checkBlobs();
230  RunFirstStage(_syncPipeline.begin(), _syncPipeline.end(), _syncCallbackExecutor);
231  Wait(InferenceEngine::IInferRequest::WaitMode::RESULT_READY);
232  }
233 
234  ITaskExecutor::Ptr _requestExecutor; //!< Used to run inference CPU tasks.
235  ITaskExecutor::Ptr _callbackExecutor; //!< Used to run post inference callback in asynchronous pipline
236  ITaskExecutor::Ptr _syncCallbackExecutor; //!< Used to run post inference callback in synchronous pipline
237  Pipeline _pipeline; //!< Pipeline variable that should be filled by inherited class.
238  Pipeline _syncPipeline; //!< Synchronous pipeline variable that should be filled by inherited class.
239 
240  void StartAsync_ThreadUnsafe() override {
241  _syncRequest->checkBlobs();
242  RunFirstStage(_pipeline.begin(), _pipeline.end(), _callbackExecutor);
243  }
244 
245  void Infer_ThreadUnsafe() override {
246  InferUsingSync();
247  }
248 
249  void GetPerformanceCounts_ThreadUnsafe(std::map<std::string, InferenceEngineProfileInfo>& perfMap) const override {
250  _syncRequest->GetPerformanceCounts(perfMap);
251  }
252 
253  void SetBlob_ThreadUnsafe(const char* name, const Blob::Ptr& data) override {
254  _syncRequest->SetBlob(name, data);
255  }
256 
257  void SetBlob_ThreadUnsafe(const char* name, const Blob::Ptr& data, const PreProcessInfo& info) override {
258  _syncRequest->SetBlob(name, data, info);
259  }
260 
261  void GetBlob_ThreadUnsafe(const char* name, Blob::Ptr& data) override {
262  _syncRequest->GetBlob(name, data);
263  }
264 
265  void GetPreProcess_ThreadUnsafe(const char* name, const PreProcessInfo** info) const override {
266  _syncRequest->GetPreProcess(name, info);
267  }
268 
270  _callback = callback;
271  }
272 
273  void GetUserData_ThreadUnsafe(void** data) override {
274  if (data == nullptr) THROW_IE_EXCEPTION << NOT_ALLOCATED_str;
275  *data = _userData;
276  }
277 
278  void SetUserData_ThreadUnsafe(void* data) override {
279  _userData = data;
280  }
281 
282  void SetBatch_ThreadUnsafe(int batch) override {
283  _syncRequest->SetBatch(batch);
284  }
285 
286 private:
287  /**
288  * @brief Create a task with next pipeline stage.
289  * Each call to MakeNextStageTask() generates @ref Task objects for each stage.
290  * On last stage or if the exception is raised from `_pipeline` task
291  * the last stage task is called or passed to callback executor if it is presented. The last stage task call the
292  * callback, if it is presented, capture the `_promise` member and use it to forward completion or exception to the
293  * one of `_futures` member
294  * @param[in] itStage Iterator to next stage of pipeline
295  * @param[in] itEndStage End pipeline iterator
296  * @param[in] callbackExecutor Executor that will run final stage with callback call
297  * @return A next stage task
298  */
299  Task MakeNextStageTask(const Pipeline::iterator itStage, const Pipeline::iterator itEndStage,
300  const ITaskExecutor::Ptr callbackExecutor) {
301  return std::bind([this, itStage, itEndStage](ITaskExecutor::Ptr& callbackExecutor) mutable {
302  StatusCode requestStatus = StatusCode::OK;
303  std::exception_ptr localCurrentException = nullptr;
304  auto& thisStage = *itStage;
305  auto itNextStage = itStage + 1;
306 
307  try {
308  auto& stageTask = std::get<Stage_e::task>(thisStage);
309  IE_ASSERT(nullptr != stageTask);
310  stageTask();
311  if (itEndStage != itNextStage) {
312  auto& nextStage = *itNextStage;
313  auto& nextStageExecutor = std::get<Stage_e::executor>(nextStage);
314  IE_ASSERT(nullptr != nextStageExecutor);
315  nextStageExecutor->run(MakeNextStageTask(itNextStage, itEndStage, std::move(callbackExecutor)));
316  }
317  } catch (InferenceEngine::details::InferenceEngineException& ie_ex) {
318  requestStatus = ie_ex.hasStatus() ? ie_ex.getStatus() : StatusCode::GENERAL_ERROR;
319  localCurrentException = std::make_exception_ptr(ie_ex);
320  } catch (...) {
321  requestStatus = StatusCode::GENERAL_ERROR;
322  localCurrentException = std::current_exception();
323  }
324 
325  if ((itEndStage == itNextStage) || (nullptr != localCurrentException)) {
326  auto lastStageTask = [this, requestStatus, localCurrentException]() mutable {
327  auto promise = std::move(_promise);
328  auto callback = _callback.load();
329  if (setIsRequestBusy(false)) {
330  if (nullptr != callback) {
331  InferenceEngine::CurrentException() = localCurrentException;
332  try {
333  callback(_publicInterface, requestStatus);
334  } catch (...) {
335  localCurrentException = std::current_exception();
336  }
338  }
339  if (nullptr == localCurrentException) {
340  promise.set_value();
341  } else {
342  promise.set_exception(localCurrentException);
343  }
344  }
345  };
346 
347  if (nullptr == callbackExecutor) {
348  lastStageTask();
349  } else {
350  callbackExecutor->run(std::move(lastStageTask));
351  }
352  }
353  }, std::move(callbackExecutor));
354  }
355 
356  void* _userData = nullptr;
357  AtomicCallback _callback = {nullptr};
358  IInferRequest::Ptr _publicInterface;
359  std::promise<void> _promise;
360  mutable std::mutex _mutex;
361  Futures _futures;
362  bool _stop = false;
363 };
364 } // namespace InferenceEngine
InferenceEngine
Inference Engine Plugin API namespace.
InferenceEngine::AsyncInferRequestThreadSafeDefault::Ptr
std::shared_ptr< AsyncInferRequestThreadSafeDefault > Ptr
A shared pointer to AsyncInferRequestThreadSafeDefault.
Definition: ie_infer_async_request_thread_safe_default.hpp:62
InferenceEngine::AsyncInferRequestThreadSafeInternal
Wrapper of asynchronous inference request to support thread-safe execution.
Definition: ie_infer_async_request_thread_safe_internal.hpp:21
InferenceEngine::PreProcessInfo
InferenceEngine::AsyncInferRequestThreadSafeDefault::Infer_ThreadUnsafe
void Infer_ThreadUnsafe() override
Performs inference of pipeline in syncronous mode.
Definition: ie_infer_async_request_thread_safe_default.hpp:245
InferenceEngine::AsyncInferRequestThreadSafeDefault::SetBlob_ThreadUnsafe
void SetBlob_ThreadUnsafe(const char *name, const Blob::Ptr &data) override
Sets the blob thread unsafe.
Definition: ie_infer_async_request_thread_safe_default.hpp:253
exception2status.hpp
Wrappers from c++ function to c-style one.
ie_system_conf.h
Abstraction over platform specific implementations.
InferenceEngine::AsyncInferRequestThreadSafeDefault::InferUsingSync
void InferUsingSync()
Implements Infer() using synchronous pipeline and Wait()
Definition: ie_infer_async_request_thread_safe_default.hpp:227
InferenceEngine::AsyncInferRequestThreadSafeDefault::_callbackExecutor
ITaskExecutor::Ptr _callbackExecutor
Used to run post inference callback in asynchronous pipline.
Definition: ie_infer_async_request_thread_safe_default.hpp:235
InferenceEngine::AsyncInferRequestThreadSafeDefault::SetBatch_ThreadUnsafe
void SetBatch_ThreadUnsafe(int batch) override
Sets the dynamic batch thread unsafe.
Definition: ie_infer_async_request_thread_safe_default.hpp:282
ie_itask_executor.hpp
A header file for Inference Engine Task Executor Interface.
InferenceEngine::AsyncInferRequestThreadSafeDefault::_syncCallbackExecutor
ITaskExecutor::Ptr _syncCallbackExecutor
Used to run post inference callback in synchronous pipline.
Definition: ie_infer_async_request_thread_safe_default.hpp:236
InferenceEngine::AsyncInferRequestThreadSafeDefault::GetPreProcess_ThreadUnsafe
void GetPreProcess_ThreadUnsafe(const char *name, const PreProcessInfo **info) const override
Gets the preprocessing information thread unsafe.
Definition: ie_infer_async_request_thread_safe_default.hpp:265
InferenceEngine::AsyncInferRequestThreadSafeDefault::SetCompletionCallback_ThreadUnsafe
void SetCompletionCallback_ThreadUnsafe(IInferRequest::CompletionCallback callback) override
Sets the completion callback thread unsafe.
Definition: ie_infer_async_request_thread_safe_default.hpp:269
StatusCode
StatusCode
InferenceEngine::AsyncInferRequestThreadSafeDefault::Wait
StatusCode Wait(int64_t millis_timeout) override
Waits for completion of all pipeline stages If the pipeline raises an exception it will be rethrown h...
Definition: ie_infer_async_request_thread_safe_default.hpp:96
InferenceEngine::AsyncInferRequestThreadSafeDefault::SetBlob_ThreadUnsafe
void SetBlob_ThreadUnsafe(const char *name, const Blob::Ptr &data, const PreProcessInfo &info) override
Sets the blob with preprocessing information thread unsafe.
Definition: ie_infer_async_request_thread_safe_default.hpp:257
InferenceEngine::AsyncInferRequestThreadSafeDefault::~AsyncInferRequestThreadSafeDefault
~AsyncInferRequestThreadSafeDefault()
Destroys the object, stops AsyncInferRequestThreadSafeDefault::_pipeline and waits for a finish.
Definition: ie_infer_async_request_thread_safe_default.hpp:86
InferenceEngine::AsyncInferRequestThreadSafeDefault::StartAsync_ThreadUnsafe
void StartAsync_ThreadUnsafe() override
Starts an asynchronous pipeline thread unsafe.
Definition: ie_infer_async_request_thread_safe_default.hpp:240
InferenceEngine::AsyncInferRequestThreadSafeDefault::InferUsingAsync
void InferUsingAsync()
Implements Infer() using StartAsync() and Wait()
Definition: ie_infer_async_request_thread_safe_default.hpp:218
InferenceEngine::IInferRequest
InferenceEngine::AsyncInferRequestThreadSafeDefault::StopAndWait
void StopAndWait()
Forbids pipeline start and wait for all started pipelines.
Definition: ie_infer_async_request_thread_safe_default.hpp:200
NOT_ALLOCATED_str
#define NOT_ALLOCATED_str
Defines the not allocated message.
Definition: exception2status.hpp:138
InferenceEngine::InferRequestInternal::Ptr
std::shared_ptr< InferRequestInternal > Ptr
A shared pointer to a InferRequestInternal implementation.
Definition: ie_infer_request_internal.hpp:37
InferenceEngine::AsyncInferRequestThreadSafeDefault::_pipeline
Pipeline _pipeline
Pipeline variable that should be filled by inherited class.
Definition: ie_infer_async_request_thread_safe_default.hpp:237
InferenceEngine::ITaskExecutor::Ptr
std::shared_ptr< ITaskExecutor > Ptr
Definition: ie_itask_executor.hpp:51
InferenceEngine::Blob::Ptr
std::shared_ptr< Blob > Ptr
InferenceEngine::AsyncInferRequestThreadSafeDefault::_syncPipeline
Pipeline _syncPipeline
Synchronous pipeline variable that should be filled by inherited class.
Definition: ie_infer_async_request_thread_safe_default.hpp:238
InferenceEngine::IInferRequest::CompletionCallback
void(* CompletionCallback)(InferenceEngine::IInferRequest::Ptr context, InferenceEngine::StatusCode code)
THROW_IE_EXCEPTION
#define THROW_IE_EXCEPTION
InferenceEngine::CurrentException
std::exception_ptr & CurrentException()
Provides the reference to static thread_local std::exception_ptr.
InferenceEngine::AsyncInferRequestThreadSafeDefault::RunFirstStage
void RunFirstStage(const Pipeline::iterator itBeginStage, const Pipeline::iterator itEndStage, const ITaskExecutor::Ptr callbackExecutor={})
Creates and run the first stage task. If destructor was not called add a new std::future to the Async...
Definition: ie_infer_async_request_thread_safe_default.hpp:161
InferenceEngine::AsyncInferRequestThreadSafeDefault::SetUserData_ThreadUnsafe
void SetUserData_ThreadUnsafe(void *data) override
Sets the user data thread unsafe.
Definition: ie_infer_async_request_thread_safe_default.hpp:278
InferenceEngine::AsyncInferRequestThreadSafeDefault::Stage
std::pair< ITaskExecutor::Ptr, Task > Stage
Each pipeline stage is a Task that is executed by specified ITaskExecutor implementation.
Definition: ie_infer_async_request_thread_safe_default.hpp:147
PARAMETER_MISMATCH_str
#define PARAMETER_MISMATCH_str
Defines the parameter mismatch message.
Definition: exception2status.hpp:96
InferenceEngine::IInferRequest::Ptr
std::shared_ptr< IInferRequest > Ptr
InferenceEngine::AsyncInferRequestThreadSafeDefault::Pipeline
std::vector< Stage > Pipeline
Pipeline is vector of stages.
Definition: ie_infer_async_request_thread_safe_default.hpp:151
InferenceEngine::AsyncInferRequestThreadSafeDefault::_requestExecutor
ITaskExecutor::Ptr _requestExecutor
Used to run inference CPU tasks.
Definition: ie_infer_async_request_thread_safe_default.hpp:234
InferenceEngine::AsyncInferRequestThreadSafeDefault::GetPerformanceCounts_ThreadUnsafe
void GetPerformanceCounts_ThreadUnsafe(std::map< std::string, InferenceEngineProfileInfo > &perfMap) const override
Gets the performance counts thread unsafe.
Definition: ie_infer_async_request_thread_safe_default.hpp:249
InferenceEngine::AsyncInferRequestThreadSafeDefault::SetPointerToPublicInterface
void SetPointerToPublicInterface(InferenceEngine::IInferRequest::Ptr ptr)
Sets the pointer to public interface.
Definition: ie_infer_async_request_thread_safe_default.hpp:139
ie_immediate_executor.hpp
A header file for Inference Engine Immediate Executor implementation.
InferenceEngine::AsyncInferRequestThreadSafeDefault::GetUserData_ThreadUnsafe
void GetUserData_ThreadUnsafe(void **data) override
Gets the user data thread unsafe.
Definition: ie_infer_async_request_thread_safe_default.hpp:273
InferenceEngine::AsyncInferRequestThreadSafeDefault::GetBlob_ThreadUnsafe
void GetBlob_ThreadUnsafe(const char *name, Blob::Ptr &data) override
Gets the input or output blob thread unsafe.
Definition: ie_infer_async_request_thread_safe_default.hpp:261
IE_ASSERT
#define IE_ASSERT(EXPRESSION)
InferenceEngine::AsyncInferRequestThreadSafeDefault
Base class with default implementation of asynchronous multi staged inference request....
Definition: ie_infer_async_request_thread_safe_default.hpp:42
InferenceEngine::Task
std::function< void()> Task
Inference Engine Task Executor can use any copyable callable without parameters and output as a task....
Definition: ie_itask_executor.hpp:25
InferenceEngine::AsyncInferRequestThreadSafeDefault::AsyncInferRequestThreadSafeDefault
AsyncInferRequestThreadSafeDefault(const InferRequestInternal::Ptr &request, const ITaskExecutor::Ptr &taskExecutor, const ITaskExecutor::Ptr &callbackExecutor)
Wraps a InferRequestInternal::Ptr implementation and constructs a AsyncInferRequestThreadSafeDefault:...
Definition: ie_infer_async_request_thread_safe_default.hpp:73