Skip to content

Commit 7f44d3b

Browse files
authored
chore: Update WatsonxChatGenerator default model to ibm/granite-4-h-small (#2558)
* Update WatsonxChatGenerator default model to ibm/granite-4.0-h-small * Lint * Update model name
1 parent 9a47156 commit 7f44d3b

File tree

4 files changed

+28
-42
lines changed

4 files changed

+28
-42
lines changed

integrations/watsonx/src/haystack_integrations/components/generators/watsonx/chat/chat_generator.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ class WatsonxChatGenerator:
5959
6060
client = WatsonxChatGenerator(
6161
api_key=Secret.from_env_var("WATSONX_API_KEY"),
62-
model="ibm/granite-13b-chat-v2",
62+
model="ibm/granite-4-h-small",
6363
project_id=Secret.from_env_var("WATSONX_PROJECT_ID"),
6464
)
6565
response = client.run(messages)
@@ -92,7 +92,7 @@ def __init__(
9292
self,
9393
*,
9494
api_key: Secret = Secret.from_env_var("WATSONX_API_KEY"), # noqa: B008
95-
model: str = "ibm/granite-3-3-8b-instruct",
95+
model: str = "ibm/granite-4-h-small",
9696
project_id: Secret = Secret.from_env_var("WATSONX_PROJECT_ID"), # noqa: B008
9797
api_base_url: str = "https://us-south.ml.cloud.ibm.com",
9898
generation_kwargs: dict[str, Any] | None = None,
@@ -110,7 +110,7 @@ def __init__(
110110
111111
:param api_key: IBM Cloud API key for watsonx.ai access.
112112
Can be set via `WATSONX_API_KEY` environment variable or passed directly.
113-
:param model: The model ID to use for completions. Defaults to "ibm/granite-13b-chat-v2".
113+
:param model: The model ID to use for completions. Defaults to "ibm/granite-4-h-small".
114114
Available models can be found in your IBM Cloud account.
115115
:param project_id: IBM Cloud project ID
116116
:param api_base_url: Custom base URL for the API endpoint.

integrations/watsonx/src/haystack_integrations/components/generators/watsonx/generator.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ class WatsonxGenerator(WatsonxChatGenerator):
3838
3939
generator = WatsonxGenerator(
4040
api_key=Secret.from_env_var("WATSONX_API_KEY"),
41-
model="ibm/granite-13b-chat-v2",
41+
model="ibm/granite-4-h-small",
4242
project_id=Secret.from_env_var("WATSONX_PROJECT_ID"),
4343
)
4444
@@ -54,7 +54,7 @@ class WatsonxGenerator(WatsonxChatGenerator):
5454
"replies": ["Quantum computing uses quantum-mechanical phenomena like...."],
5555
"meta": [
5656
{
57-
"model": "ibm/granite-13b-chat-v2",
57+
"model": "ibm/granite-4-h-small",
5858
"project_id": "your-project-id",
5959
"usage": {
6060
"prompt_tokens": 12,
@@ -71,7 +71,7 @@ def __init__(
7171
self,
7272
*,
7373
api_key: Secret = Secret.from_env_var("WATSONX_API_KEY"), # noqa: B008
74-
model: str = "ibm/granite-3-3-8b-instruct",
74+
model: str = "ibm/granite-4-h-small",
7575
project_id: Secret = Secret.from_env_var("WATSONX_PROJECT_ID"), # noqa: B008
7676
api_base_url: str = "https://us-south.ml.cloud.ibm.com",
7777
system_prompt: str | None = None,
@@ -90,7 +90,7 @@ def __init__(
9090
9191
:param api_key: IBM Cloud API key for watsonx.ai access.
9292
Can be set via `WATSONX_API_KEY` environment variable or passed directly.
93-
:param model: The model ID to use for completions. Defaults to "ibm/granite-13b-chat-v2".
93+
:param model: The model ID to use for completions. Defaults to "ibm/granite-4-h-small".
9494
Available models can be found in your IBM Cloud account.
9595
:param project_id: IBM Cloud project ID
9696
:param api_base_url: Custom base URL for the API endpoint.

integrations/watsonx/tests/test_chat_generator.py

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -99,16 +99,14 @@ async def __anext__(self):
9999
yield {"model": mock_model, "model_instance": mock_model_instance, "select_callback": mock_select_callback}
100100

101101
def test_init_default(self, mock_watsonx):
102-
generator = WatsonxChatGenerator(
103-
model="ibm/granite-3-3-8b-instruct", project_id=Secret.from_token("fake-project-id")
104-
)
102+
generator = WatsonxChatGenerator(project_id=Secret.from_token("fake-project-id"))
105103

106104
_, kwargs = mock_watsonx["model"].call_args
107-
assert kwargs["model_id"] == "ibm/granite-3-3-8b-instruct"
105+
assert kwargs["model_id"] == "ibm/granite-4-h-small"
108106
assert kwargs["project_id"] == "fake-project-id"
109107
assert kwargs["verify"] is None
110108

111-
assert generator.model == "ibm/granite-3-3-8b-instruct"
109+
assert generator.model == "ibm/granite-4-h-small"
112110
assert isinstance(generator.project_id, Secret)
113111
assert generator.project_id.resolve_value() == "fake-project-id"
114112
assert generator.api_base_url == "https://us-south.ml.cloud.ibm.com"
@@ -123,7 +121,7 @@ def test_init_with_all_params(self, mock_watsonx):
123121
)
124122

125123
_, kwargs = mock_watsonx["model"].call_args
126-
assert kwargs["model_id"] == "ibm/granite-3-3-8b-instruct"
124+
assert kwargs["model_id"] == "ibm/granite-4-h-small"
127125
assert kwargs["project_id"] == "test-project"
128126
assert kwargs["verify"] is False
129127

@@ -148,7 +146,7 @@ def test_to_dict(self, mock_watsonx):
148146
"type": "haystack_integrations.components.generators.watsonx.chat.chat_generator.WatsonxChatGenerator",
149147
"init_parameters": {
150148
"api_key": {"env_vars": ["WATSONX_API_KEY"], "strict": True, "type": "env_var"},
151-
"model": "ibm/granite-3-3-8b-instruct",
149+
"model": "ibm/granite-4-h-small",
152150
"project_id": {"env_vars": ["WATSONX_PROJECT_ID"], "strict": True, "type": "env_var"},
153151
"api_base_url": "https://us-south.ml.cloud.ibm.com",
154152
"generation_kwargs": {"max_tokens": 100},
@@ -173,7 +171,7 @@ def test_to_dict_with_params(self, mock_watsonx):
173171
"type": "haystack_integrations.components.generators.watsonx.chat.chat_generator.WatsonxChatGenerator",
174172
"init_parameters": {
175173
"api_key": {"env_vars": ["WATSONX_API_KEY"], "strict": True, "type": "env_var"},
176-
"model": "ibm/granite-3-3-8b-instruct",
174+
"model": "ibm/granite-4-h-small",
177175
"project_id": {"env_vars": ["WATSONX_PROJECT_ID"], "strict": True, "type": "env_var"},
178176
"api_base_url": "https://us-south.ml.cloud.ibm.com",
179177
"generation_kwargs": {"max_tokens": 100},
@@ -191,14 +189,14 @@ def test_from_dict(self, mock_watsonx):
191189
"type": "haystack_integrations.components.generators.watsonx.chat.chat_generator.WatsonxChatGenerator",
192190
"init_parameters": {
193191
"api_key": {"env_vars": ["WATSONX_API_KEY"], "strict": True, "type": "env_var"},
194-
"model": "ibm/granite-3-3-8b-instruct",
192+
"model": "ibm/granite-4-h-small",
195193
"project_id": {"env_vars": ["WATSONX_PROJECT_ID"], "strict": True, "type": "env_var"},
196194
"generation_kwargs": {"max_tokens": 100},
197195
},
198196
}
199197

200198
generator = WatsonxChatGenerator.from_dict(data)
201-
assert generator.model == "ibm/granite-3-3-8b-instruct"
199+
assert generator.model == "ibm/granite-4-h-small"
202200
assert isinstance(generator.project_id, Secret)
203201
assert generator.project_id.resolve_value() == "fake-project-id"
204202
assert generator.generation_kwargs == {"max_tokens": 100}
@@ -209,7 +207,7 @@ def test_from_dict_with_callback(self, mock_watsonx):
209207
"type": "haystack_integrations.components.generators.watsonx.chat.chat_generator.WatsonxChatGenerator",
210208
"init_parameters": {
211209
"api_key": {"env_vars": ["WATSONX_API_KEY"], "strict": True, "type": "env_var"},
212-
"model": "ibm/granite-3-3-8b-instruct",
210+
"model": "ibm/granite-4-h-small",
213211
"project_id": {"env_vars": ["WATSONX_PROJECT_ID"], "strict": True, "type": "env_var"},
214212
"streaming_callback": callback_str,
215213
},
@@ -253,9 +251,7 @@ def test_run_with_generation_params(self, mock_watsonx):
253251

254252
def test_run_with_streaming(self, mock_watsonx):
255253
"""Test streaming with callback through parent class"""
256-
generator = WatsonxChatGenerator(
257-
model="ibm/granite-13b-instruct-v2", project_id=Secret.from_token("test-project")
258-
)
254+
generator = WatsonxChatGenerator(project_id=Secret.from_token("test-project"))
259255

260256
mock_callback = MagicMock()
261257
messages = [ChatMessage.from_user("Test prompt")]
@@ -540,7 +536,6 @@ class TestWatsonxChatGeneratorIntegration:
540536
)
541537
def test_live_run(self):
542538
generator = WatsonxChatGenerator(
543-
model="ibm/granite-3-3-8b-instruct",
544539
project_id=Secret.from_env_var("WATSONX_PROJECT_ID"),
545540
generation_kwargs={"max_tokens": 50, "temperature": 0.7, "top_p": 0.9},
546541
)
@@ -560,9 +555,7 @@ def test_live_run(self):
560555
reason="WATSONX_API_KEY or WATSONX_PROJECT_ID not set",
561556
)
562557
def test_live_run_streaming(self):
563-
generator = WatsonxChatGenerator(
564-
model="ibm/granite-3-3-8b-instruct", project_id=Secret.from_env_var("WATSONX_PROJECT_ID")
565-
)
558+
generator = WatsonxChatGenerator(project_id=Secret.from_env_var("WATSONX_PROJECT_ID"))
566559
collected_chunks = []
567560

568561
def callback(chunk: StreamingChunk):
@@ -585,9 +578,7 @@ def callback(chunk: StreamingChunk):
585578
reason="WATSONX_API_KEY or WATSONX_PROJECT_ID not set",
586579
)
587580
async def test_live_run_async(self):
588-
generator = WatsonxChatGenerator(
589-
model="ibm/granite-3-3-8b-instruct", project_id=Secret.from_env_var("WATSONX_PROJECT_ID")
590-
)
581+
generator = WatsonxChatGenerator(project_id=Secret.from_env_var("WATSONX_PROJECT_ID"))
591582
messages = [ChatMessage.from_user("What's the capital of Germany? Answer concisely.")]
592583
results = await generator.run_async(messages=messages)
593584

integrations/watsonx/tests/test_generator.py

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -85,11 +85,11 @@ def test_init_default(self, mock_watsonx):
8585
generator = WatsonxGenerator(project_id=Secret.from_token("fake-project-id"))
8686

8787
_, kwargs = mock_watsonx["model"].call_args
88-
assert kwargs["model_id"] == "ibm/granite-3-3-8b-instruct"
88+
assert kwargs["model_id"] == "ibm/granite-4-h-small"
8989
assert kwargs["project_id"] == "fake-project-id"
9090
assert kwargs["verify"] is None
9191

92-
assert generator.model == "ibm/granite-3-3-8b-instruct"
92+
assert generator.model == "ibm/granite-4-h-small"
9393
assert isinstance(generator.project_id, Secret)
9494
assert generator.project_id.resolve_value() == "fake-project-id"
9595
assert generator.api_base_url == "https://us-south.ml.cloud.ibm.com"
@@ -104,7 +104,7 @@ def test_init_with_all_params(self, mock_watsonx):
104104
)
105105

106106
_, kwargs = mock_watsonx["model"].call_args
107-
assert kwargs["model_id"] == "ibm/granite-3-3-8b-instruct"
107+
assert kwargs["model_id"] == "ibm/granite-4-h-small"
108108
assert kwargs["project_id"] == "test-project"
109109
assert kwargs["verify"] is False
110110

@@ -126,7 +126,7 @@ def test_to_dict(self, mock_watsonx):
126126
"type": "haystack_integrations.components.generators.watsonx.generator.WatsonxGenerator",
127127
"init_parameters": {
128128
"api_key": {"env_vars": ["WATSONX_API_KEY"], "strict": True, "type": "env_var"},
129-
"model": "ibm/granite-3-3-8b-instruct",
129+
"model": "ibm/granite-4-h-small",
130130
"project_id": {"env_vars": ["WATSONX_PROJECT_ID"], "strict": True, "type": "env_var"},
131131
"api_base_url": "https://us-south.ml.cloud.ibm.com",
132132
"generation_kwargs": {"max_tokens": 100},
@@ -144,7 +144,7 @@ def test_from_dict(self, mock_watsonx):
144144
"type": "haystack_integrations.components.generators.watsonx.generator.WatsonxGenerator",
145145
"init_parameters": {
146146
"api_key": {"env_vars": ["WATSONX_API_KEY"], "strict": True, "type": "env_var"},
147-
"model": "ibm/granite-3-3-8b-instruct",
147+
"model": "ibm/granite-4-h-small",
148148
"project_id": {"env_vars": ["WATSONX_PROJECT_ID"], "strict": True, "type": "env_var"},
149149
"api_base_url": "https://us-south.ml.cloud.ibm.com",
150150
"generation_kwargs": {"max_tokens": 100},
@@ -158,7 +158,7 @@ def test_from_dict(self, mock_watsonx):
158158

159159
generator = WatsonxGenerator.from_dict(data)
160160
assert generator.api_key == Secret.from_env_var("WATSONX_API_KEY")
161-
assert generator.model == "ibm/granite-3-3-8b-instruct"
161+
assert generator.model == "ibm/granite-4-h-small"
162162
assert generator.project_id == Secret.from_env_var("WATSONX_PROJECT_ID")
163163
assert generator.api_base_url == "https://us-south.ml.cloud.ibm.com"
164164
assert generator.generation_kwargs == {"max_tokens": 100}
@@ -221,7 +221,7 @@ def test_run_with_generation_kwargs(self, mock_watsonx):
221221
)
222222

223223
def test_run_with_streaming(self, mock_watsonx):
224-
generator = WatsonxGenerator(model="ibm/granite-13b-instruct-v2", project_id=Secret.from_token("test-project"))
224+
generator = WatsonxGenerator(project_id=Secret.from_token("test-project"))
225225

226226
mock_callback = MagicMock()
227227
result = generator.run(prompt="Test prompt", streaming_callback=mock_callback)
@@ -357,7 +357,6 @@ class TestWatsonxGeneratorIntegration:
357357
)
358358
def test_live_run(self):
359359
generator = WatsonxGenerator(
360-
model="ibm/granite-3-3-8b-instruct",
361360
project_id=Secret.from_env_var("WATSONX_PROJECT_ID"),
362361
generation_kwargs={"max_tokens": 50, "temperature": 0.7, "top_p": 0.9},
363362
)
@@ -383,9 +382,7 @@ def test_live_run(self):
383382
reason="WATSONX_API_KEY or WATSONX_PROJECT_ID not set",
384383
)
385384
def test_live_run_streaming(self):
386-
generator = WatsonxGenerator(
387-
model="ibm/granite-3-3-8b-instruct", project_id=Secret.from_env_var("WATSONX_PROJECT_ID")
388-
)
385+
generator = WatsonxGenerator(project_id=Secret.from_env_var("WATSONX_PROJECT_ID"))
389386

390387
collected_chunks = []
391388

@@ -411,9 +408,7 @@ def callback(chunk: StreamingChunk):
411408
reason="WATSONX_API_KEY or WATSONX_PROJECT_ID not set",
412409
)
413410
async def test_live_run_async(self):
414-
generator = WatsonxGenerator(
415-
model="ibm/granite-3-3-8b-instruct", project_id=Secret.from_env_var("WATSONX_PROJECT_ID")
416-
)
411+
generator = WatsonxGenerator(project_id=Secret.from_env_var("WATSONX_PROJECT_ID"))
417412

418413
result = await generator.run_async(prompt="What's the capital of Germany? Answer concisely.")
419414

0 commit comments

Comments
 (0)