Skip to content

Commit 243ff2f

Browse files
authored
Merge pull request #51 from ljn917/tstring
Compatibility with TF 2.4 TF_TString
2 parents 8357eee + 81615cc commit 243ff2f

File tree

1 file changed

+14
-1
lines changed

1 file changed

+14
-1
lines changed

src/Model.cpp

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,12 @@ void Model::init() {
5858
}
5959

6060
void Model::save(const std::string &ckpt) {
61+
#ifdef TENSORFLOW_C_TF_TSTRING_H_
62+
std::unique_ptr<TF_TString, decltype(&TF_TString_Dealloc)> tstr(new TF_TString, &TF_TString_Dealloc);
63+
TF_TString_Copy(tstr.get(), ckpt.c_str(), ckpt.size());
64+
auto deallocator = [](void* data, size_t len, void* arg) {};
65+
TF_Tensor* t = TF_NewTensor(TF_STRING, nullptr, 0, tstr.get(), 1, deallocator, nullptr);
66+
#else
6167
// Encode file_name to tensor
6268
size_t size = 8 + TF_StringEncodedSize(ckpt.length());
6369
TF_Tensor* t = TF_AllocateTensor(TF_STRING, nullptr, 0, size);
@@ -67,6 +73,7 @@ void Model::save(const std::string &ckpt) {
6773

6874
memset(data, 0, 8); // 8-byte offset of first string.
6975
TF_StringEncode(ckpt.c_str(), ckpt.length(), (char*)(data + 8), size - 8, status);
76+
#endif // TENSORFLOW_C_TF_TSTRING_H_
7077

7178
// Check errors
7279
if (!this->status_check(false)) {
@@ -95,13 +102,19 @@ void Model::save(const std::string &ckpt) {
95102
}
96103

97104
void Model::restore(const std::string& ckpt) {
98-
105+
#ifdef TENSORFLOW_C_TF_TSTRING_H_
106+
std::unique_ptr<TF_TString, decltype(&TF_TString_Dealloc)> tstr(new TF_TString, &TF_TString_Dealloc);
107+
TF_TString_Copy(tstr.get(), ckpt.c_str(), ckpt.size());
108+
auto deallocator = [](void* data, size_t len, void* arg) {};
109+
TF_Tensor* t = TF_NewTensor(TF_STRING, nullptr, 0, tstr.get(), 1, deallocator, nullptr);
110+
#else
99111
// Encode file_name to tensor
100112
size_t size = 8 + TF_StringEncodedSize(ckpt.size());
101113
TF_Tensor* t = TF_AllocateTensor(TF_STRING, nullptr, 0, size);
102114
char* data = static_cast<char *>(TF_TensorData(t));
103115
for (int i=0; i<8; i++) {data[i]=0;}
104116
TF_StringEncode(ckpt.c_str(), ckpt.size(), data + 8, size - 8, status);
117+
#endif // TENSORFLOW_C_TF_TSTRING_H_
105118

106119
// Check errors
107120
if (!this->status_check(false)) {

0 commit comments

Comments
 (0)