@@ -19,7 +19,13 @@ namespace cppflow {
1919
2020 class model {
2121 public:
22- explicit model (const std::string& filename, bool import_saved_model=true );
22+ enum TYPE
23+ {
24+ SAVED_MODEL,
25+ FROZEN_GRAPH,
26+ };
27+
28+ explicit model (const std::string& filename, const TYPE type=TYPE::SAVED_MODEL);
2329
2430 std::vector<std::string> get_operations () const ;
2531 std::vector<int64_t > get_operation_shape (const std::string& operation) const ;
@@ -44,7 +50,7 @@ namespace cppflow {
4450
4551namespace cppflow {
4652
47- inline model::model (const std::string &filename, bool import_saved_model ) {
53+ inline model::model (const std::string &filename, const TYPE type ) {
4854 this ->graph = {TF_NewGraph (), TF_DeleteGraph};
4955
5056 // Create the session.
@@ -55,7 +61,7 @@ namespace cppflow {
5561 status_check (context::get_status ());
5662 };
5763
58- if (import_saved_model ) {
64+ if (type == TYPE::SAVED_MODEL ) {
5965 std::unique_ptr<TF_Buffer, decltype (&TF_DeleteBuffer)> run_options = {TF_NewBufferFromString (" " , 0 ), TF_DeleteBuffer};
6066 std::unique_ptr<TF_Buffer, decltype (&TF_DeleteBuffer)> meta_graph = {TF_NewBuffer (), TF_DeleteBuffer};
6167
@@ -65,20 +71,23 @@ namespace cppflow {
6571 &tag, tag_len, this ->graph .get (), meta_graph.get (), context::get_status ()),
6672 session_deleter};
6773 }
68- else {
74+ else if (type == TYPE::FROZEN_GRAPH) {
6975 this ->session = {TF_NewSession (this ->graph .get (), session_options.get (), context::get_status ()), session_deleter};
7076 status_check (context::get_status ());
7177
7278 // Import the graph definition
7379 TF_Buffer* def = readGraph (filename);
7480 if (def == nullptr ) {
75- throw std::runtime_error (" Failed to import gragh def from file" );
81+ throw std::runtime_error (" Failed to import graph def from file" );
7682 }
7783
7884 std::unique_ptr<TF_ImportGraphDefOptions, decltype (&TF_DeleteImportGraphDefOptions)> graph_opts = {TF_NewImportGraphDefOptions (), TF_DeleteImportGraphDefOptions};
7985 TF_GraphImportGraphDef (this ->graph .get (), def, graph_opts.get (), context::get_status ());
8086 TF_DeleteBuffer (def);
8187 }
88+ else {
89+ throw std::runtime_error (" Model type unknown" );
90+ }
8291
8392 status_check (context::get_status ());
8493 }
0 commit comments