@@ -19,7 +19,7 @@ namespace cppflow {
1919
2020 class model {
2121 public:
22- explicit model (const std::string& filename);
22+ explicit model (const std::string& filename, bool import_saved_model= true );
2323
2424 std::vector<std::string> get_operations () const ;
2525 std::vector<int64_t > get_operation_shape (const std::string& operation) const ;
@@ -34,6 +34,7 @@ namespace cppflow {
3434 model &operator =(model &&other) = default ;
3535
3636 private:
37+ TF_Buffer * readGraph (const std::string& filename);
3738
3839 std::shared_ptr<TF_Graph> graph;
3940 std::shared_ptr<TF_Session> session;
@@ -43,24 +44,41 @@ namespace cppflow {
4344
4445namespace cppflow {
4546
46- inline model::model (const std::string &filename) {
47+ inline model::model (const std::string &filename, bool import_saved_model ) {
4748 this ->graph = {TF_NewGraph (), TF_DeleteGraph};
4849
4950 // Create the session.
5051 std::unique_ptr<TF_SessionOptions, decltype (&TF_DeleteSessionOptions)> session_options = {TF_NewSessionOptions (), TF_DeleteSessionOptions};
51- std::unique_ptr<TF_Buffer, decltype (&TF_DeleteBuffer)> run_options = {TF_NewBufferFromString (" " , 0 ), TF_DeleteBuffer};
52- std::unique_ptr<TF_Buffer, decltype (&TF_DeleteBuffer)> meta_graph = {TF_NewBuffer (), TF_DeleteBuffer};
5352
5453 auto session_deleter = [](TF_Session* sess) {
5554 TF_DeleteSession (sess, context::get_status ());
5655 status_check (context::get_status ());
5756 };
5857
59- int tag_len = 1 ;
60- const char * tag = " serve" ;
61- this ->session = {TF_LoadSessionFromSavedModel (session_options.get (), run_options.get (), filename.c_str (),
62- &tag, tag_len, this ->graph .get (), meta_graph.get (), context::get_status ()),
63- session_deleter};
58+ if (import_saved_model) {
59+ std::unique_ptr<TF_Buffer, decltype (&TF_DeleteBuffer)> run_options = {TF_NewBufferFromString (" " , 0 ), TF_DeleteBuffer};
60+ std::unique_ptr<TF_Buffer, decltype (&TF_DeleteBuffer)> meta_graph = {TF_NewBuffer (), TF_DeleteBuffer};
61+
62+ int tag_len = 1 ;
63+ const char * tag = " serve" ;
64+ this ->session = {TF_LoadSessionFromSavedModel (session_options.get (), run_options.get (), filename.c_str (),
65+ &tag, tag_len, this ->graph .get (), meta_graph.get (), context::get_status ()),
66+ session_deleter};
67+ }
68+ else {
69+ this ->session = {TF_NewSession (this ->graph .get (), session_options.get (), context::get_status ()), session_deleter};
70+ status_check (context::get_status ());
71+
72+ // Import the graph definition
73+ TF_Buffer* def = readGraph (filename);
74+ if (def == nullptr ) {
75+ throw std::runtime_error (" Failed to import gragh def from file" );
76+ }
77+
78+ std::unique_ptr<TF_ImportGraphDefOptions, decltype (&TF_DeleteImportGraphDefOptions)> graph_opts = {TF_NewImportGraphDefOptions (), TF_DeleteImportGraphDefOptions};
79+ TF_GraphImportGraphDef (this ->graph .get (), def, graph_opts.get (), context::get_status ());
80+ TF_DeleteBuffer (def);
81+ }
6482
6583 status_check (context::get_status ());
6684 }
@@ -169,6 +187,42 @@ namespace cppflow {
169187 inline tensor model::operator ()(const tensor& input) {
170188 return (*this )({{" serving_default_input_1" , input}}, {" StatefulPartitionedCall" })[0 ];
171189 }
190+
191+
192+ inline TF_Buffer * model::readGraph (const std::string& filename) {
193+ std::ifstream file (filename, std::ios::binary | std::ios::ate);
194+
195+ // Error opening the file
196+ if (!file.is_open ()) {
197+ std::cerr << " Unable to open file: " << filename << std::endl;
198+ return nullptr ;
199+ }
200+
201+ // Cursor is at the end to get size
202+ auto size = file.tellg ();
203+ // Move cursor to the beginning
204+ file.seekg (0 , std::ios::beg);
205+
206+ // Read
207+ auto data = std::make_unique<char []>(size);
208+ file.seekg (0 , std::ios::beg);
209+ file.read (data.get (), size);
210+
211+ // Error reading the file
212+ if (!file) {
213+ std::cerr << " Unable to read the full file: " << filename << std::endl;
214+ return nullptr ;
215+ }
216+
217+ // Create tensorflow buffer from read data
218+ TF_Buffer* buffer = TF_NewBufferFromString (data.get (), size);
219+
220+ // Close file and remove data
221+ file.close ();
222+
223+ return buffer;
224+ }
225+
172226}
173227
174228#endif // CPPFLOW2_MODEL_H
0 commit comments