-
Notifications
You must be signed in to change notification settings - Fork 181
Description
Many thanks for this extremely helpful library.
Currently models are loaded from a file using the model constructor which calls readGraph.
Sometimes it is helpful to load a model from resources. This is an issue that I have solved for frozen graphs, and I'm posting it here as a suggestion, in case you want to add it to the master.
(I'm not that familiar with GitHub, so excuse me not branching the master and pushing, or whatever the terms are!)
I added to model.h a new constructor which takes a pointer to a std::vector of uchar as its only parameter. This then provides the arguments bufferModel->data() and bufferModel->size() to TF_NewBufferFromString instead of readGraph(filename) as in the existing version.
inline model::model(const std::vector<uchar>* bufferModel)
{
this->status = {TF_NewStatus(), &TF_DeleteStatus};
this->graph = {TF_NewGraph(), TF_DeleteGraph};
// Create the session.
std::unique_ptr<TF_SessionOptions, decltype(&TF_DeleteSessionOptions)>
session_options = {TF_NewSessionOptions(), TF_DeleteSessionOptions};
auto session_deleter = [this](TF_Session* sess) {
TF_DeleteSession(sess, this->status.get());
status_check(this->status.get());
};
this->session = {TF_NewSession(this->graph.get(),
session_options.get(),
this->status.get()),
session_deleter};
status_check(this->status.get());
// Import the graph definition
TF_Buffer* def = TF_NewBufferFromString(bufferModel->data(), bufferModel->size());
if (def == nullptr)
{
throw std::runtime_error("Failed to import graph def from file");
}
std::unique_ptr<TF_ImportGraphDefOptions, decltype(&TF_DeleteImportGraphDefOptions)> graph_opts = {
TF_NewImportGraphDefOptions(), TF_DeleteImportGraphDefOptions};
TF_GraphImportGraphDef(this->graph.get(), def, graph_opts.get(), this->status.get());
TF_DeleteBuffer(def);
status_check(this->status.get());
}
I then load the PB model from resources as std::vector<uchar> using a LoadModel function in my own project's source code, and pass that to the new model constructor. My project happens to be using wxWidgets, and so this is conveniently done as follows. I include this here only in case this might help someone in future. It's not itself a suggestion for cppflow.
void LoadModel(wxString resName, std::vector<uchar>& model)
{
HRSRC hrsrc = FindResource(wxGetInstance(), resName, RT_RCDATA);
if(hrsrc == NULL) return;
HGLOBAL hglobal = LoadResource(wxGetInstance(), hrsrc);
if(hglobal == NULL) return;
void *data = LockResource(hglobal);
if(data == NULL) return;
DWORD datalen = SizeofResource(wxGetInstance(), hrsrc);
if(datalen < 1) return;
uchar *charBuf = (uchar*)data;
model = std::vector<uchar>(charBuf, charBuf + datalen);
}