Skip to content

Commit 14b42b1

Browse files
authored
Merge pull request unclecode#1471 from unclecode/fix/adaptive-crawler-llm-config
Fix: allow custom LLM providers for adaptive crawler embedding config…
2 parents 0482c1e + 3bc56dd commit 14b42b1

File tree

4 files changed

+381
-19
lines changed

4 files changed

+381
-19
lines changed

crawl4ai/adaptive_crawler.py

Lines changed: 60 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
from pathlib import Path
2020

2121
from crawl4ai.async_webcrawler import AsyncWebCrawler
22-
from crawl4ai.async_configs import CrawlerRunConfig, LinkPreviewConfig
22+
from crawl4ai.async_configs import CrawlerRunConfig, LinkPreviewConfig, LLMConfig
2323
from crawl4ai.models import Link, CrawlResult
2424
import numpy as np
2525

@@ -178,7 +178,7 @@ class AdaptiveConfig:
178178

179179
# Embedding strategy parameters
180180
embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2"
181-
embedding_llm_config: Optional[Dict] = None # Separate config for embeddings
181+
embedding_llm_config: Optional[Union[LLMConfig, Dict]] = None # Separate config for embeddings
182182
n_query_variations: int = 10
183183
coverage_threshold: float = 0.85
184184
alpha_shape_alpha: float = 0.5
@@ -250,6 +250,30 @@ def validate(self):
250250
assert 0 <= self.embedding_quality_max_confidence <= 1, "embedding_quality_max_confidence must be between 0 and 1"
251251
assert self.embedding_quality_scale_factor > 0, "embedding_quality_scale_factor must be positive"
252252
assert 0 <= self.embedding_min_confidence_threshold <= 1, "embedding_min_confidence_threshold must be between 0 and 1"
253+
254+
@property
255+
def _embedding_llm_config_dict(self) -> Optional[Dict]:
256+
"""Convert LLMConfig to dict format for backward compatibility."""
257+
if self.embedding_llm_config is None:
258+
return None
259+
260+
if isinstance(self.embedding_llm_config, dict):
261+
# Already a dict - return as-is for backward compatibility
262+
return self.embedding_llm_config
263+
264+
# Convert LLMConfig object to dict format
265+
return {
266+
'provider': self.embedding_llm_config.provider,
267+
'api_token': self.embedding_llm_config.api_token,
268+
'base_url': getattr(self.embedding_llm_config, 'base_url', None),
269+
'temperature': getattr(self.embedding_llm_config, 'temperature', None),
270+
'max_tokens': getattr(self.embedding_llm_config, 'max_tokens', None),
271+
'top_p': getattr(self.embedding_llm_config, 'top_p', None),
272+
'frequency_penalty': getattr(self.embedding_llm_config, 'frequency_penalty', None),
273+
'presence_penalty': getattr(self.embedding_llm_config, 'presence_penalty', None),
274+
'stop': getattr(self.embedding_llm_config, 'stop', None),
275+
'n': getattr(self.embedding_llm_config, 'n', None),
276+
}
253277

254278

255279
class CrawlStrategy(ABC):
@@ -593,7 +617,7 @@ def _get_document_terms(self, crawl_result: CrawlResult) -> List[str]:
593617
class EmbeddingStrategy(CrawlStrategy):
594618
"""Embedding-based adaptive crawling using semantic space coverage"""
595619

596-
def __init__(self, embedding_model: str = None, llm_config: Dict = None):
620+
def __init__(self, embedding_model: str = None, llm_config: Union[LLMConfig, Dict] = None):
597621
self.embedding_model = embedding_model or "sentence-transformers/all-MiniLM-L6-v2"
598622
self.llm_config = llm_config
599623
self._embedding_cache = {}
@@ -605,14 +629,24 @@ def __init__(self, embedding_model: str = None, llm_config: Dict = None):
605629
self._kb_embeddings_hash = None # Track KB changes
606630
self._validation_embeddings_cache = None # Cache validation query embeddings
607631
self._kb_similarity_threshold = 0.95 # Threshold for deduplication
632+
633+
def _get_embedding_llm_config_dict(self) -> Dict:
634+
"""Get embedding LLM config as dict with fallback to default."""
635+
if hasattr(self, 'config') and self.config:
636+
config_dict = self.config._embedding_llm_config_dict
637+
if config_dict:
638+
return config_dict
639+
640+
# Fallback to default if no config provided
641+
return {
642+
'provider': 'openai/text-embedding-3-small',
643+
'api_token': os.getenv('OPENAI_API_KEY')
644+
}
608645

609646
async def _get_embeddings(self, texts: List[str]) -> Any:
610647
"""Get embeddings using configured method"""
611648
from .utils import get_text_embeddings
612-
embedding_llm_config = {
613-
'provider': 'openai/text-embedding-3-small',
614-
'api_token': os.getenv('OPENAI_API_KEY')
615-
}
649+
embedding_llm_config = self._get_embedding_llm_config_dict()
616650
return await get_text_embeddings(
617651
texts,
618652
embedding_llm_config,
@@ -679,8 +713,20 @@ async def map_query_semantic_space(self, query: str, n_synthetic: int = 10) -> A
679713
Return as a JSON array of strings."""
680714

681715
# Use the LLM for query generation
682-
provider = self.llm_config.get('provider', 'openai/gpt-4o-mini') if self.llm_config else 'openai/gpt-4o-mini'
683-
api_token = self.llm_config.get('api_token') if self.llm_config else None
716+
# Convert LLMConfig to dict if needed
717+
llm_config_dict = None
718+
if self.llm_config:
719+
if isinstance(self.llm_config, dict):
720+
llm_config_dict = self.llm_config
721+
else:
722+
# Convert LLMConfig object to dict
723+
llm_config_dict = {
724+
'provider': self.llm_config.provider,
725+
'api_token': self.llm_config.api_token
726+
}
727+
728+
provider = llm_config_dict.get('provider', 'openai/gpt-4o-mini') if llm_config_dict else 'openai/gpt-4o-mini'
729+
api_token = llm_config_dict.get('api_token') if llm_config_dict else None
684730

685731
# response = perform_completion_with_backoff(
686732
# provider=provider,
@@ -843,10 +889,7 @@ async def select_links_for_expansion(
843889

844890
# Batch embed only uncached links
845891
if texts_to_embed:
846-
embedding_llm_config = {
847-
'provider': 'openai/text-embedding-3-small',
848-
'api_token': os.getenv('OPENAI_API_KEY')
849-
}
892+
embedding_llm_config = self._get_embedding_llm_config_dict()
850893
new_embeddings = await get_text_embeddings(texts_to_embed, embedding_llm_config, self.embedding_model)
851894

852895
# Cache the new embeddings
@@ -1184,10 +1227,7 @@ async def update_state(self, state: CrawlState, new_results: List[CrawlResult])
11841227
return
11851228

11861229
# Get embeddings for new texts
1187-
embedding_llm_config = {
1188-
'provider': 'openai/text-embedding-3-small',
1189-
'api_token': os.getenv('OPENAI_API_KEY')
1190-
}
1230+
embedding_llm_config = self._get_embedding_llm_config_dict()
11911231
new_embeddings = await get_text_embeddings(new_texts, embedding_llm_config, self.embedding_model)
11921232

11931233
# Deduplicate embeddings before adding to KB
@@ -1256,10 +1296,12 @@ def _create_strategy(self, strategy_name: str) -> CrawlStrategy:
12561296
if strategy_name == "statistical":
12571297
return StatisticalStrategy()
12581298
elif strategy_name == "embedding":
1259-
return EmbeddingStrategy(
1299+
strategy = EmbeddingStrategy(
12601300
embedding_model=self.config.embedding_model,
12611301
llm_config=self.config.embedding_llm_config
12621302
)
1303+
strategy.config = self.config # Pass config to strategy
1304+
return strategy
12631305
else:
12641306
raise ValueError(f"Unknown strategy: {strategy_name}")
12651307

Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
import asyncio
2+
import os
3+
from crawl4ai import AsyncWebCrawler, AdaptiveCrawler, AdaptiveConfig, LLMConfig
4+
5+
6+
async def test_configuration(name: str, config: AdaptiveConfig, url: str, query: str):
7+
"""Test a specific configuration"""
8+
print(f"\n{'='*60}")
9+
print(f"Configuration: {name}")
10+
print(f"{'='*60}")
11+
12+
async with AsyncWebCrawler(verbose=False) as crawler:
13+
adaptive = AdaptiveCrawler(crawler, config)
14+
result = await adaptive.digest(start_url=url, query=query)
15+
16+
print("\n" + "="*50)
17+
print("CRAWL STATISTICS")
18+
print("="*50)
19+
adaptive.print_stats(detailed=False)
20+
21+
# Get the most relevant content found
22+
print("\n" + "="*50)
23+
print("MOST RELEVANT PAGES")
24+
print("="*50)
25+
26+
relevant_pages = adaptive.get_relevant_content(top_k=5)
27+
for i, page in enumerate(relevant_pages, 1):
28+
print(f"\n{i}. {page['url']}")
29+
print(f" Relevance Score: {page['score']:.2%}")
30+
31+
# Show a snippet of the content
32+
content = page['content'] or ""
33+
if content:
34+
snippet = content[:200].replace('\n', ' ')
35+
if len(content) > 200:
36+
snippet += "..."
37+
print(f" Preview: {snippet}")
38+
39+
print(f"\n{'='*50}")
40+
print(f"Pages crawled: {len(result.crawled_urls)}")
41+
print(f"Final confidence: {adaptive.confidence:.1%}")
42+
print(f"Stopped reason: {result.metrics.get('stopped_reason', 'max_pages')}")
43+
44+
if result.metrics.get('is_irrelevant', False):
45+
print("⚠️ Query detected as irrelevant!")
46+
47+
return result
48+
49+
50+
async def llm_embedding():
51+
"""Demonstrate various embedding configurations"""
52+
53+
print("EMBEDDING STRATEGY CONFIGURATION EXAMPLES")
54+
print("=" * 60)
55+
56+
# Base URL and query for testing
57+
test_url = "https://docs.python.org/3/library/asyncio.html"
58+
59+
openai_llm_config = LLMConfig(
60+
provider='openai/text-embedding-3-small',
61+
api_token=os.getenv('OPENAI_API_KEY'),
62+
temperature=0.7,
63+
max_tokens=2000
64+
)
65+
config_openai = AdaptiveConfig(
66+
strategy="embedding",
67+
max_pages=10,
68+
69+
# Use OpenAI embeddings
70+
embedding_llm_config=openai_llm_config,
71+
# embedding_llm_config={
72+
# 'provider': 'openai/text-embedding-3-small',
73+
# 'api_token': os.getenv('OPENAI_API_KEY')
74+
# },
75+
76+
# OpenAI embeddings are high quality, can be stricter
77+
embedding_k_exp=4.0,
78+
n_query_variations=12
79+
)
80+
81+
await test_configuration(
82+
"OpenAI Embeddings",
83+
config_openai,
84+
test_url,
85+
# "event-driven architecture patterns"
86+
"async await context managers coroutines"
87+
)
88+
return
89+
90+
91+
92+
async def basic_adaptive_crawling():
93+
"""Basic adaptive crawling example"""
94+
95+
# Initialize the crawler
96+
async with AsyncWebCrawler(verbose=True) as crawler:
97+
# Create an adaptive crawler with default settings (statistical strategy)
98+
adaptive = AdaptiveCrawler(crawler)
99+
100+
# Note: You can also use embedding strategy for semantic understanding:
101+
# from crawl4ai import AdaptiveConfig
102+
# config = AdaptiveConfig(strategy="embedding")
103+
# adaptive = AdaptiveCrawler(crawler, config)
104+
105+
# Start adaptive crawling
106+
print("Starting adaptive crawl for Python async programming information...")
107+
result = await adaptive.digest(
108+
start_url="https://docs.python.org/3/library/asyncio.html",
109+
query="async await context managers coroutines"
110+
)
111+
112+
# Display crawl statistics
113+
print("\n" + "="*50)
114+
print("CRAWL STATISTICS")
115+
print("="*50)
116+
adaptive.print_stats(detailed=False)
117+
118+
# Get the most relevant content found
119+
print("\n" + "="*50)
120+
print("MOST RELEVANT PAGES")
121+
print("="*50)
122+
123+
relevant_pages = adaptive.get_relevant_content(top_k=5)
124+
for i, page in enumerate(relevant_pages, 1):
125+
print(f"\n{i}. {page['url']}")
126+
print(f" Relevance Score: {page['score']:.2%}")
127+
128+
# Show a snippet of the content
129+
content = page['content'] or ""
130+
if content:
131+
snippet = content[:200].replace('\n', ' ')
132+
if len(content) > 200:
133+
snippet += "..."
134+
print(f" Preview: {snippet}")
135+
136+
# Show final confidence
137+
print(f"\n{'='*50}")
138+
print(f"Final Confidence: {adaptive.confidence:.2%}")
139+
print(f"Total Pages Crawled: {len(result.crawled_urls)}")
140+
print(f"Knowledge Base Size: {len(adaptive.state.knowledge_base)} documents")
141+
142+
143+
if adaptive.confidence >= 0.8:
144+
print("✓ High confidence - can answer detailed questions about async Python")
145+
elif adaptive.confidence >= 0.6:
146+
print("~ Moderate confidence - can answer basic questions")
147+
else:
148+
print("✗ Low confidence - need more information")
149+
150+
151+
152+
if __name__ == "__main__":
153+
asyncio.run(llm_embedding())
154+
# asyncio.run(basic_adaptive_crawling())

docs/md_v2/core/adaptive-crawling.md

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,19 @@ config = AdaptiveConfig(
108108
embedding_min_confidence_threshold=0.1 # Stop if completely irrelevant
109109
)
110110

111-
# With custom embedding provider (e.g., OpenAI)
111+
# With custom LLM provider for query expansion (recommended)
112+
from crawl4ai import LLMConfig
113+
114+
config = AdaptiveConfig(
115+
strategy="embedding",
116+
embedding_llm_config=LLMConfig(
117+
provider='openai/text-embedding-3-small',
118+
api_token='your-api-key',
119+
temperature=0.7
120+
)
121+
)
122+
123+
# Alternative: Dictionary format (backward compatible)
112124
config = AdaptiveConfig(
113125
strategy="embedding",
114126
embedding_llm_config={

0 commit comments

Comments
 (0)