1- from collections import defaultdict
1+ """Clarifai as retriver to retrieve hits"""
22from typing import List , Union
33import os
44import dspy
99
1010try :
1111 from clarifai .client .search import Search
12- from google .protobuf import json_format
13- except ImportError :
12+ except ImportError as err :
1413 raise ImportError (
1514 "Clarifai is not installed. Install it using `pip install clarifai`"
16- )
15+ ) from err
1716
1817
1918class ClarifaiRM (dspy .Retrieve ):
2019 """
2120 Retrieval module uses clarifai to return the Top K relevant pasages for the given query.
21+ Assuming that you have ingested the source documents into clarifai App, where it is indexed and stored.
22+
23+ Args:
24+ clarifai_user_id (str): Clarifai unique user_id.
25+ clarfiai_app_id (str): Clarifai App ID, where the documents are stored.
26+ clarifai_pat (str): Clarifai PAT key.
27+ k (int): Top K documents to retrieve.
28+
29+ Examples:
30+ TODO
2231 """
2332
2433 def __init__ (self ,
@@ -32,10 +41,18 @@ def __init__(self,
3241 self .user_id = clarifai_user_id
3342 self .pat = clarifai_pat if clarifai_pat is not None else os .environ ["CLARIFAI_PAT" ]
3443 self .k = k
35-
44+
3645 super ().__init__ (k = k )
3746
38- def forward (self , query_or_queries : Union [str , List [str ]], k : Optional [int ]):
47+ def retrieve_hits (self , hits ):
48+ header = {"Authorization" : f"Key { self .pat } " }
49+ request = requests .get (hits .input .data .text .url , headers = header )
50+ request .encoding = request .apparent_encoding
51+ requested_text = request .text
52+ return requested_text
53+
54+ def forward (self , query_or_queries : Union [str , List [str ]], k : Optional [int ] = None
55+ ) -> dspy .Prediction :
3956
4057 """Uses clarifai-python SDK search function and retrieves top_k similar passages for given query,
4158 Args:
@@ -49,7 +66,6 @@ def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int]):
4966 Below is a code snippet that shows how to use Marqo as the default retriver:
5067 ```python
5168 import clarifai
52-
5369 llm = dspy.Clarifai(model=MODEL_URL, api_key="YOUR CLARIFAI_PAT")
5470 retriever_model = ClarifaiRM(clarifai_user_id="USER_ID", clarfiai_app_id="APP_ID", clarifai_pat="YOUR CLARIFAI_PAT")
5571 dspy.settings.configure(lm=llm, rm=retriever_model)
@@ -60,37 +76,19 @@ def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int]):
6076 if isinstance (query_or_queries , str )
6177 else query_or_queries
6278 )
63- k = k if k is not None else self .k
79+
80+ k = self .k if self .k is not None else k
6481 passages = []
6582 queries = [q for q in queries if q ]
83+ clarifai_search = Search (user_id = self .user_id , app_id = self .app_id , top_k = k , pat = self .pat )
6684
6785 for query in queries :
68- clarifai_search = Search (user_id = self .user_id , app_id = self .app_id , top_k = k , pat = self .pat )
6986 search_response = clarifai_search .query (ranks = [{"text_raw" : query }])
7087
7188 # Retrieve hits
7289 hits = [hit for data in search_response for hit in data .hits ]
73- executor = ThreadPoolExecutor (max_workers = 10 )
74-
75- def retrieve_hits (hits ):
76- header = {"Authorization" : f"Key { self .pat } " }
77- request = requests .get (hits .input .data .text .url , headers = header )
78- request .encoding = request .apparent_encoding
79- requested_text = request .text
80- return requested_text
81-
82- futures = [executor .submit (retrieve_hits , hit ) for hit in hits ]
83- results = [future .result () for future in futures ]
84- passages = [dotdict ({"long_text" : d }) for d in results ]
85-
86- return passages
87-
88-
89-
90+ with ThreadPoolExecutor (max_workers = 10 ) as executor :
91+ results = list (executor .map (self .retrieve_hits , hits ))
92+ passages .extend (dotdict ({"long_text" : d }) for d in results )
9093
91-
92-
93-
94-
95-
96-
94+ return passages
0 commit comments