Skip to content

Commit 27ac8db

Browse files
committed
TT-3084-clarifai-dspy-integration
1 parent 8c6ba34 commit 27ac8db

File tree

5 files changed

+204
-7
lines changed

5 files changed

+204
-7
lines changed

dsp/modules/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from .cohere import *
77
from .sbert import *
88
from .pyserini import *
9+
from .clarifai import *
910

1011
from .hf_client import HFClientTGI
1112
from .hf_client import Anyscale

dsp/modules/clarifai.py

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import math
2+
from typing import Any, Optional
3+
import backoff
4+
5+
from dsp.modules.lm import LM
6+
7+
try:
8+
from clarifai.client.model import Model
9+
except ImportError:
10+
raise ImportError("ClarifaiLLM requires `pip install clarifai`.")
11+
12+
class ClarifaiLLM(LM):
13+
"""Integration to call models hosted in clarifai platform."""
14+
15+
def __init__(
16+
self,
17+
model: str = "https://clarifai.com/mistralai/completion/models/mistral-7B-Instruct", #defaults to mistral-7B-Instruct
18+
api_key: Optional[str] = None,
19+
stop_sequences: list[str] = [],
20+
**kwargs,
21+
):
22+
super().__init__(model)
23+
24+
self.provider = "clarifai"
25+
self.pat=api_key
26+
self._model= Model(url=model, pat=api_key)
27+
self.kwargs = {
28+
"n": 1,
29+
**kwargs
30+
}
31+
self.history :list[dict[str, Any]] = []
32+
self.kwargs['temperature'] = (
33+
self.kwargs['inference_params']['temperature'] if
34+
'inference_params' in self.kwargs and
35+
'temperature' in self.kwargs['inference_params'] else 0.0
36+
)
37+
self.kwargs['max_tokens'] = (
38+
self.kwargs['inference_params']['max_tokens'] if
39+
'inference_params' in self.kwargs and
40+
'max_tokens' in self.kwargs['inference_params'] else 150
41+
)
42+
43+
def basic_request(self, prompt, **kwargs):
44+
45+
params = (
46+
self.kwargs['inference_params'] if 'inference_params' in self.kwargs
47+
else {}
48+
)
49+
response = (
50+
self._model.predict_by_bytes(
51+
input_bytes= prompt.encode(encoding="utf-8"),
52+
input_type= "text",
53+
inference_params= params
54+
).outputs[0].data.text.raw
55+
56+
)
57+
58+
history = {
59+
"prompt": prompt,
60+
"response": response,
61+
"kwargs": kwargs,
62+
}
63+
self.history.append(history)
64+
return response
65+
66+
def _get_choice_text(self, choice: dict[str, Any]) -> str:
67+
return choice
68+
69+
def request(self, prompt: str, **kwargs):
70+
return self.basic_request(prompt, **kwargs)
71+
72+
def __call__(self,
73+
prompt: str,
74+
only_completed: bool = True,
75+
return_sorted: bool = False,
76+
**kwargs
77+
):
78+
assert only_completed, "for now"
79+
assert return_sorted is False, "for now"
80+
81+
n = kwargs.pop("n", 1)
82+
completions=[]
83+
84+
for i in range(n):
85+
response = self.request(prompt, **kwargs)
86+
completions.append(response)
87+
88+
return completions

dsp/modules/lm.py

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -45,14 +45,23 @@ def inspect_history(self, n: int = 1, skip: int = 0):
4545
prompt = x["prompt"]
4646

4747
if prompt != last_prompt:
48-
printed.append(
49-
(
50-
prompt,
51-
x["response"].generations
52-
if provider == "cohere"
53-
else x["response"]["choices"],
48+
49+
if provider=="clarifai":
50+
printed.append(
51+
(
52+
prompt,
53+
x['response']
54+
)
55+
)
56+
else:
57+
printed.append(
58+
(
59+
prompt,
60+
x["response"].generations
61+
if provider == "cohere"
62+
else x["response"]["choices"],
63+
)
5464
)
55-
)
5665

5766
last_prompt = prompt
5867

@@ -71,6 +80,8 @@ def inspect_history(self, n: int = 1, skip: int = 0):
7180
text = choices[0].text
7281
elif provider == "openai" or provider == "ollama":
7382
text = ' ' + self._get_choice_text(choices[0]).strip()
83+
elif provider == "clarifai":
84+
text=choices
7485
else:
7586
text = choices[0]["text"]
7687
self.print_green(text, end="")

dspy/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
OpenAI = dsp.GPT3
1919
ColBERTv2 = dsp.ColBERTv2
2020
Pyserini = dsp.PyseriniRetriever
21+
Clarifai=dsp.ClarifaiLLM
2122

2223
HFClientTGI = dsp.HFClientTGI
2324
HFClientVLLM = HFClientVLLM

dspy/retrieve/clarifai_rm.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
from collections import defaultdict
2+
from typing import List, Union
3+
import os
4+
import dspy
5+
from dsp.utils import dotdict
6+
import requests
7+
from typing import Optional
8+
from concurrent.futures import ThreadPoolExecutor
9+
10+
try:
11+
from clarifai.client.search import Search
12+
from google.protobuf import json_format
13+
except ImportError:
14+
raise ImportError(
15+
"Clarifai is not installed. Install it using `pip install clarifai`"
16+
)
17+
18+
19+
class ClarifaiRM(dspy.Retrieve):
20+
"""
21+
Retrieval module uses clarifai to return the Top K relevant pasages for the given query.
22+
"""
23+
24+
def __init__(self,
25+
clarifai_user_id: str,
26+
clarfiai_app_id: str,
27+
clarifai_pat: Optional[str] = None,
28+
k: int = 3,
29+
30+
):
31+
self.app_id = clarfiai_app_id
32+
self.user_id = clarifai_user_id
33+
self.pat = clarifai_pat if clarifai_pat is not None else os.environ["CLARIFAI_PAT"]
34+
self.k=k
35+
36+
super().__init__(k=k)
37+
38+
def forward(self, query_or_queries: Union[str, List[str]], k: Optional[int]):
39+
40+
"""Uses clarifai-python SDK search function and retrieves top_k similar passages for given query,
41+
Args:
42+
query_or_queries : single query or list of queries
43+
k : Top K relevant documents to return
44+
45+
Returns:
46+
passages in format of dotdict
47+
48+
Examples:
49+
Below is a code snippet that shows how to use Marqo as the default retriver:
50+
```python
51+
import clarifai
52+
53+
llm = dspy.Clarifai(model=MODEL_URL, api_key="YOUR CLARIFAI_PAT")
54+
retriever_model = ClarifaiRM(clarifai_user_id="USER_ID", clarfiai_app_id="APP_ID", clarifai_pat="YOUR CLARIFAI_PAT")
55+
dspy.settings.configure(lm=llm, rm=retriever_model)
56+
```
57+
"""
58+
queries = (
59+
[query_or_queries]
60+
if isinstance(query_or_queries, str)
61+
else query_or_queries
62+
)
63+
k = k if k is not None else self.k
64+
passages = []
65+
queries = [q for q in queries if q]
66+
67+
for query in queries:
68+
clarifai_search = Search(user_id=self.user_id, app_id=self.app_id, top_k=k, pat=self.pat)
69+
search_response= clarifai_search.query(ranks=[{"text_raw": query}])
70+
71+
# Retrieve hits
72+
hits=[hit for data in search_response for hit in data.hits]
73+
executor = ThreadPoolExecutor(max_workers=10)
74+
75+
def retrieve_hits(hits):
76+
header = {"Authorization": f"Key {self.pat}"}
77+
request = requests.get(hits.input.data.text.url, headers=header)
78+
request.encoding = request.apparent_encoding
79+
requested_text = request.text
80+
return requested_text
81+
82+
futures = [executor.submit(retrieve_hits, hit) for hit in hits]
83+
results = [future.result() for future in futures]
84+
passages=[dotdict({"long_text": d}) for d in results]
85+
86+
return passages
87+
88+
89+
90+
91+
92+
93+
94+
95+
96+

0 commit comments

Comments
 (0)