Skip to content

Commit 50c446a

Browse files
committed
addressed comments
1 parent 27ac8db commit 50c446a

File tree

2 files changed

+49
-48
lines changed

2 files changed

+49
-48
lines changed

dsp/modules/clarifai.py

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,30 @@
1-
import math
1+
"""Clarifai LM integration"""
22
from typing import Any, Optional
3-
import backoff
4-
53
from dsp.modules.lm import LM
64

75
try:
86
from clarifai.client.model import Model
9-
except ImportError:
10-
raise ImportError("ClarifaiLLM requires `pip install clarifai`.")
7+
except ImportError as err:
8+
raise ImportError("ClarifaiLLM requires `pip install clarifai`.") from err
119

1210
class ClarifaiLLM(LM):
13-
"""Integration to call models hosted in clarifai platform."""
11+
"""Integration to call models hosted in clarifai platform.
12+
13+
Args:
14+
model (str, optional): Clarifai URL of the model. Defaults to "Mistral-7B-Instruct".
15+
api_key (Optional[str], optional): CLARIFAI_PAT token. Defaults to None.
16+
**kwargs: Additional arguments to pass to the API provider.
17+
Example:
18+
import dspy
19+
dspy.configure(lm=dspy.Clarifai(model=MODEL_URL,
20+
api_key=CLARIFAI_PAT,
21+
inference_params={"max_tokens":100,'temperature':0.6}))
22+
"""
1423

1524
def __init__(
1625
self,
17-
model: str = "https://clarifai.com/mistralai/completion/models/mistral-7B-Instruct", #defaults to mistral-7B-Instruct
26+
model: str = "https://clarifai.com/mistralai/completion/models/mistral-7B-Instruct",
1827
api_key: Optional[str] = None,
19-
stop_sequences: list[str] = [],
2028
**kwargs,
2129
):
2230
super().__init__(model)
@@ -41,7 +49,6 @@ def __init__(
4149
)
4250

4351
def basic_request(self, prompt, **kwargs):
44-
4552
params = (
4653
self.kwargs['inference_params'] if 'inference_params' in self.kwargs
4754
else {}
@@ -50,11 +57,10 @@ def basic_request(self, prompt, **kwargs):
5057
self._model.predict_by_bytes(
5158
input_bytes= prompt.encode(encoding="utf-8"),
5259
input_type= "text",
53-
inference_params= params
54-
).outputs[0].data.text.raw
55-
60+
inference_params= params,
61+
).outputs[0].data.text.raw
5662
)
57-
63+
kwargs = {**self.kwargs, **kwargs}
5864
history = {
5965
"prompt": prompt,
6066
"response": response,
@@ -63,9 +69,6 @@ def basic_request(self, prompt, **kwargs):
6369
self.history.append(history)
6470
return response
6571

66-
def _get_choice_text(self, choice: dict[str, Any]) -> str:
67-
return choice
68-
6972
def request(self, prompt: str, **kwargs):
7073
return self.basic_request(prompt, **kwargs)
7174

dspy/retrieve/clarifai_rm.py

Lines changed: 30 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from collections import defaultdict
1+
"""Clarifai as retriver to retrieve hits"""
22
from typing import List, Union
33
import os
44
import dspy
@@ -9,16 +9,25 @@
99

1010
try:
1111
from clarifai.client.search import Search
12-
from google.protobuf import json_format
13-
except ImportError:
12+
except ImportError as err:
1413
raise ImportError(
1514
"Clarifai is not installed. Install it using `pip install clarifai`"
16-
)
15+
) from err
1716

1817

1918
class ClarifaiRM(dspy.Retrieve):
2019
"""
2120
Retrieval module uses clarifai to return the Top K relevant pasages for the given query.
21+
Assuming that you have ingested the source documents into clarifai App, where it is indexed and stored.
22+
23+
Args:
24+
clarifai_user_id (str): Clarifai unique user_id.
25+
clarfiai_app_id (str): Clarifai App ID, where the documents are stored.
26+
clarifai_pat (str): Clarifai PAT key.
27+
k (int): Top K documents to retrieve.
28+
29+
Examples:
30+
TODO
2231
"""
2332

2433
def __init__(self,
@@ -32,10 +41,18 @@ def __init__(self,
3241
self.user_id = clarifai_user_id
3342
self.pat = clarifai_pat if clarifai_pat is not None else os.environ["CLARIFAI_PAT"]
3443
self.k=k
35-
44+
3645
super().__init__(k=k)
3746

38-
def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int]):
47+
def retrieve_hits(self, hits):
48+
header = {"Authorization": f"Key {self.pat}"}
49+
request = requests.get(hits.input.data.text.url, headers=header)
50+
request.encoding = request.apparent_encoding
51+
requested_text = request.text
52+
return requested_text
53+
54+
def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int] = None
55+
) -> dspy.Prediction:
3956

4057
"""Uses clarifai-python SDK search function and retrieves top_k similar passages for given query,
4158
Args:
@@ -49,7 +66,6 @@ def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int]):
4966
Below is a code snippet that shows how to use Marqo as the default retriver:
5067
```python
5168
import clarifai
52-
5369
llm = dspy.Clarifai(model=MODEL_URL, api_key="YOUR CLARIFAI_PAT")
5470
retriever_model = ClarifaiRM(clarifai_user_id="USER_ID", clarfiai_app_id="APP_ID", clarifai_pat="YOUR CLARIFAI_PAT")
5571
dspy.settings.configure(lm=llm, rm=retriever_model)
@@ -60,37 +76,19 @@ def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int]):
6076
if isinstance(query_or_queries, str)
6177
else query_or_queries
6278
)
63-
k = k if k is not None else self.k
79+
80+
k = self.k if self.k is not None else k
6481
passages = []
6582
queries = [q for q in queries if q]
83+
clarifai_search = Search(user_id=self.user_id, app_id=self.app_id, top_k=k, pat=self.pat)
6684

6785
for query in queries:
68-
clarifai_search = Search(user_id=self.user_id, app_id=self.app_id, top_k=k, pat=self.pat)
6986
search_response= clarifai_search.query(ranks=[{"text_raw": query}])
7087

7188
# Retrieve hits
7289
hits=[hit for data in search_response for hit in data.hits]
73-
executor = ThreadPoolExecutor(max_workers=10)
74-
75-
def retrieve_hits(hits):
76-
header = {"Authorization": f"Key {self.pat}"}
77-
request = requests.get(hits.input.data.text.url, headers=header)
78-
request.encoding = request.apparent_encoding
79-
requested_text = request.text
80-
return requested_text
81-
82-
futures = [executor.submit(retrieve_hits, hit) for hit in hits]
83-
results = [future.result() for future in futures]
84-
passages=[dotdict({"long_text": d}) for d in results]
85-
86-
return passages
87-
88-
89-
90+
with ThreadPoolExecutor(max_workers=10) as executor:
91+
results = list(executor.map(self.retrieve_hits, hits))
92+
passages.extend(dotdict({"long_text": d}) for d in results)
9093

91-
92-
93-
94-
95-
96-
94+
return passages

0 commit comments

Comments
 (0)