1919from pathlib import Path
2020
2121from crawl4ai .async_webcrawler import AsyncWebCrawler
22- from crawl4ai .async_configs import CrawlerRunConfig , LinkPreviewConfig
22+ from crawl4ai .async_configs import CrawlerRunConfig , LinkPreviewConfig , LLMConfig
2323from crawl4ai .models import Link , CrawlResult
2424import numpy as np
2525
@@ -178,7 +178,7 @@ class AdaptiveConfig:
178178
179179 # Embedding strategy parameters
180180 embedding_model : str = "sentence-transformers/all-MiniLM-L6-v2"
181- embedding_llm_config : Optional [Dict ] = None # Separate config for embeddings
181+ embedding_llm_config : Optional [Union [ LLMConfig , Dict ] ] = None # Separate config for embeddings
182182 n_query_variations : int = 10
183183 coverage_threshold : float = 0.85
184184 alpha_shape_alpha : float = 0.5
@@ -250,6 +250,30 @@ def validate(self):
250250 assert 0 <= self .embedding_quality_max_confidence <= 1 , "embedding_quality_max_confidence must be between 0 and 1"
251251 assert self .embedding_quality_scale_factor > 0 , "embedding_quality_scale_factor must be positive"
252252 assert 0 <= self .embedding_min_confidence_threshold <= 1 , "embedding_min_confidence_threshold must be between 0 and 1"
253+
254+ @property
255+ def _embedding_llm_config_dict (self ) -> Optional [Dict ]:
256+ """Convert LLMConfig to dict format for backward compatibility."""
257+ if self .embedding_llm_config is None :
258+ return None
259+
260+ if isinstance (self .embedding_llm_config , dict ):
261+ # Already a dict - return as-is for backward compatibility
262+ return self .embedding_llm_config
263+
264+ # Convert LLMConfig object to dict format
265+ return {
266+ 'provider' : self .embedding_llm_config .provider ,
267+ 'api_token' : self .embedding_llm_config .api_token ,
268+ 'base_url' : getattr (self .embedding_llm_config , 'base_url' , None ),
269+ 'temperature' : getattr (self .embedding_llm_config , 'temperature' , None ),
270+ 'max_tokens' : getattr (self .embedding_llm_config , 'max_tokens' , None ),
271+ 'top_p' : getattr (self .embedding_llm_config , 'top_p' , None ),
272+ 'frequency_penalty' : getattr (self .embedding_llm_config , 'frequency_penalty' , None ),
273+ 'presence_penalty' : getattr (self .embedding_llm_config , 'presence_penalty' , None ),
274+ 'stop' : getattr (self .embedding_llm_config , 'stop' , None ),
275+ 'n' : getattr (self .embedding_llm_config , 'n' , None ),
276+ }
253277
254278
255279class CrawlStrategy (ABC ):
@@ -593,7 +617,7 @@ def _get_document_terms(self, crawl_result: CrawlResult) -> List[str]:
593617class EmbeddingStrategy (CrawlStrategy ):
594618 """Embedding-based adaptive crawling using semantic space coverage"""
595619
596- def __init__ (self , embedding_model : str = None , llm_config : Dict = None ):
620+ def __init__ (self , embedding_model : str = None , llm_config : Union [ LLMConfig , Dict ] = None ):
597621 self .embedding_model = embedding_model or "sentence-transformers/all-MiniLM-L6-v2"
598622 self .llm_config = llm_config
599623 self ._embedding_cache = {}
@@ -605,14 +629,24 @@ def __init__(self, embedding_model: str = None, llm_config: Dict = None):
605629 self ._kb_embeddings_hash = None # Track KB changes
606630 self ._validation_embeddings_cache = None # Cache validation query embeddings
607631 self ._kb_similarity_threshold = 0.95 # Threshold for deduplication
632+
633+ def _get_embedding_llm_config_dict (self ) -> Dict :
634+ """Get embedding LLM config as dict with fallback to default."""
635+ if hasattr (self , 'config' ) and self .config :
636+ config_dict = self .config ._embedding_llm_config_dict
637+ if config_dict :
638+ return config_dict
639+
640+ # Fallback to default if no config provided
641+ return {
642+ 'provider' : 'openai/text-embedding-3-small' ,
643+ 'api_token' : os .getenv ('OPENAI_API_KEY' )
644+ }
608645
609646 async def _get_embeddings (self , texts : List [str ]) -> Any :
610647 """Get embeddings using configured method"""
611648 from .utils import get_text_embeddings
612- embedding_llm_config = {
613- 'provider' : 'openai/text-embedding-3-small' ,
614- 'api_token' : os .getenv ('OPENAI_API_KEY' )
615- }
649+ embedding_llm_config = self ._get_embedding_llm_config_dict ()
616650 return await get_text_embeddings (
617651 texts ,
618652 embedding_llm_config ,
@@ -679,8 +713,20 @@ async def map_query_semantic_space(self, query: str, n_synthetic: int = 10) -> A
679713 Return as a JSON array of strings."""
680714
681715 # Use the LLM for query generation
682- provider = self .llm_config .get ('provider' , 'openai/gpt-4o-mini' ) if self .llm_config else 'openai/gpt-4o-mini'
683- api_token = self .llm_config .get ('api_token' ) if self .llm_config else None
716+ # Convert LLMConfig to dict if needed
717+ llm_config_dict = None
718+ if self .llm_config :
719+ if isinstance (self .llm_config , dict ):
720+ llm_config_dict = self .llm_config
721+ else :
722+ # Convert LLMConfig object to dict
723+ llm_config_dict = {
724+ 'provider' : self .llm_config .provider ,
725+ 'api_token' : self .llm_config .api_token
726+ }
727+
728+ provider = llm_config_dict .get ('provider' , 'openai/gpt-4o-mini' ) if llm_config_dict else 'openai/gpt-4o-mini'
729+ api_token = llm_config_dict .get ('api_token' ) if llm_config_dict else None
684730
685731 # response = perform_completion_with_backoff(
686732 # provider=provider,
@@ -843,10 +889,7 @@ async def select_links_for_expansion(
843889
844890 # Batch embed only uncached links
845891 if texts_to_embed :
846- embedding_llm_config = {
847- 'provider' : 'openai/text-embedding-3-small' ,
848- 'api_token' : os .getenv ('OPENAI_API_KEY' )
849- }
892+ embedding_llm_config = self ._get_embedding_llm_config_dict ()
850893 new_embeddings = await get_text_embeddings (texts_to_embed , embedding_llm_config , self .embedding_model )
851894
852895 # Cache the new embeddings
@@ -1184,10 +1227,7 @@ async def update_state(self, state: CrawlState, new_results: List[CrawlResult])
11841227 return
11851228
11861229 # Get embeddings for new texts
1187- embedding_llm_config = {
1188- 'provider' : 'openai/text-embedding-3-small' ,
1189- 'api_token' : os .getenv ('OPENAI_API_KEY' )
1190- }
1230+ embedding_llm_config = self ._get_embedding_llm_config_dict ()
11911231 new_embeddings = await get_text_embeddings (new_texts , embedding_llm_config , self .embedding_model )
11921232
11931233 # Deduplicate embeddings before adding to KB
@@ -1256,10 +1296,12 @@ def _create_strategy(self, strategy_name: str) -> CrawlStrategy:
12561296 if strategy_name == "statistical" :
12571297 return StatisticalStrategy ()
12581298 elif strategy_name == "embedding" :
1259- return EmbeddingStrategy (
1299+ strategy = EmbeddingStrategy (
12601300 embedding_model = self .config .embedding_model ,
12611301 llm_config = self .config .embedding_llm_config
12621302 )
1303+ strategy .config = self .config # Pass config to strategy
1304+ return strategy
12631305 else :
12641306 raise ValueError (f"Unknown strategy: { strategy_name } " )
12651307
0 commit comments