Skip to content

Commit 0238304

Browse files
committed
2 parents 007b62f + 5e86775 commit 0238304

File tree

9 files changed

+620
-147
lines changed

9 files changed

+620
-147
lines changed

pom.xml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -227,6 +227,12 @@
227227
<artifactId>dashscope-sdk-java</artifactId>
228228
<version>2.20.2</version>
229229
</dependency>
230+
<!-- 讯飞 -->
231+
<dependency>
232+
<groupId>cn.xfyun</groupId>
233+
<artifactId>websdk-java-speech</artifactId>
234+
<version>3.0.2</version>
235+
</dependency>
230236
<!-- COZE -->
231237
<dependency>
232238
<groupId>com.coze</groupId>

src/main/java/com/xiaozhi/dialogue/stt/factory/SttServiceFactory.java

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,7 @@
11
package com.xiaozhi.dialogue.stt.factory;
22

33
import com.xiaozhi.dialogue.stt.SttService;
4-
import com.xiaozhi.dialogue.stt.providers.AliyunSttService;
5-
import com.xiaozhi.dialogue.stt.providers.FunASRSttService;
6-
import com.xiaozhi.dialogue.stt.providers.TencentSttService;
7-
import com.xiaozhi.dialogue.stt.providers.VoskSttService;
4+
import com.xiaozhi.dialogue.stt.providers.*;
85
import com.xiaozhi.entity.SysConfig;
96

107
import org.slf4j.Logger;
@@ -164,6 +161,8 @@ private SttService createApiService(SysConfig config) {
164161
return new AliyunSttService(config);
165162
} else if ("funasr".equals(provider)) {
166163
return new FunASRSttService(config);
164+
} else if ("xfyun".equals(provider)) {
165+
return new XfyunSttService(config);
167166
}
168167
// 可以添加其他服务提供商的支持
169168

Lines changed: 316 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,316 @@
1+
package com.xiaozhi.dialogue.stt.providers;
2+
3+
import cn.xfyun.api.IatClient;
4+
import cn.xfyun.model.response.iat.IatResponse;
5+
import cn.xfyun.model.response.iat.IatResult;
6+
import cn.xfyun.model.response.iat.Text;
7+
import cn.xfyun.service.iat.AbstractIatWebSocketListener;
8+
import com.xiaozhi.dialogue.stt.SttService;
9+
import com.xiaozhi.entity.SysConfig;
10+
import com.xiaozhi.utils.AudioUtils;
11+
import okhttp3.Response;
12+
import okhttp3.WebSocket;
13+
import org.apache.commons.lang3.StringUtils;
14+
import org.slf4j.Logger;
15+
import org.slf4j.LoggerFactory;
16+
import reactor.core.publisher.Sinks;
17+
18+
import java.io.ByteArrayInputStream;
19+
import java.io.File;
20+
import java.net.MalformedURLException;
21+
import java.security.SignatureException;
22+
import java.util.ArrayList;
23+
import java.util.List;
24+
import java.util.UUID;
25+
import java.util.concurrent.*;
26+
import java.util.concurrent.atomic.AtomicBoolean;
27+
28+
public class XfyunSttService implements SttService {
29+
private static final Logger logger = LoggerFactory.getLogger(XfyunSttService.class);
30+
31+
private static final String PROVIDER_NAME = "xfyun";
32+
// 队列等待超时时间
33+
private static final int QUEUE_TIMEOUT_MS = 60000;
34+
// 识别超时时间(60秒)
35+
private static final long RECOGNITION_TIMEOUT_MS = 60000;
36+
37+
// 存储当前活跃的识别会话
38+
private final ConcurrentHashMap<String, IatClient> activeRecognizers = new ConcurrentHashMap<>();
39+
40+
private String secretId;
41+
private String secretKey;
42+
private String appId;
43+
44+
public XfyunSttService(SysConfig config) {
45+
if (config != null) {
46+
this.secretId = config.getApiKey();
47+
this.secretKey = config.getApiSecret();
48+
this.appId = config.getAppId();
49+
}
50+
}
51+
52+
@Override
53+
public String getProviderName() {
54+
return PROVIDER_NAME;
55+
}
56+
57+
@Override
58+
public boolean supportsStreaming() {
59+
return true;
60+
}
61+
62+
@Override
63+
public String recognition(byte[] audioData) {
64+
if (audioData == null || audioData.length == 0) {
65+
logger.warn("音频数据为空!");
66+
return null;
67+
}
68+
List<Text> resultSegments = new ArrayList<>();
69+
// 将原始音频数据转换为MP3格式并保存(用于调试)
70+
String fileName = AudioUtils.saveAsWav(audioData);
71+
File file = new File(fileName);
72+
CountDownLatch recognitionLatch = new CountDownLatch(1);
73+
try {
74+
// 检查配置是否已设置
75+
if (secretId == null || secretKey == null) {
76+
logger.error("讯飞云语音识别配置未设置,无法进行识别");
77+
return null;
78+
}
79+
80+
// 设置听写参数,这里的appid,apiKey,apiSecret是在开放平台控制台获得
81+
IatClient iatClient = new IatClient.Builder()
82+
.signature(appId, secretId, secretKey)
83+
// 动态修正功能:值为wpgs时代表开启(包含修正功能的)流式听写
84+
.dwa("wpgs")
85+
.build();
86+
87+
iatClient.send(file, new AbstractIatWebSocketListener() {
88+
@Override
89+
public void onSuccess(WebSocket webSocket, IatResponse iatResponse) {
90+
if (iatResponse.getCode() != 0) {
91+
logger.warn("code:{}, error:{}, sid:{}", iatResponse.getCode(), iatResponse.getMessage(), iatResponse.getSid());
92+
logger.warn("错误码查询链接:https://www.xfyun.cn/document/error-code");
93+
return;
94+
}
95+
96+
if (iatResponse.getData() != null) {
97+
if (iatResponse.getData().getResult() != null) {
98+
// 解析服务端返回结果
99+
IatResult result = iatResponse.getData().getResult();
100+
Text textObject = result.getText();
101+
handleResultText(textObject, resultSegments);
102+
logger.info("中间识别结果:{}", getFinalResult(resultSegments));
103+
}
104+
105+
if (iatResponse.getData().getStatus() == 2) {
106+
// resp.data.status ==2 说明数据全部返回完毕,可以关闭连接,释放资源
107+
logger.info("session end ");
108+
iatClient.closeWebsocket();
109+
recognitionLatch.countDown();
110+
} else {
111+
// 根据返回的数据自定义处理逻辑
112+
}
113+
}
114+
}
115+
116+
@Override
117+
public void onFail(WebSocket webSocket, Throwable t, Response response) {
118+
// 自定义处理逻辑
119+
recognitionLatch.countDown();
120+
}
121+
});
122+
// 等待识别完成或超时
123+
boolean recognized = recognitionLatch.await(RECOGNITION_TIMEOUT_MS, TimeUnit.MILLISECONDS);
124+
if (!recognized) {
125+
logger.warn("讯飞云识别超时");
126+
}
127+
return getFinalResult(resultSegments);
128+
} catch (Exception e) {
129+
logger.error("处理音频时发生错误!", e);
130+
return null;
131+
}
132+
}
133+
134+
/**
135+
* 处理返回结果(包括全量返回与流式返回(结果修正))
136+
*/
137+
private void handleResultText(Text textObject, List<Text> resultSegments) {
138+
// 处理流式返回的替换结果
139+
if (StringUtils.equals(textObject.getPgs(), "rpl") && textObject.getRg() != null && textObject.getRg().length == 2) {
140+
// 返回结果序号sn字段的最小值为1
141+
int start = textObject.getRg()[0] - 1;
142+
int end = textObject.getRg()[1] - 1;
143+
144+
// 将指定区间的结果设置为删除状态
145+
for (int i = start; i <= end && i < resultSegments.size(); i++) {
146+
resultSegments.get(i).setDeleted(true);
147+
}
148+
// logger.info("替换操作,服务端返回结果为:" + textObject);
149+
}
150+
151+
// 通用逻辑,添加当前文本到结果列表
152+
resultSegments.add(textObject);
153+
}
154+
155+
/**
156+
* 获取最终结果
157+
*/
158+
private String getFinalResult(List<Text> resultSegments) {
159+
StringBuilder finalResult = new StringBuilder();
160+
for (Text text : resultSegments) {
161+
if (text != null && !text.isDeleted()) {
162+
finalResult.append(text.getText());
163+
}
164+
}
165+
return finalResult.toString();
166+
}
167+
168+
@Override
169+
public String streamRecognition(Sinks.Many<byte[]> audioSink) {
170+
// 检查配置是否已设置
171+
if (secretId == null || secretKey == null || appId == null) {
172+
logger.error("讯飞云语音识别配置未设置,无法进行识别");
173+
return null;
174+
}
175+
176+
// 使用阻塞队列存储音频数据
177+
BlockingQueue<byte[]> audioQueue = new LinkedBlockingQueue<>();
178+
AtomicBoolean isCompleted = new AtomicBoolean(false);
179+
CountDownLatch recognitionLatch = new CountDownLatch(1);
180+
List<Text> resultSegments = new ArrayList<>();
181+
182+
// // 订阅Sink并将数据放入队列
183+
// audioSink.asFlux().subscribe(
184+
// data -> audioQueue.offer(data),
185+
// error -> {
186+
// logger.error("音频流处理错误", error);
187+
// isCompleted.set(true);
188+
// },
189+
// () -> isCompleted.set(true)
190+
// );
191+
192+
// 处理合并后的完整字节数组
193+
audioSink.asFlux()
194+
.reduce((bytes1, bytes2) -> {
195+
// 创建新数组并合并两个字节数组
196+
byte[] merged = new byte[bytes1.length + bytes2.length];
197+
System.arraycopy(bytes1, 0, merged, 0, bytes1.length);
198+
System.arraycopy(bytes2, 0, merged, bytes1.length, bytes2.length);
199+
return merged;
200+
})
201+
.subscribe(audioQueue::offer,
202+
error -> {
203+
logger.error("音频流处理错误", error);
204+
isCompleted.set(true);
205+
},
206+
() -> isCompleted.set(true)
207+
);
208+
209+
// 设置听写参数,这里的appid,apiKey,apiSecret是在开放平台控制台获得
210+
IatClient iatClient = new IatClient.Builder()
211+
.signature(appId, secretId, secretKey)
212+
// 动态修正功能:值为wpgs时代表开启(包含修正功能的)流式听写
213+
.dwa("wpgs")
214+
.build();
215+
216+
// 生成唯一的语音ID
217+
String voiceId = UUID.randomUUID().toString();
218+
// 存储到活跃识别器映射中
219+
activeRecognizers.put(voiceId, iatClient);
220+
221+
AbstractIatWebSocketListener socketListener = new AbstractIatWebSocketListener() {
222+
@Override
223+
public void onSuccess(WebSocket webSocket, IatResponse iatResponse) {
224+
if (iatResponse.getCode() != 0) {
225+
logger.warn("code:{}, error:{}, sid:{}", iatResponse.getCode(), iatResponse.getMessage(), iatResponse.getSid());
226+
logger.warn("错误码查询链接:https://www.xfyun.cn/document/error-code");
227+
return;
228+
}
229+
230+
if (iatResponse.getData() != null) {
231+
if (iatResponse.getData().getResult() != null) {
232+
// 解析服务端返回结果
233+
IatResult result = iatResponse.getData().getResult();
234+
Text textObject = result.getText();
235+
handleResultText(textObject, resultSegments);
236+
logger.info("中间识别结果:{}", getFinalResult(resultSegments));
237+
}
238+
if (iatResponse.getData().getStatus() == 2) {
239+
// resp.data.status ==2 说明数据全部返回完毕,可以关闭连接,释放资源
240+
logger.info("session end ");
241+
recognitionLatch.countDown();
242+
iatClient.closeWebsocket();
243+
} else {
244+
// 根据返回的数据自定义处理逻辑
245+
}
246+
}
247+
}
248+
249+
@Override
250+
public void onFail(WebSocket webSocket, Throwable t, Response response) {
251+
// 自定义处理逻辑
252+
// 释放锁,表示识别完成
253+
recognitionLatch.countDown();
254+
iatClient.closeWebsocket();
255+
logger.error("xfyun stt fail,原因:{}", t.getMessage());
256+
}
257+
};
258+
259+
260+
// 使用虚拟线程处理音频识别
261+
try {
262+
Thread.startVirtualThread(() -> {
263+
while (!isCompleted.get() || !audioQueue.isEmpty()) {
264+
265+
byte[] audioChunk = null;
266+
try {
267+
audioChunk = audioQueue.poll(QUEUE_TIMEOUT_MS, TimeUnit.MILLISECONDS);
268+
} catch (InterruptedException e) {
269+
logger.warn("音频数据队列等待被中断", e);
270+
Thread.currentThread().interrupt(); // 重新设置中断标志
271+
break;
272+
}
273+
274+
if (audioChunk != null && activeRecognizers.containsKey(voiceId)) {
275+
try {
276+
iatClient.send(new ByteArrayInputStream(audioChunk), socketListener);
277+
} catch (MalformedURLException e) {
278+
throw new RuntimeException(e);
279+
} catch (SignatureException e) {
280+
throw new RuntimeException(e);
281+
}
282+
283+
// 如果已完成且队列为空,获取最终结果
284+
if (isCompleted.get() && audioQueue.isEmpty()) {
285+
activeRecognizers.remove(voiceId);
286+
break;
287+
}
288+
}
289+
}
290+
}).join(); // 等待虚拟线程完成
291+
} catch (Exception e) {
292+
logger.error("启动虚拟线程失败", e);
293+
}
294+
295+
try {
296+
// 等待识别完成或超时
297+
boolean recognized = recognitionLatch.await(RECOGNITION_TIMEOUT_MS, TimeUnit.MILLISECONDS);
298+
if (!recognized) {
299+
logger.warn("讯飞云识别超时 - VoiceId: {}", voiceId);
300+
// 超时后清理资源
301+
if (activeRecognizers.containsKey(voiceId)) {
302+
try {
303+
iatClient.closeWebsocket();
304+
activeRecognizers.remove(voiceId);
305+
} catch (Exception e) {
306+
logger.error("清理超时识别器资源时发生错误 - VoiceId: {}", voiceId, e);
307+
}
308+
}
309+
}
310+
} catch (Exception e) {
311+
logger.error("创建语音识别会话时发生错误", e);
312+
}
313+
314+
return getFinalResult(resultSegments);
315+
}
316+
}

src/main/java/com/xiaozhi/dialogue/tts/factory/TtsServiceFactory.java

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,7 @@
11
package com.xiaozhi.dialogue.tts.factory;
22

33
import com.xiaozhi.dialogue.tts.TtsService;
4-
import com.xiaozhi.dialogue.tts.providers.AliyunTtsService;
5-
import com.xiaozhi.dialogue.tts.providers.EdgeTtsService;
6-
import com.xiaozhi.dialogue.tts.providers.VolcengineTtsService;
4+
import com.xiaozhi.dialogue.tts.providers.*;
75
import com.xiaozhi.entity.SysConfig;
86

97
import org.slf4j.Logger;
@@ -98,7 +96,9 @@ private TtsService createApiService(SysConfig config, String voiceName, String o
9896
return new AliyunTtsService(config, voiceName, outputPath);
9997
} else if ("volcengine".equals(provider)) {
10098
return new VolcengineTtsService(config, voiceName, outputPath);
101-
} /*
99+
} else if ("xfyun".equals(provider)) {
100+
return new XfyunTtsService(config, voiceName, outputPath);
101+
}/*
102102
* else if ("tencent".equals(provider)) {
103103
* return new TencentTtsService(config, voiceName, outputPath);
104104
* }

0 commit comments

Comments
 (0)