Skip to content

Commit 858d51a

Browse files
author
qiutong
committed
rectify bert_convert_to_ov
1 parent fb7a018 commit 858d51a

File tree

1 file changed

+1
-8
lines changed

1 file changed

+1
-8
lines changed

melo/api.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -120,13 +120,6 @@ def bert_convert_to_ov(self, ov_path, language = "ZH"):
120120
example_input = example_input,
121121
)
122122

123-
get_input_names = lambda: ["input_ids", "token_type_ids", "attention_mask"]
124-
for input, input_name in zip(ov_model.inputs, get_input_names()):
125-
input.get_tensor().set_names({input_name})
126-
outputs_name = ['hidden_states']
127-
for output, output_name in zip(ov_model.outputs, outputs_name):
128-
output.get_tensor().set_names({output_name})
129-
130123
"""
131124
reshape model
132125
Set the batch size of all input tensors to 1 to facilitate the use of the C++ infer
@@ -212,7 +205,7 @@ def pad_tensor(input_tensor, pad_length=32):
212205

213206
self.bert_request.start_async(padded_inputs if self.device=="NPU" else inputs_dict , share_inputs=True)
214207
self.bert_request.wait()
215-
bert_output = (self.bert_request.get_tensor("hidden_states").data.copy())
208+
bert_output = self.bert_request.get_output_tensor(0).data.copy()
216209

217210
return bert_output
218211
class TTS(nn.Module):

0 commit comments

Comments
 (0)