33import com .xiaozhi .communication .common .SessionManager ;
44import com .xiaozhi .dialogue .vad .impl .SileroVadModel ;
55import com .xiaozhi .entity .SysDevice ;
6+ import com .xiaozhi .utils .AutomaticGainControl ;
67import com .xiaozhi .utils .OpusProcessor ;
78
89import org .slf4j .Logger ;
@@ -34,6 +35,9 @@ public class VadService {
3435 @ Autowired
3536 private SessionManager sessionManager ;
3637
38+ @ Autowired
39+ private AutomaticGainControl agc ;
40+
3741 // 语音检测前缓冲时长(毫秒)
3842 private int preBufferMs = 500 ;
3943
@@ -55,6 +59,12 @@ public void init() {
5559 } else {
5660 logger .error ("SileroVadModel未注入,VAD功能不可用" );
5761 }
62+
63+ if (agc != null ) {
64+ logger .info ("AGC服务初始化成功" );
65+ } else {
66+ logger .warn ("AGC服务未注入,将跳过自动增益控制" );
67+ }
5868 } catch (Exception e ) {
5969 logger .error ("初始化VAD服务失败" , e );
6070 }
@@ -93,6 +103,11 @@ private class VadState {
93103 private final ByteArrayOutputStream pcmAccumulator = new ByteArrayOutputStream ();
94104 private long lastAccumTime = 0 ;
95105
106+ // AGC相关统计
107+ private String detectedDeviceType = "normal" ;
108+ private int lowQualityFrameCount = 0 ;
109+ private int totalFrameCount = 0 ;
110+
96111 public VadState () {
97112 this .maxPreBufferSize = preBufferMs * 32 ; // 16kHz, 16bit, mono = 32 bytes/ms
98113 this .lastAccumTime = System .currentTimeMillis ();
@@ -139,6 +154,17 @@ public void addProb(float prob) {
139154 if (probs .size () > 10 ) {
140155 probs .remove (0 );
141156 }
157+
158+ // 更新设备质量统计
159+ totalFrameCount ++;
160+ if (prob < 0.3f ) {
161+ lowQualityFrameCount ++;
162+ }
163+
164+ // 每50帧重新评估设备类型
165+ if (totalFrameCount % 50 == 0 ) {
166+ updateDeviceType ();
167+ }
142168 }
143169
144170 public float getLastProb () {
@@ -149,6 +175,33 @@ public List<Float> getProbs() {
149175 return probs ;
150176 }
151177
178+ public String getDetectedDeviceType () {
179+ return detectedDeviceType ;
180+ }
181+
182+ private void updateDeviceType () {
183+ if (totalFrameCount < 10 ) {
184+ return ; // 样本不足
185+ }
186+
187+ float lowQualityRatio = (float ) lowQualityFrameCount / totalFrameCount ;
188+ float avgProb = 0 ;
189+ for (Float prob : probs ) {
190+ avgProb += prob ;
191+ }
192+ avgProb /= probs .size ();
193+
194+ if (lowQualityRatio > 0.7f || avgProb < 0.2f ) {
195+ detectedDeviceType = "low_quality_mic" ;
196+ } else if (lowQualityRatio < 0.2f && avgProb > 0.6f ) {
197+ detectedDeviceType = "high_quality_mic" ;
198+ } else if (avgEnergy < 0.001f ) {
199+ detectedDeviceType = "weak_signal" ;
200+ } else {
201+ detectedDeviceType = "normal" ;
202+ }
203+ }
204+
152205 // 预缓冲区管理
153206 public void addToPreBuffer (byte [] data ) {
154207 if (speaking )
@@ -239,6 +292,11 @@ public void reset() {
239292 opusData .clear ();
240293 pcmAccumulator .reset ();
241294 lastAccumTime = System .currentTimeMillis ();
295+
296+ // 重置AGC相关统计
297+ detectedDeviceType = "normal" ;
298+ lowQualityFrameCount = 0 ;
299+ totalFrameCount = 0 ;
242300 }
243301 }
244302
@@ -255,6 +313,12 @@ public void initSession(String sessionId) {
255313 } else {
256314 state .reset ();
257315 }
316+
317+ // 重置AGC状态
318+ if (agc != null ) {
319+ agc .resetSession (sessionId );
320+ }
321+
258322 logger .info ("VAD会话已初始化: {}" , sessionId );
259323 }
260324 }
@@ -314,6 +378,18 @@ public VadResult processAudio(String sessionId, byte[] opusData) {
314378 return new VadResult (VadStatus .ERROR , null );
315379 }
316380
381+ // ========== 应用 AGC ==========
382+ String deviceType = state .getDetectedDeviceType ();
383+ byte [] originalPcm = pcmData .clone (); // 保留原始数据用于对比
384+ pcmData = agc .process (sessionId , pcmData , deviceType );
385+
386+ // 获取AGC统计信息
387+ AutomaticGainControl .AgcStats agcStats = agc .getStats (sessionId );
388+
389+ // 根据AGC增益动态调整VAD阈值
390+ speechThreshold = adjustVadThreshold (speechThreshold , agcStats );
391+ silenceThreshold = adjustVadThreshold (silenceThreshold , agcStats );
392+
317393 // 添加到预缓冲区
318394 state .addToPreBuffer (pcmData );
319395
@@ -353,7 +429,15 @@ public VadResult processAudio(String sessionId, byte[] opusData) {
353429 // 语音开始
354430 state .pcmData .clear ();
355431 state .setSpeaking (true );
356- logger .info ("检测到语音开始 - SessionId: {}, 概率: {}, 能量: {}" , sessionId , speechProb , energy );
432+
433+ // 记录AGC和设备信息
434+ String agcInfo = "" ;
435+ agcInfo = String .format (", AGC增益: %.2f, 设备类型: %s" ,
436+ agcStats .gain , state .getDetectedDeviceType ());
437+
438+ logger .info ("检测到语音开始 - SessionId: {}, 概率: {}, 能量: {}, " +
439+ "调整后阈值: {}{}" ,
440+ sessionId , speechProb , energy , speechThreshold , agcInfo );
357441
358442 // 获取预缓冲数据
359443 byte [] preBufferData = state .drainPreBuffer ();
@@ -400,6 +484,42 @@ public VadResult processAudio(String sessionId, byte[] opusData) {
400484 }
401485 }
402486
487+ /**
488+ * 根据AGC统计信息调整VAD阈值
489+ */
490+ private float adjustVadThreshold (float baseThreshold , AutomaticGainControl .AgcStats agcStats ) {
491+ float gainFactor = agcStats .gain ;
492+ float snr = agcStats .snr ;
493+
494+ // 基于增益的调整
495+ float gainAdjustment = 1.0f ;
496+ if (gainFactor > 10.0f ) {
497+ // 非常高的增益,大幅降低阈值
498+ gainAdjustment = 0.5f ;
499+ } else if (gainFactor > 5.0f ) {
500+ // 高增益,降低阈值
501+ gainAdjustment = 0.7f ;
502+ } else if (gainFactor > 2.0f ) {
503+ // 中等增益,略微降低阈值
504+ gainAdjustment = 0.85f ;
505+ }
506+
507+ // 基于信噪比的调整
508+ float snrAdjustment = 1.0f ;
509+ if (snr < 2.0f ) {
510+ // 低信噪比,进一步降低阈值
511+ snrAdjustment = 0.8f ;
512+ } else if (snr > 10.0f ) {
513+ // 高信噪比,可以提高阈值
514+ snrAdjustment = 1.1f ;
515+ }
516+
517+ float adjustedThreshold = baseThreshold * gainAdjustment * snrAdjustment ;
518+
519+ // 确保阈值在合理范围内
520+ return Math .max (0.05f , Math .min (0.8f , adjustedThreshold ));
521+ }
522+
403523 /**
404524 * 执行语音检测
405525 */
@@ -475,6 +595,11 @@ public void resetSession(String sessionId) {
475595 }
476596 states .remove (sessionId );
477597 locks .remove (sessionId );
598+
599+ // 重置AGC状态
600+ if (agc != null ) {
601+ agc .resetSession (sessionId );
602+ }
478603 }
479604 }
480605
@@ -522,6 +647,38 @@ public List<byte[]> getOpusData(String sessionId) {
522647 }
523648 }
524649
650+ /**
651+ * 获取AGC统计信息
652+ */
653+ public AutomaticGainControl .AgcStats getAgcStats (String sessionId ) {
654+ return agc != null ? agc .getStats (sessionId ) : new AutomaticGainControl .AgcStats ();
655+ }
656+
657+ /**
658+ * 获取检测到的设备类型
659+ */
660+ public String getDetectedDeviceType (String sessionId ) {
661+ Object lock = getLock (sessionId );
662+ synchronized (lock ) {
663+ VadState state = states .get (sessionId );
664+ return state != null ? state .getDetectedDeviceType () : "normal" ;
665+ }
666+ }
667+
668+ /**
669+ * 手动设置设备类型(用于测试或特殊情况)
670+ */
671+ public void setDeviceType (String sessionId , String deviceType ) {
672+ Object lock = getLock (sessionId );
673+ synchronized (lock ) {
674+ VadState state = states .get (sessionId );
675+ if (state != null ) {
676+ state .detectedDeviceType = deviceType ;
677+ logger .info ("手动设置设备类型 - SessionId: {}, 类型: {}" , sessionId , deviceType );
678+ }
679+ }
680+ }
681+
525682 /**
526683 * VAD状态枚举
527684 */
0 commit comments