Skip to content

Commit c48d029

Browse files
committed
feat:应用agc
1 parent 0d8cd39 commit c48d029

File tree

3 files changed

+452
-2
lines changed

3 files changed

+452
-2
lines changed

src/main/java/com/xiaozhi/dialogue/service/VadService.java

Lines changed: 158 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import com.xiaozhi.communication.common.SessionManager;
44
import com.xiaozhi.dialogue.vad.impl.SileroVadModel;
55
import com.xiaozhi.entity.SysDevice;
6+
import com.xiaozhi.utils.AutomaticGainControl;
67
import com.xiaozhi.utils.OpusProcessor;
78

89
import org.slf4j.Logger;
@@ -34,6 +35,9 @@ public class VadService {
3435
@Autowired
3536
private SessionManager sessionManager;
3637

38+
@Autowired
39+
private AutomaticGainControl agc;
40+
3741
// 语音检测前缓冲时长(毫秒)
3842
private int preBufferMs = 500;
3943

@@ -55,6 +59,12 @@ public void init() {
5559
} else {
5660
logger.error("SileroVadModel未注入,VAD功能不可用");
5761
}
62+
63+
if (agc != null) {
64+
logger.info("AGC服务初始化成功");
65+
} else {
66+
logger.warn("AGC服务未注入,将跳过自动增益控制");
67+
}
5868
} catch (Exception e) {
5969
logger.error("初始化VAD服务失败", e);
6070
}
@@ -93,6 +103,11 @@ private class VadState {
93103
private final ByteArrayOutputStream pcmAccumulator = new ByteArrayOutputStream();
94104
private long lastAccumTime = 0;
95105

106+
// AGC相关统计
107+
private String detectedDeviceType = "normal";
108+
private int lowQualityFrameCount = 0;
109+
private int totalFrameCount = 0;
110+
96111
public VadState() {
97112
this.maxPreBufferSize = preBufferMs * 32; // 16kHz, 16bit, mono = 32 bytes/ms
98113
this.lastAccumTime = System.currentTimeMillis();
@@ -139,6 +154,17 @@ public void addProb(float prob) {
139154
if (probs.size() > 10) {
140155
probs.remove(0);
141156
}
157+
158+
// 更新设备质量统计
159+
totalFrameCount++;
160+
if (prob < 0.3f) {
161+
lowQualityFrameCount++;
162+
}
163+
164+
// 每50帧重新评估设备类型
165+
if (totalFrameCount % 50 == 0) {
166+
updateDeviceType();
167+
}
142168
}
143169

144170
public float getLastProb() {
@@ -149,6 +175,33 @@ public List<Float> getProbs() {
149175
return probs;
150176
}
151177

178+
public String getDetectedDeviceType() {
179+
return detectedDeviceType;
180+
}
181+
182+
private void updateDeviceType() {
183+
if (totalFrameCount < 10) {
184+
return; // 样本不足
185+
}
186+
187+
float lowQualityRatio = (float) lowQualityFrameCount / totalFrameCount;
188+
float avgProb = 0;
189+
for (Float prob : probs) {
190+
avgProb += prob;
191+
}
192+
avgProb /= probs.size();
193+
194+
if (lowQualityRatio > 0.7f || avgProb < 0.2f) {
195+
detectedDeviceType = "low_quality_mic";
196+
} else if (lowQualityRatio < 0.2f && avgProb > 0.6f) {
197+
detectedDeviceType = "high_quality_mic";
198+
} else if (avgEnergy < 0.001f) {
199+
detectedDeviceType = "weak_signal";
200+
} else {
201+
detectedDeviceType = "normal";
202+
}
203+
}
204+
152205
// 预缓冲区管理
153206
public void addToPreBuffer(byte[] data) {
154207
if (speaking)
@@ -239,6 +292,11 @@ public void reset() {
239292
opusData.clear();
240293
pcmAccumulator.reset();
241294
lastAccumTime = System.currentTimeMillis();
295+
296+
// 重置AGC相关统计
297+
detectedDeviceType = "normal";
298+
lowQualityFrameCount = 0;
299+
totalFrameCount = 0;
242300
}
243301
}
244302

@@ -255,6 +313,12 @@ public void initSession(String sessionId) {
255313
} else {
256314
state.reset();
257315
}
316+
317+
// 重置AGC状态
318+
if (agc != null) {
319+
agc.resetSession(sessionId);
320+
}
321+
258322
logger.info("VAD会话已初始化: {}", sessionId);
259323
}
260324
}
@@ -314,6 +378,18 @@ public VadResult processAudio(String sessionId, byte[] opusData) {
314378
return new VadResult(VadStatus.ERROR, null);
315379
}
316380

381+
// ========== 应用 AGC ==========
382+
String deviceType = state.getDetectedDeviceType();
383+
byte[] originalPcm = pcmData.clone(); // 保留原始数据用于对比
384+
pcmData = agc.process(sessionId, pcmData, deviceType);
385+
386+
// 获取AGC统计信息
387+
AutomaticGainControl.AgcStats agcStats = agc.getStats(sessionId);
388+
389+
// 根据AGC增益动态调整VAD阈值
390+
speechThreshold = adjustVadThreshold(speechThreshold, agcStats);
391+
silenceThreshold = adjustVadThreshold(silenceThreshold, agcStats);
392+
317393
// 添加到预缓冲区
318394
state.addToPreBuffer(pcmData);
319395

@@ -353,7 +429,15 @@ public VadResult processAudio(String sessionId, byte[] opusData) {
353429
// 语音开始
354430
state.pcmData.clear();
355431
state.setSpeaking(true);
356-
logger.info("检测到语音开始 - SessionId: {}, 概率: {}, 能量: {}", sessionId, speechProb, energy);
432+
433+
// 记录AGC和设备信息
434+
String agcInfo = "";
435+
agcInfo = String.format(", AGC增益: %.2f, 设备类型: %s",
436+
agcStats.gain, state.getDetectedDeviceType());
437+
438+
logger.info("检测到语音开始 - SessionId: {}, 概率: {}, 能量: {}, " +
439+
"调整后阈值: {}{}",
440+
sessionId, speechProb, energy, speechThreshold, agcInfo);
357441

358442
// 获取预缓冲数据
359443
byte[] preBufferData = state.drainPreBuffer();
@@ -400,6 +484,42 @@ public VadResult processAudio(String sessionId, byte[] opusData) {
400484
}
401485
}
402486

487+
/**
488+
* 根据AGC统计信息调整VAD阈值
489+
*/
490+
private float adjustVadThreshold(float baseThreshold, AutomaticGainControl.AgcStats agcStats) {
491+
float gainFactor = agcStats.gain;
492+
float snr = agcStats.snr;
493+
494+
// 基于增益的调整
495+
float gainAdjustment = 1.0f;
496+
if (gainFactor > 10.0f) {
497+
// 非常高的增益,大幅降低阈值
498+
gainAdjustment = 0.5f;
499+
} else if (gainFactor > 5.0f) {
500+
// 高增益,降低阈值
501+
gainAdjustment = 0.7f;
502+
} else if (gainFactor > 2.0f) {
503+
// 中等增益,略微降低阈值
504+
gainAdjustment = 0.85f;
505+
}
506+
507+
// 基于信噪比的调整
508+
float snrAdjustment = 1.0f;
509+
if (snr < 2.0f) {
510+
// 低信噪比,进一步降低阈值
511+
snrAdjustment = 0.8f;
512+
} else if (snr > 10.0f) {
513+
// 高信噪比,可以提高阈值
514+
snrAdjustment = 1.1f;
515+
}
516+
517+
float adjustedThreshold = baseThreshold * gainAdjustment * snrAdjustment;
518+
519+
// 确保阈值在合理范围内
520+
return Math.max(0.05f, Math.min(0.8f, adjustedThreshold));
521+
}
522+
403523
/**
404524
* 执行语音检测
405525
*/
@@ -475,6 +595,11 @@ public void resetSession(String sessionId) {
475595
}
476596
states.remove(sessionId);
477597
locks.remove(sessionId);
598+
599+
// 重置AGC状态
600+
if (agc != null) {
601+
agc.resetSession(sessionId);
602+
}
478603
}
479604
}
480605

@@ -522,6 +647,38 @@ public List<byte[]> getOpusData(String sessionId) {
522647
}
523648
}
524649

650+
/**
651+
* 获取AGC统计信息
652+
*/
653+
public AutomaticGainControl.AgcStats getAgcStats(String sessionId) {
654+
return agc != null ? agc.getStats(sessionId) : new AutomaticGainControl.AgcStats();
655+
}
656+
657+
/**
658+
* 获取检测到的设备类型
659+
*/
660+
public String getDetectedDeviceType(String sessionId) {
661+
Object lock = getLock(sessionId);
662+
synchronized (lock) {
663+
VadState state = states.get(sessionId);
664+
return state != null ? state.getDetectedDeviceType() : "normal";
665+
}
666+
}
667+
668+
/**
669+
* 手动设置设备类型(用于测试或特殊情况)
670+
*/
671+
public void setDeviceType(String sessionId, String deviceType) {
672+
Object lock = getLock(sessionId);
673+
synchronized (lock) {
674+
VadState state = states.get(sessionId);
675+
if (state != null) {
676+
state.detectedDeviceType = deviceType;
677+
logger.info("手动设置设备类型 - SessionId: {}, 类型: {}", sessionId, deviceType);
678+
}
679+
}
680+
}
681+
525682
/**
526683
* VAD状态枚举
527684
*/

0 commit comments

Comments
 (0)