|
| 1 | +// MIT License |
1 | 2 | // |
2 | | -// Created by serizba on 27/6/20. |
| 3 | +// Copyright (c) 2020 Sergio Izquierdo |
| 4 | +// Copyright (c) 2020 Jiannan Liu |
3 | 5 | // |
| 6 | +// Permission is hereby granted, free of charge, to any person obtaining a copy |
| 7 | +// of this software and associated documentation files (the "Software"), to deal |
| 8 | +// in the Software without restriction, including without limitation the rights |
| 9 | +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell |
| 10 | +// copies of the Software, and to permit persons to whom the Software is |
| 11 | +// furnished to do so, subject to the following conditions: |
| 12 | +// |
| 13 | +// The above copyright notice and this permission notice shall be included in |
| 14 | +// all copies or substantial portions of the Software. |
| 15 | +// |
| 16 | +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR |
| 17 | +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, |
| 18 | +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE |
| 19 | +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER |
| 20 | +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, |
| 21 | +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE |
| 22 | +// SOFTWARE. |
| 23 | + |
| 24 | +/*! |
| 25 | + * @file context.h |
| 26 | + * @author Jiannan Liu |
| 27 | + * @author Sergio Izquierdo |
| 28 | + * @date @showdate "%B %d, %Y" 2020-06-27 |
| 29 | + */ |
| 30 | + |
| 31 | +#ifndef INCLUDE_CPPFLOW_CONTEXT_H_ |
| 32 | +#define INCLUDE_CPPFLOW_CONTEXT_H_ |
| 33 | + |
| 34 | +// C headers |
| 35 | +#include <tensorflow/c/c_api.h> |
| 36 | +#include <tensorflow/c/eager/c_api.h> |
4 | 37 |
|
5 | | -#ifndef CPPFLOW2_CONTEXT_H |
6 | | -#define CPPFLOW2_CONTEXT_H |
7 | | - |
| 38 | +// C++ headers |
8 | 39 | #include <memory> |
9 | 40 | #include <stdexcept> |
10 | 41 | #include <utility> |
11 | 42 |
|
12 | | -#include <tensorflow/c/c_api.h> |
13 | | -#include <tensorflow/c/eager/c_api.h> |
14 | | - |
15 | 43 | namespace cppflow { |
16 | 44 |
|
17 | | - inline bool status_check(TF_Status* status) { |
18 | | - if (TF_GetCode(status) != TF_OK) { |
19 | | - throw std::runtime_error(TF_Message(status)); |
20 | | - } |
21 | | - return true; |
22 | | - } |
| 45 | +inline bool status_check(TF_Status* status) { |
| 46 | + if (TF_GetCode(status) != TF_OK) { |
| 47 | + throw std::runtime_error(TF_Message(status)); |
| 48 | + } |
| 49 | + return true; |
| 50 | +} |
23 | 51 |
|
24 | | - class context { |
25 | | - public: |
26 | | - static TFE_Context* get_context(); |
| 52 | +class context { |
| 53 | + public: |
| 54 | + explicit context(TFE_ContextOptions* opts = nullptr); |
27 | 55 |
|
28 | | - // only use get_status() for eager ops |
29 | | - static TF_Status* get_status(); |
| 56 | + context(context const&) = delete; |
| 57 | + context(context&&) noexcept; |
30 | 58 |
|
31 | | - private: |
32 | | - TFE_Context* tfe_context{nullptr}; |
| 59 | + ~context(); |
33 | 60 |
|
34 | | - public: |
35 | | - explicit context(TFE_ContextOptions* opts = nullptr); |
| 61 | + context& operator=(context const&) = delete; |
| 62 | + context& operator=(context&&) noexcept; |
36 | 63 |
|
37 | | - context(context const&) = delete; |
38 | | - context& operator=(context const&) = delete; |
39 | | - context(context&&) noexcept; |
40 | | - context& operator=(context&&) noexcept; |
| 64 | + static TFE_Context* get_context(); |
41 | 65 |
|
42 | | - ~context(); |
43 | | - }; |
| 66 | + // only use get_status() for eager ops |
| 67 | + static TF_Status* get_status(); |
44 | 68 |
|
45 | | - // TODO: create ContextManager class if needed |
46 | | - // Set new context, thread unsafe, must be called at the beginning. |
47 | | - // TFE_ContextOptions* tfe_opts = ... |
48 | | - // cppflow::get_global_context() = cppflow::context(tfe_opts); |
49 | | - inline context& get_global_context() { |
50 | | - static context global_context; |
51 | | - return global_context; |
52 | | - } |
| 69 | + private: |
| 70 | + TFE_Context* tfe_context{nullptr}; |
| 71 | +}; // Class context |
53 | 72 |
|
| 73 | +// @todo create ContextManager class if needed |
| 74 | +// Set new context, thread unsafe, must be called at the beginning. |
| 75 | +// TFE_ContextOptions* tfe_opts = ... |
| 76 | +// cppflow::get_global_context() = cppflow::context(tfe_opts); |
| 77 | +inline context& get_global_context() { |
| 78 | + static context global_context; |
| 79 | + return global_context; |
54 | 80 | } |
| 81 | +} // namespace cppflow |
55 | 82 |
|
56 | 83 | namespace cppflow { |
57 | 84 |
|
58 | | - inline TFE_Context* context::get_context() { |
59 | | - return get_global_context().tfe_context; |
60 | | - } |
61 | | - |
62 | | - inline TF_Status* context::get_status() { |
63 | | - thread_local std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> local_tf_status(TF_NewStatus(), &TF_DeleteStatus); |
64 | | - return local_tf_status.get(); |
65 | | - } |
66 | | - |
67 | | - inline context::context(TFE_ContextOptions* opts) { |
68 | | - auto tf_status = context::get_status(); |
69 | | - if(opts == nullptr) { |
70 | | - std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> new_opts(TFE_NewContextOptions(), &TFE_DeleteContextOptions); |
71 | | - this->tfe_context = TFE_NewContext(new_opts.get(), tf_status); |
72 | | - } else { |
73 | | - this->tfe_context = TFE_NewContext(opts, tf_status); |
74 | | - } |
75 | | - status_check(tf_status); |
76 | | - } |
77 | | - |
78 | | - inline context::context(context&& ctx) noexcept : |
79 | | - tfe_context(std::exchange(ctx.tfe_context, nullptr)) |
80 | | - { |
81 | | - } |
82 | | - |
83 | | - inline context& context::operator=(context&& ctx) noexcept { |
84 | | - tfe_context = std::exchange(ctx.tfe_context, tfe_context); |
85 | | - return *this; |
86 | | - } |
87 | | - |
88 | | - inline context::~context() { |
89 | | - TFE_DeleteContext(this->tfe_context); |
90 | | - } |
| 85 | +inline TFE_Context* context::get_context() { |
| 86 | + return get_global_context().tfe_context; |
| 87 | +} |
| 88 | + |
| 89 | +inline TF_Status* context::get_status() { |
| 90 | + thread_local std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> |
| 91 | + local_tf_status(TF_NewStatus(), &TF_DeleteStatus); |
| 92 | + return local_tf_status.get(); |
| 93 | +} |
91 | 94 |
|
| 95 | +inline context::context(TFE_ContextOptions* opts) { |
| 96 | + auto tf_status = context::get_status(); |
| 97 | + if (opts == nullptr) { |
| 98 | + std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> |
| 99 | + new_opts(TFE_NewContextOptions(), &TFE_DeleteContextOptions); |
| 100 | + this->tfe_context = TFE_NewContext(new_opts.get(), tf_status); |
| 101 | + } else { |
| 102 | + this->tfe_context = TFE_NewContext(opts, tf_status); |
| 103 | + } |
| 104 | + status_check(tf_status); |
92 | 105 | } |
93 | 106 |
|
94 | | -#endif //CPPFLOW2_CONTEXT_H |
| 107 | +inline context::context(context&& ctx) noexcept |
| 108 | + : tfe_context(std::exchange(ctx.tfe_context, nullptr)) {} |
| 109 | + |
| 110 | +inline context& context::operator=(context&& ctx) noexcept { |
| 111 | + tfe_context = std::exchange(ctx.tfe_context, tfe_context); |
| 112 | + return *this; |
| 113 | +} |
| 114 | + |
| 115 | +inline context::~context() { |
| 116 | + TFE_DeleteContext(this->tfe_context); |
| 117 | +} |
| 118 | + |
| 119 | +} // namespace cppflow |
| 120 | + |
| 121 | +#endif // INCLUDE_CPPFLOW_CONTEXT_H_ |
0 commit comments