Skip to content

Commit 2603ba6

Browse files
committed
Fix bugs
1 parent 337e0bd commit 2603ba6

File tree

3 files changed

+38
-22
lines changed

3 files changed

+38
-22
lines changed

aifn/client.py

Lines changed: 23 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import nest_asyncio
23
import base64
34
import os
45
import warnings
@@ -613,15 +614,15 @@ def _validate_batch_query(self, batch_inputs: List[Dict[str, Any]]):
613614
if not all(isinstance(img, str) for inp in batch_inputs for img in inp.get("images_input", [])):
614615
raise ValueError("Images input must be a list of strings.")
615616

616-
def _batch_query(
617+
async def _batch_query(
617618
self,
618619
fn_name: str,
619620
version: Optional[Union[str, int]],
620621
batch_inputs: List[Dict[str, Any]],
621622
return_reasoning: bool,
622623
strict: bool,
623624
) -> List[NamedTuple]:
624-
"""Internal method to handle both synchronous and asynchronous batch query requests.
625+
"""Internal method to handle asynchronous batch query requests.
625626
626627
Parameters
627628
----------
@@ -644,21 +645,18 @@ def _batch_query(
644645
645646
Returns
646647
-------
647-
Union[List[NamedTuple], Coroutine[Any, Any, List[NamedTuple]]]
648-
A list of NamedTuples, each containing the output and metadata of the response, or a coroutine that returns such a list.
648+
List[NamedTuple]
649+
A list of NamedTuples, each containing the output and metadata of the response.
649650
"""
650651

651652
# Validate the batch inputs
652653
self._validate_batch_query(batch_inputs=batch_inputs)
653654

654-
async def run_queries():
655-
tasks = [
656-
self.aquery(fn_name=fn_name, version=version, return_reasoning=return_reasoning, strict=strict, **fn_input)
657-
for fn_input in batch_inputs
658-
]
659-
return await asyncio.gather(*tasks)
660-
661-
return run_queries()
655+
tasks = [
656+
self.aquery(fn_name=fn_name, version=version, return_reasoning=return_reasoning, strict=strict, **fn_input)
657+
for fn_input in batch_inputs
658+
]
659+
return await asyncio.gather(*tasks)
662660

663661
async def abatch_query(
664662
self,
@@ -732,8 +730,17 @@ def batch_query(
732730
List[NamedTuple]
733731
A list of NamedTuples, each containing the output and metadata of the response.
734732
"""
735-
return asyncio.run(
736-
self._batch_query(
733+
try:
734+
# Check if an event loop is already running
735+
loop = asyncio.get_running_loop()
736+
if loop.is_running():
737+
nest_asyncio.apply()
738+
task = loop.create_task(self._batch_query(
737739
fn_name=fn_name, version=version, batch_inputs=batch_inputs, return_reasoning=return_reasoning, strict=strict
738-
)
739-
)
740+
))
741+
return loop.run_until_complete(task)
742+
except RuntimeError:
743+
# If no event loop is running, use asyncio.run
744+
return asyncio.run(self._batch_query(
745+
fn_name=fn_name, version=version, batch_inputs=batch_inputs, return_reasoning=return_reasoning, strict=strict
746+
))

examples/cookbook.ipynb

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,11 +38,12 @@
3838
},
3939
{
4040
"cell_type": "code",
41-
"execution_count": null,
41+
"execution_count": 1,
4242
"metadata": {},
4343
"outputs": [],
4444
"source": [
45-
"%env WECO_API_KEY=<YOUR_WECO_API_KEY>"
45+
"import os\n",
46+
"os.environ[\"WECO_API_KEY\"] = \"YOUR_WECO_API_KEY\""
4647
]
4748
},
4849
{
@@ -214,13 +215,11 @@
214215
"source": [
215216
"task_evaluator = build(task_description=\"I want to know if AI can solve a problem for me, how easy it is to arrive at a solution and whether any helpful tips for me along the way. Help me understand this through - 'feasibility', 'justification', and 'suggestions'.\")\n",
216217
"\n",
217-
"\n",
218218
"task1 = {\n",
219219
" \"text_input\": \"I want to train a model to predict house prices using the Boston Housing dataset hosted on Kaggle.\"\n",
220220
"}\n",
221221
"task2 = {\n",
222222
" \"text_input\": \"I want to train a model to classify digits using the MNIST dataset hosted on Kaggle using a Google Colab notebook. Attached is an example of what some of the digits would look like.\",\n",
223-
" \"images_input\": [\"https://machinelearningmastery.com/wp-content/uploads/2019/02/Plot-of-a-Subset-of-Images-from-the-MNIST-Dataset-1024x768.png\"]\n",
224223
"}\n",
225224
"responses = task_evaluator.batch([task1, task2])\n",
226225
"for response in responses:\n",
@@ -352,10 +351,16 @@
352351
"task_evaluator = build(task_description=\"I want to know if AI can solve a problem for me, how easy it is to arrive at a solution and whether any helpful tips for me along the way. Help me understand this through - 'feasibility', 'justification', and 'suggestions'.\")\n",
353352
"\n",
354353
"output, metadata = task_evaluator(\"I want to train a model to predict house prices using the Boston Housing dataset hosted on Kaggle.\", return_reasoning=True)\n",
355-
"reasoning_steps = response[\"reasoning_steps\"]\n",
356354
"for key, value in output.items(): print(f\"{key}: {value}\")\n",
357355
"for i, step in enumerate(metadata[\"reasoning_steps\"]): print(f\"Step {i+1}: {step}\")"
358356
]
357+
},
358+
{
359+
"cell_type": "code",
360+
"execution_count": null,
361+
"metadata": {},
362+
"outputs": [],
363+
"source": []
359364
}
360365
],
361366
"metadata": {

pyproject.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ readme = "README.md"
1313
version = "0.2.0"
1414
license = {text = "MIT"}
1515
requires-python = ">=3.8"
16-
dependencies = ["requests", "asyncio", "httpx[http2]", "pillow"]
16+
dependencies = ["requests", "asyncio", "nest_asyncio", "httpx[http2]", "pillow"]
1717
keywords = ["AI", "LLM", "VLM", "AI functions", "Prompt Engineering", "NLP"]
1818
classifiers = [
1919
"Programming Language :: Python :: 3",
@@ -33,6 +33,10 @@ packages = ["aifn"]
3333

3434
[tool.setuptools_scm]
3535

36+
# Test configuration
37+
[tool.pytest.ini_options]
38+
asyncio_default_fixture_loop_scope = "function"
39+
3640
# Lint and format code
3741
[tool.ruff]
3842
fix = false

0 commit comments

Comments
 (0)