Skip to content

Commit a8643ca

Browse files
Remove unsupported kargs from Bedrock calls.
1 parent 251e4aa commit a8643ca

File tree

1 file changed

+39
-0
lines changed

1 file changed

+39
-0
lines changed

singlestoredb/ai/chat.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,31 @@ def _maybe_inject_headers(self, kwargs: dict[str, Any]) -> None:
257257
if call_headers:
258258
kwargs['headers'] = call_headers
259259

260+
# ------------------------------------------------------------------
261+
# Bedrock kwargs sanitation
262+
# ------------------------------------------------------------------
263+
def _sanitize_bedrock_kwargs(self, kwargs: dict[str, Any]) -> None:
264+
"""Remove or adapt kwargs not supported by ChatBedrockConverse.
265+
266+
Currently strips keys that would raise TypeError in
267+
ChatBedrockConverse._converse_params (e.g. 'parallelToolCalls').
268+
This guards against passing OpenAI/other provider specific
269+
parameters straight through to Bedrock.
270+
"""
271+
if self._backend_type != 'bedrock': # only relevant for bedrock backend
272+
return
273+
unsupported = {'parallelToolCalls', 'parallel_tool_calls'}
274+
# Direct kwargs
275+
for key in list(kwargs.keys()):
276+
if key in unsupported:
277+
kwargs.pop(key)
278+
# Nested model_kwargs if present
279+
mk = kwargs.get('model_kwargs')
280+
if isinstance(mk, dict):
281+
for key in list(mk.keys()):
282+
if key in unsupported:
283+
mk.pop(key)
284+
260285
def as_base(self) -> Any:
261286
"""Return the underlying backend client instance.
262287
@@ -267,14 +292,17 @@ def as_base(self) -> Any:
267292

268293
def invoke(self, *args: Any, **kwargs: Any) -> Any:
269294
self._maybe_inject_headers(kwargs)
295+
self._sanitize_bedrock_kwargs(kwargs)
270296
return self._client.invoke(*args, **kwargs)
271297

272298
async def ainvoke(self, *args: Any, **kwargs: Any) -> Any:
273299
self._maybe_inject_headers(kwargs)
300+
self._sanitize_bedrock_kwargs(kwargs)
274301
return await self._client.ainvoke(*args, **kwargs)
275302

276303
def stream(self, *args: Any, **kwargs: Any) -> Any:
277304
self._maybe_inject_headers(kwargs)
305+
self._sanitize_bedrock_kwargs(kwargs)
278306
return self._client.stream(*args, **kwargs)
279307

280308
async def astream(
@@ -283,6 +311,7 @@ async def astream(
283311
**kwargs: Any,
284312
) -> AsyncIterator[Any]:
285313
self._maybe_inject_headers(kwargs)
314+
self._sanitize_bedrock_kwargs(kwargs)
286315
async for chunk in self._client.astream(*args, **kwargs):
287316
yield chunk
288317

@@ -292,14 +321,17 @@ async def astream(
292321
# ------------------------------------------------------------------
293322
def generate(self, *args: Any, **kwargs: Any) -> Any:
294323
self._maybe_inject_headers(kwargs)
324+
self._sanitize_bedrock_kwargs(kwargs)
295325
return self._client.generate(*args, **kwargs)
296326

297327
async def agenerate(self, *args: Any, **kwargs: Any) -> Any:
298328
self._maybe_inject_headers(kwargs)
329+
self._sanitize_bedrock_kwargs(kwargs)
299330
return await self._client.agenerate(*args, **kwargs)
300331

301332
def predict(self, *args: Any, **kwargs: Any) -> Any:
302333
self._maybe_inject_headers(kwargs)
334+
self._sanitize_bedrock_kwargs(kwargs)
303335
return self._client.predict(*args, **kwargs)
304336

305337
async def apredict(
@@ -308,6 +340,7 @@ async def apredict(
308340
**kwargs: Any,
309341
) -> Any:
310342
self._maybe_inject_headers(kwargs)
343+
self._sanitize_bedrock_kwargs(kwargs)
311344
return await self._client.apredict(*args, **kwargs)
312345

313346
def predict_messages(
@@ -316,6 +349,7 @@ def predict_messages(
316349
**kwargs: Any,
317350
) -> Any:
318351
self._maybe_inject_headers(kwargs)
352+
self._sanitize_bedrock_kwargs(kwargs)
319353
return self._client.predict_messages(*args, **kwargs)
320354

321355
async def apredict_messages(
@@ -324,18 +358,22 @@ async def apredict_messages(
324358
**kwargs: Any,
325359
) -> Any:
326360
self._maybe_inject_headers(kwargs)
361+
self._sanitize_bedrock_kwargs(kwargs)
327362
return await self._client.apredict_messages(*args, **kwargs)
328363

329364
def batch(self, *args: Any, **kwargs: Any) -> Any:
330365
self._maybe_inject_headers(kwargs)
366+
self._sanitize_bedrock_kwargs(kwargs)
331367
return self._client.batch(*args, **kwargs)
332368

333369
async def abatch(self, *args: Any, **kwargs: Any) -> Any:
334370
self._maybe_inject_headers(kwargs)
371+
self._sanitize_bedrock_kwargs(kwargs)
335372
return await self._client.abatch(*args, **kwargs)
336373

337374
def apply(self, *args: Any, **kwargs: Any) -> Any:
338375
self._maybe_inject_headers(kwargs)
376+
self._sanitize_bedrock_kwargs(kwargs)
339377
return self._client.apply(*args, **kwargs)
340378

341379
async def aapply(
@@ -344,6 +382,7 @@ async def aapply(
344382
**kwargs: Any,
345383
) -> Any:
346384
self._maybe_inject_headers(kwargs)
385+
self._sanitize_bedrock_kwargs(kwargs)
347386
return await self._client.aapply(*args, **kwargs)
348387

349388
def __repr__(self) -> str:

0 commit comments

Comments
 (0)