Skip to content

Commit 4a291f4

Browse files
committed
use enum for model types
1 parent 9ce6b5a commit 4a291f4

File tree

1 file changed

+14
-5
lines changed

1 file changed

+14
-5
lines changed

include/cppflow/model.h

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,13 @@ namespace cppflow {
1919

2020
class model {
2121
public:
22-
explicit model(const std::string& filename, bool import_saved_model=true);
22+
enum TYPE
23+
{
24+
SAVED_MODEL,
25+
FROZEN_GRAPH,
26+
};
27+
28+
explicit model(const std::string& filename, const TYPE type=TYPE::SAVED_MODEL);
2329

2430
std::vector<std::string> get_operations() const;
2531
std::vector<int64_t> get_operation_shape(const std::string& operation) const;
@@ -44,7 +50,7 @@ namespace cppflow {
4450

4551
namespace cppflow {
4652

47-
inline model::model(const std::string &filename, bool import_saved_model) {
53+
inline model::model(const std::string &filename, const TYPE type) {
4854
this->graph = {TF_NewGraph(), TF_DeleteGraph};
4955

5056
// Create the session.
@@ -55,7 +61,7 @@ namespace cppflow {
5561
status_check(context::get_status());
5662
};
5763

58-
if (import_saved_model) {
64+
if (type == TYPE::SAVED_MODEL) {
5965
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> run_options = {TF_NewBufferFromString("", 0), TF_DeleteBuffer};
6066
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> meta_graph = {TF_NewBuffer(), TF_DeleteBuffer};
6167

@@ -65,20 +71,23 @@ namespace cppflow {
6571
&tag, tag_len, this->graph.get(), meta_graph.get(), context::get_status()),
6672
session_deleter};
6773
}
68-
else {
74+
else if (type == TYPE::FROZEN_GRAPH) {
6975
this->session = {TF_NewSession(this->graph.get(), session_options.get(), context::get_status()), session_deleter};
7076
status_check(context::get_status());
7177

7278
// Import the graph definition
7379
TF_Buffer* def = readGraph(filename);
7480
if(def == nullptr) {
75-
throw std::runtime_error("Failed to import gragh def from file");
81+
throw std::runtime_error("Failed to import graph def from file");
7682
}
7783

7884
std::unique_ptr<TF_ImportGraphDefOptions, decltype(&TF_DeleteImportGraphDefOptions)> graph_opts = {TF_NewImportGraphDefOptions(), TF_DeleteImportGraphDefOptions};
7985
TF_GraphImportGraphDef(this->graph.get(), def, graph_opts.get(), context::get_status());
8086
TF_DeleteBuffer(def);
8187
}
88+
else {
89+
throw std::runtime_error("Model type unknown");
90+
}
8291

8392
status_check(context::get_status());
8493
}

0 commit comments

Comments
 (0)