Skip to content

Commit 097e724

Browse files
committed
support before_request
1 parent 46de544 commit 097e724

File tree

7 files changed

+95
-110
lines changed

7 files changed

+95
-110
lines changed

README.md

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -124,6 +124,27 @@ def get_user_post(self, user_id: int, post_id: int) -> PostResponse:
124124
- For GET and DELETE methods, remaining arguments are sent as query parameters
125125
- For POST, PUT, and PATCH methods, remaining arguments are sent in the request body as JSON
126126

127+
```python
128+
129+
# you can cal signature by your self, overwrite the function `before_request`
130+
from pydantic_client.schema import RequestInfo
131+
132+
133+
class MyAPIClient(RequestsWebClient):
134+
# some code
135+
136+
def before_request(self, request_params: Dict[str, Any]) -> Dict[str, Any]:
137+
# the request params before send: body, header, etc...
138+
sign = cal_signature(request_params)
139+
request_params["headers].update(dict(signature=sign))
140+
return request_params
141+
142+
143+
# will send your new request_params
144+
user = client.get_user("123")
145+
146+
```
147+
127148

128149
### Timing Context Manager
129150

pydantic_client/async_client.py

Lines changed: 13 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,14 @@
22

33
from pydantic import BaseModel
44

5-
from .base import BaseWebClient
5+
from .base import BaseWebClient, RequestInfo
66

77
try:
88
import aiohttp
99
except ImportError:
1010
raise ImportError("please install aiohttp: `pip install aiohttp`")
1111

1212

13-
1413
T = TypeVar('T', bound=BaseModel)
1514

1615

@@ -25,32 +24,15 @@ def __init__(
2524
):
2625
super().__init__(base_url, headers, timeout, session, statsd_address)
2726

28-
async def _request(
29-
self,
30-
method: str,
31-
path: str,
32-
*,
33-
params: Optional[Dict[str, Any]] = None,
34-
json: Optional[Dict[str, Any]] = None,
35-
data: Optional[Dict[str, Any]] = None,
36-
headers: Optional[Dict[str, str]] = None,
37-
response_model: Optional[Type[T]] = None
38-
) -> Any:
39-
url = self._make_url(path)
40-
41-
# Merge headers
42-
request_headers = self.headers.copy()
43-
if headers:
44-
request_headers.update(headers)
27+
async def _request(self, request_info: RequestInfo) -> Any:
28+
request_params = self.dump_request_params(request_info)
29+
response_model = request_params.pop("response_model")
30+
31+
request_params = self.before_request(request_params)
4532

4633
async with aiohttp.ClientSession() as session:
4734
async with session.request(
48-
method=method,
49-
url=url,
50-
params=params,
51-
json=json,
52-
data=data,
53-
headers=request_headers,
35+
**request_params,
5436
timeout=aiohttp.ClientTimeout(total=self.timeout)
5537
) as response:
5638
response.raise_for_status()
@@ -77,37 +59,17 @@ def __init__(
7759
except ImportError:
7860
raise ImportError("please install httpx: `pip install httpx`")
7961

80-
async def _request(
81-
self,
82-
method: str,
83-
path: str,
84-
*,
85-
params: Optional[Dict[str, Any]] = None,
86-
json: Optional[Dict[str, Any]] = None,
87-
data: Optional[Dict[str, Any]] = None,
88-
headers: Optional[Dict[str, str]] = None,
89-
response_model: Optional[Type[T]] = None
90-
) -> Any:
62+
async def _request(self, request_info: RequestInfo) -> Any:
9163
import httpx
92-
url = self._make_url(path)
93-
94-
# Merge headers
95-
request_headers = self.headers.copy()
96-
if headers:
97-
request_headers.update(headers)
64+
request_params = self.dump_request_params(request_info)
65+
response_model = request_params.pop("response_model")
66+
67+
request_params = self.before_request(request_params)
9868

9969
async with httpx.AsyncClient(timeout=self.timeout) as client:
100-
response = await client.request(
101-
method=method,
102-
url=url,
103-
params=params,
104-
json=json,
105-
data=data,
106-
headers=request_headers
107-
)
70+
response = await client.request(**request_params)
10871
response.raise_for_status()
10972
data = response.json()
110-
11173
if response_model is not None:
11274
return response_model.model_validate(data)
11375
return data

pydantic_client/base.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from typing import Any, Dict, Optional, Type, TypeVar, List
88

99
from pydantic import BaseModel
10+
from .schema import RequestInfo
1011

1112

1213
T = TypeVar('T', bound=BaseModel)
@@ -59,25 +60,32 @@ def from_config(cls, config: Dict[str, Any]) -> 'BaseWebClient':
5960
timeout=config.get('timeout', 30),
6061
session=config.get('session', None)
6162
)
63+
64+
def before_request(self, request_params: Dict[str, Any]) -> Dict[str, Any]:
65+
"""before request, you can do something by yourself
66+
such as: cal signature, etc."""
67+
return request_params
6268

6369
def _make_url(self, path: str) -> str:
6470
return f"{self.base_url}/{path.lstrip('/')}"
6571

6672
def span(self, prefix: Optional[str] = None):
6773
return SpanContext(self, prefix)
74+
75+
def dump_request_params(self, request_info: RequestInfo) -> Dict[str, Any]:
76+
request_params = request_info.model_dump()
77+
url = self._make_url(request_params.pop("path"))
78+
# Merge headers
79+
request_headers = self.headers.copy()
80+
if request_info.headers:
81+
request_headers.update(request_info.headers)
82+
83+
request_params["headers"] = request_headers
84+
request_params["url"] = url
85+
return request_params
6886

6987
@abstractmethod
70-
def _request(
71-
self,
72-
method: str,
73-
path: str,
74-
*,
75-
params: Optional[Dict[str, Any]] = None,
76-
json: Optional[Dict[str, Any]] = None,
77-
data: Optional[Dict[str, Any]] = None,
78-
headers: Optional[Dict[str, str]] = None,
79-
response_model: Optional[Type[T]] = None
80-
) -> Any:
88+
def _request(self, request_info: RequestInfo) -> Any:
8189
...
8290

8391

pydantic_client/decorators.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,12 @@
55
from pydantic import BaseModel
66

77
from .tools.agno import register_agno_tool
8+
from .schema import RequestInfo
89

910

1011
def _process_request_params(
1112
func: Callable, method: str, path: str, form_body: bool, *args, **kwargs
12-
) -> Dict[str, Any]:
13+
) -> RequestInfo:
1314
"""
1415
Extract and process request parameters from function arguments.
1516
@@ -21,7 +22,7 @@ def _process_request_params(
2122
*args, **kwargs: Function arguments
2223
2324
Returns:
24-
Dictionary containing processed request parameters
25+
RequestInfo containing processed request parameters
2526
"""
2627
sig = inspect.signature(func)
2728
bound_args = sig.bind(*args, **kwargs)
@@ -61,7 +62,7 @@ def _process_request_params(
6162
if method in ["GET", "DELETE"]:
6263
query_params[param_name] = param_value
6364

64-
return {
65+
info = {
6566
"method": method,
6667
"path": formatted_path,
6768
"params": query_params if method in ["GET", "DELETE"] else None,
@@ -78,6 +79,7 @@ def _process_request_params(
7879
"headers": request_headers,
7980
"response_model": response_model,
8081
}
82+
return RequestInfo.model_validate(info)
8183

8284

8385
def rest(
@@ -131,15 +133,15 @@ async def async_wrapped(self, *args, **kwargs):
131133
request_params = _process_request_params(
132134
func, method, path, form_body, self, *args, **kwargs
133135
)
134-
return await self._request(**request_params)
136+
return await self._request(request_params)
135137

136138
@wraps(func)
137139
def sync_wrapped(self, *args, **kwargs):
138140
"""Sync wrapper for handling HTTP requests."""
139141
request_params = _process_request_params(
140142
func, method, path, form_body, self, *args, **kwargs
141143
)
142-
return self._request(**request_params)
144+
return self._request(request_params)
143145

144146
@wraps(func)
145147
def choose_wrapper(self, *args, **kwargs):

pydantic_client/schema.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
from typing import Dict, Optional, Any, Type
2+
3+
from pydantic import BaseModel
4+
5+
6+
class RequestInfo(BaseModel):
7+
method: str
8+
path: str
9+
params: Optional[Dict[str, Any]] = {}
10+
json: Optional[Dict[str, Any]] = None
11+
data: Optional[Dict[str, Any]] = None
12+
headers: Optional[Dict[str, Any]] = {}
13+
response_model: Optional[Type[BaseModel]] = None

pydantic_client/sync_client.py

Lines changed: 8 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import requests
44
from pydantic import BaseModel
55

6-
from .base import BaseWebClient
6+
from .base import BaseWebClient, RequestInfo
77

88
T = TypeVar('T', bound=BaseModel)
99

@@ -21,35 +21,14 @@ def __init__(
2121
if not self.session:
2222
self.session = requests.Session()
2323

24-
def _request(
25-
self,
26-
method: str,
27-
path: str,
28-
*,
29-
params: Optional[Dict[str, Any]] = None,
30-
json: Optional[Dict[str, Any]] = None,
31-
data: Optional[Dict[str, Any]] = None,
32-
headers: Optional[Dict[str, str]] = None,
33-
response_model: Optional[Type[T]] = None
34-
) -> Any:
35-
url = self._make_url(path)
36-
37-
# Merge headers
38-
request_headers = self.headers.copy()
39-
if headers:
40-
request_headers.update(headers)
41-
42-
response = requests.request(
43-
method=method,
44-
url=url,
45-
params=params,
46-
json=json,
47-
data=data,
48-
headers=request_headers,
49-
timeout=self.timeout
50-
)
51-
response.raise_for_status()
24+
def _request(self, request_info: RequestInfo) -> Any:
25+
request_params = self.dump_request_params(request_info)
26+
response_model = request_params.pop("response_model")
5227

28+
request_params = self.before_request(request_params)
29+
30+
response = requests.request(**request_params, timeout=self.timeout)
31+
response.raise_for_status()
5332
data = response.json()
5433
if response_model is not None:
5534
return response_model.model_validate(data)

tests/test_async_new_features.py

Lines changed: 14 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,12 @@ async def test_async_pydantic_model_json_body():
3838
# For this test, we'll use a simple mock by patching the _request method
3939

4040
class MockAsyncClient(TestAsyncFormBodyClient):
41-
async def _request(self, method, path, **kwargs):
41+
async def _request(self, request_info):
4242
# Verify that the correct parameters are passed
43-
assert method == "POST"
44-
assert path == "/users"
45-
assert kwargs.get('json') == {"name": "Test User", "email": "test@example.com"}
46-
assert kwargs.get('data') is None # Should be None for JSON requests
43+
assert request_info.method == "POST"
44+
assert request_info.path == "/users"
45+
assert request_info.json == {"name": "Test User", "email": "test@example.com"}
46+
assert request_info.data is None # Should be None for JSON requests
4747

4848
# Return mock response
4949
return {"id": "123", "name": "Test User", "email": "test@example.com"}
@@ -60,12 +60,12 @@ async def test_async_pydantic_model_form_body():
6060
"""Test that async client correctly handles Pydantic models as form data"""
6161

6262
class MockAsyncClient(TestAsyncFormBodyClient):
63-
async def _request(self, method, path, **kwargs):
63+
async def _request(self, request_info):
6464
# Verify that the correct parameters are passed
65-
assert method == "POST"
66-
assert path == "/users"
67-
assert kwargs.get('data') == {"name": "Test User", "email": "test@example.com"}
68-
assert kwargs.get('json') is None # Should be None for form requests
65+
assert request_info.method == "POST"
66+
assert request_info.path == "/users"
67+
assert request_info.data == {"name": "Test User", "email": "test@example.com"}
68+
assert request_info.json is None # Should be None for form requests
6969

7070
# Return mock response
7171
return {"id": "123", "name": "Test User", "email": "test@example.com"}
@@ -82,11 +82,11 @@ async def test_async_custom_headers():
8282
"""Test that async client correctly handles custom headers"""
8383

8484
class MockAsyncClient(TestAsyncFormBodyClient):
85-
async def _request(self, method, path, **kwargs):
85+
async def _request(self, request_info):
8686
# Verify that the correct parameters are passed
87-
assert method == "POST"
88-
assert path == "/users/custom"
89-
assert kwargs.get('headers') == {"X-Custom-Header": "custom-value", "Authorization": "Bearer token123"}
87+
assert request_info.method == "POST"
88+
assert request_info.path == "/users/custom"
89+
assert request_info.headers == {"X-Custom-Header": "custom-value", "Authorization": "Bearer token123"}
9090

9191
# Return mock response
9292
return {"id": "123", "name": "Test User", "email": "test@example.com"}

0 commit comments

Comments
 (0)