Skip to content

Commit 9ce6b5a

Browse files
committed
add frozen graph support
1 parent 883eb4c commit 9ce6b5a

File tree

1 file changed

+63
-9
lines changed

1 file changed

+63
-9
lines changed

include/cppflow/model.h

Lines changed: 63 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ namespace cppflow {
1919

2020
class model {
2121
public:
22-
explicit model(const std::string& filename);
22+
explicit model(const std::string& filename, bool import_saved_model=true);
2323

2424
std::vector<std::string> get_operations() const;
2525
std::vector<int64_t> get_operation_shape(const std::string& operation) const;
@@ -34,6 +34,7 @@ namespace cppflow {
3434
model &operator=(model &&other) = default;
3535

3636
private:
37+
TF_Buffer * readGraph(const std::string& filename);
3738

3839
std::shared_ptr<TF_Graph> graph;
3940
std::shared_ptr<TF_Session> session;
@@ -43,24 +44,41 @@ namespace cppflow {
4344

4445
namespace cppflow {
4546

46-
inline model::model(const std::string &filename) {
47+
inline model::model(const std::string &filename, bool import_saved_model) {
4748
this->graph = {TF_NewGraph(), TF_DeleteGraph};
4849

4950
// Create the session.
5051
std::unique_ptr<TF_SessionOptions, decltype(&TF_DeleteSessionOptions)> session_options = {TF_NewSessionOptions(), TF_DeleteSessionOptions};
51-
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> run_options = {TF_NewBufferFromString("", 0), TF_DeleteBuffer};
52-
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> meta_graph = {TF_NewBuffer(), TF_DeleteBuffer};
5352

5453
auto session_deleter = [](TF_Session* sess) {
5554
TF_DeleteSession(sess, context::get_status());
5655
status_check(context::get_status());
5756
};
5857

59-
int tag_len = 1;
60-
const char* tag = "serve";
61-
this->session = {TF_LoadSessionFromSavedModel(session_options.get(), run_options.get(), filename.c_str(),
62-
&tag, tag_len, this->graph.get(), meta_graph.get(), context::get_status()),
63-
session_deleter};
58+
if (import_saved_model) {
59+
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> run_options = {TF_NewBufferFromString("", 0), TF_DeleteBuffer};
60+
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> meta_graph = {TF_NewBuffer(), TF_DeleteBuffer};
61+
62+
int tag_len = 1;
63+
const char* tag = "serve";
64+
this->session = {TF_LoadSessionFromSavedModel(session_options.get(), run_options.get(), filename.c_str(),
65+
&tag, tag_len, this->graph.get(), meta_graph.get(), context::get_status()),
66+
session_deleter};
67+
}
68+
else {
69+
this->session = {TF_NewSession(this->graph.get(), session_options.get(), context::get_status()), session_deleter};
70+
status_check(context::get_status());
71+
72+
// Import the graph definition
73+
TF_Buffer* def = readGraph(filename);
74+
if(def == nullptr) {
75+
throw std::runtime_error("Failed to import gragh def from file");
76+
}
77+
78+
std::unique_ptr<TF_ImportGraphDefOptions, decltype(&TF_DeleteImportGraphDefOptions)> graph_opts = {TF_NewImportGraphDefOptions(), TF_DeleteImportGraphDefOptions};
79+
TF_GraphImportGraphDef(this->graph.get(), def, graph_opts.get(), context::get_status());
80+
TF_DeleteBuffer(def);
81+
}
6482

6583
status_check(context::get_status());
6684
}
@@ -169,6 +187,42 @@ namespace cppflow {
169187
inline tensor model::operator()(const tensor& input) {
170188
return (*this)({{"serving_default_input_1", input}}, {"StatefulPartitionedCall"})[0];
171189
}
190+
191+
192+
inline TF_Buffer * model::readGraph(const std::string& filename) {
193+
std::ifstream file (filename, std::ios::binary | std::ios::ate);
194+
195+
// Error opening the file
196+
if (!file.is_open()) {
197+
std::cerr << "Unable to open file: " << filename << std::endl;
198+
return nullptr;
199+
}
200+
201+
// Cursor is at the end to get size
202+
auto size = file.tellg();
203+
// Move cursor to the beginning
204+
file.seekg (0, std::ios::beg);
205+
206+
// Read
207+
auto data = std::make_unique<char[]>(size);
208+
file.seekg (0, std::ios::beg);
209+
file.read (data.get(), size);
210+
211+
// Error reading the file
212+
if (!file) {
213+
std::cerr << "Unable to read the full file: " << filename << std::endl;
214+
return nullptr;
215+
}
216+
217+
// Create tensorflow buffer from read data
218+
TF_Buffer* buffer = TF_NewBufferFromString(data.get(), size);
219+
220+
// Close file and remove data
221+
file.close();
222+
223+
return buffer;
224+
}
225+
172226
}
173227

174228
#endif //CPPFLOW2_MODEL_H

0 commit comments

Comments
 (0)