@@ -42,6 +42,7 @@ namespace cppflow {
4242 private:
4343 TF_Buffer * readGraph (const std::string& filename);
4444
45+ std::shared_ptr<TF_Status> status;
4546 std::shared_ptr<TF_Graph> graph;
4647 std::shared_ptr<TF_Session> session;
4748 };
@@ -51,14 +52,15 @@ namespace cppflow {
5152namespace cppflow {
5253
5354 inline model::model (const std::string &filename, const TYPE type) {
55+ this ->status = {TF_NewStatus (), &TF_DeleteStatus};
5456 this ->graph = {TF_NewGraph (), TF_DeleteGraph};
5557
5658 // Create the session.
5759 std::unique_ptr<TF_SessionOptions, decltype (&TF_DeleteSessionOptions)> session_options = {TF_NewSessionOptions (), TF_DeleteSessionOptions};
5860
59- auto session_deleter = [](TF_Session* sess) {
60- TF_DeleteSession (sess, context::get_status ());
61- status_check (context::get_status ());
61+ auto session_deleter = [this ](TF_Session* sess) {
62+ TF_DeleteSession (sess, this -> status . get ());
63+ status_check (this -> status . get ());
6264 };
6365
6466 if (type == TYPE::SAVED_MODEL) {
@@ -68,12 +70,12 @@ namespace cppflow {
6870 int tag_len = 1 ;
6971 const char * tag = " serve" ;
7072 this ->session = {TF_LoadSessionFromSavedModel (session_options.get (), run_options.get (), filename.c_str (),
71- &tag, tag_len, this ->graph .get (), meta_graph.get (), context::get_status ()),
73+ &tag, tag_len, this ->graph .get (), meta_graph.get (), this -> status . get ()),
7274 session_deleter};
7375 }
7476 else if (type == TYPE::FROZEN_GRAPH) {
75- this ->session = {TF_NewSession (this ->graph .get (), session_options.get (), context::get_status ()), session_deleter};
76- status_check (context::get_status ());
77+ this ->session = {TF_NewSession (this ->graph .get (), session_options.get (), this -> status . get ()), session_deleter};
78+ status_check (this -> status . get ());
7779
7880 // Import the graph definition
7981 TF_Buffer* def = readGraph (filename);
@@ -82,14 +84,14 @@ namespace cppflow {
8284 }
8385
8486 std::unique_ptr<TF_ImportGraphDefOptions, decltype (&TF_DeleteImportGraphDefOptions)> graph_opts = {TF_NewImportGraphDefOptions (), TF_DeleteImportGraphDefOptions};
85- TF_GraphImportGraphDef (this ->graph .get (), def, graph_opts.get (), context::get_status ());
87+ TF_GraphImportGraphDef (this ->graph .get (), def, graph_opts.get (), this -> status . get ());
8688 TF_DeleteBuffer (def);
8789 }
8890 else {
8991 throw std::runtime_error (" Model type unknown" );
9092 }
9193
92- status_check (context::get_status ());
94+ status_check (this -> status . get ());
9395 }
9496
9597 inline std::vector<std::string> model::get_operations () const {
@@ -122,16 +124,16 @@ namespace cppflow {
122124 // DIMENSIONS
123125
124126 // Get number of dimensions
125- int n_dims = TF_GraphGetTensorNumDims (this ->graph .get (), out_op, context::get_status ());
127+ int n_dims = TF_GraphGetTensorNumDims (this ->graph .get (), out_op, this -> status . get ());
126128
127129 // If is not a scalar
128130 if (n_dims > 0 ) {
129131 // Get dimensions
130132 auto * dims = new int64_t [n_dims];
131- TF_GraphGetTensorShape (this ->graph .get (), out_op, dims, n_dims, context::get_status ());
133+ TF_GraphGetTensorShape (this ->graph .get (), out_op, dims, n_dims, this -> status . get ());
132134
133135 // Check error on Model Status
134- status_check (context::get_status ());
136+ status_check (this -> status . get ());
135137
136138 shape = std::vector<int64_t >(dims, dims + n_dims);
137139
@@ -181,8 +183,8 @@ namespace cppflow {
181183 TF_SessionRun (this ->session .get (), NULL ,
182184 inp_ops.data (), inp_val.data (), static_cast <int >(inputs.size ()),
183185 out_ops.data (), out_val.get (), static_cast <int >(outputs.size ()),
184- NULL , 0 ,NULL , context::get_status ());
185- status_check (context::get_status ());
186+ NULL , 0 ,NULL , this -> status . get ());
187+ status_check (this -> status . get ());
186188
187189 std::vector<tensor> result;
188190 result.reserve (outputs.size ());
0 commit comments