Skip to content

Commit e835ff0

Browse files
authored
handle #[pyo3(from_py_with = ...)] on dunder (__magic__) methods (#4117)
* handle `#[pyo3(from_py_with = ...)]` on dunder (__magic__) methods * add newsfragment
1 parent c10c742 commit e835ff0

File tree

6 files changed

+103
-43
lines changed

6 files changed

+103
-43
lines changed

newsfragments/4117.fixed.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Correctly handle `#[pyo3(from_py_with = ...)]` attribute on dunder (`__magic__`) method arguments instead of silently ignoring it.

pyo3-macros-backend/src/params.rs

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use syn::spanned::Spanned;
1010

1111
pub struct Holders {
1212
holders: Vec<syn::Ident>,
13-
gil_refs_checkers: Vec<syn::Ident>,
13+
gil_refs_checkers: Vec<GilRefChecker>,
1414
}
1515

1616
impl Holders {
@@ -32,14 +32,28 @@ impl Holders {
3232
&format!("gil_refs_checker_{}", self.gil_refs_checkers.len()),
3333
span,
3434
);
35-
self.gil_refs_checkers.push(gil_refs_checker.clone());
35+
self.gil_refs_checkers
36+
.push(GilRefChecker::FunctionArg(gil_refs_checker.clone()));
37+
gil_refs_checker
38+
}
39+
40+
pub fn push_from_py_with_checker(&mut self, span: Span) -> syn::Ident {
41+
let gil_refs_checker = syn::Ident::new(
42+
&format!("gil_refs_checker_{}", self.gil_refs_checkers.len()),
43+
span,
44+
);
45+
self.gil_refs_checkers
46+
.push(GilRefChecker::FromPyWith(gil_refs_checker.clone()));
3647
gil_refs_checker
3748
}
3849

3950
pub fn init_holders(&self, ctx: &Ctx) -> TokenStream {
4051
let Ctx { pyo3_path } = ctx;
4152
let holders = &self.holders;
42-
let gil_refs_checkers = &self.gil_refs_checkers;
53+
let gil_refs_checkers = self.gil_refs_checkers.iter().map(|checker| match checker {
54+
GilRefChecker::FunctionArg(ident) => ident,
55+
GilRefChecker::FromPyWith(ident) => ident,
56+
});
4357
quote! {
4458
#[allow(clippy::let_unit_value)]
4559
#(let mut #holders = #pyo3_path::impl_::extract_argument::FunctionArgumentHolder::INIT;)*
@@ -50,11 +64,23 @@ impl Holders {
5064
pub fn check_gil_refs(&self) -> TokenStream {
5165
self.gil_refs_checkers
5266
.iter()
53-
.map(|e| quote_spanned! { e.span() => #e.function_arg(); })
67+
.map(|checker| match checker {
68+
GilRefChecker::FunctionArg(ident) => {
69+
quote_spanned! { ident.span() => #ident.function_arg(); }
70+
}
71+
GilRefChecker::FromPyWith(ident) => {
72+
quote_spanned! { ident.span() => #ident.from_py_with_arg(); }
73+
}
74+
})
5475
.collect()
5576
}
5677
}
5778

79+
enum GilRefChecker {
80+
FunctionArg(syn::Ident),
81+
FromPyWith(syn::Ident),
82+
}
83+
5884
/// Return true if the argument list is simply (*args, **kwds).
5985
pub fn is_forwarded_args(signature: &FunctionSignature<'_>) -> bool {
6086
matches!(

pyo3-macros-backend/src/pymethod.rs

Lines changed: 25 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1053,44 +1053,39 @@ impl Ty {
10531053
ctx: &Ctx,
10541054
) -> TokenStream {
10551055
let Ctx { pyo3_path } = ctx;
1056-
let name_str = arg.name().unraw().to_string();
10571056
match self {
10581057
Ty::Object => extract_object(
10591058
extract_error_mode,
10601059
holders,
1061-
&name_str,
1060+
arg,
10621061
quote! { #ident },
1063-
arg.ty().span(),
10641062
ctx
10651063
),
10661064
Ty::MaybeNullObject => extract_object(
10671065
extract_error_mode,
10681066
holders,
1069-
&name_str,
1067+
arg,
10701068
quote! {
10711069
if #ident.is_null() {
10721070
#pyo3_path::ffi::Py_None()
10731071
} else {
10741072
#ident
10751073
}
10761074
},
1077-
arg.ty().span(),
10781075
ctx
10791076
),
10801077
Ty::NonNullObject => extract_object(
10811078
extract_error_mode,
10821079
holders,
1083-
&name_str,
1080+
arg,
10841081
quote! { #ident.as_ptr() },
1085-
arg.ty().span(),
10861082
ctx
10871083
),
10881084
Ty::IPowModulo => extract_object(
10891085
extract_error_mode,
10901086
holders,
1091-
&name_str,
1087+
arg,
10921088
quote! { #ident.as_ptr() },
1093-
arg.ty().span(),
10941089
ctx
10951090
),
10961091
Ty::CompareOp => extract_error_mode.handle_error(
@@ -1118,24 +1113,37 @@ impl Ty {
11181113
fn extract_object(
11191114
extract_error_mode: ExtractErrorMode,
11201115
holders: &mut Holders,
1121-
name: &str,
1116+
arg: &FnArg<'_>,
11221117
source_ptr: TokenStream,
1123-
span: Span,
11241118
ctx: &Ctx,
11251119
) -> TokenStream {
11261120
let Ctx { pyo3_path } = ctx;
1127-
let holder = holders.push_holder(Span::call_site());
1128-
let gil_refs_checker = holders.push_gil_refs_checker(span);
1129-
let extracted = extract_error_mode.handle_error(
1121+
let gil_refs_checker = holders.push_gil_refs_checker(arg.ty().span());
1122+
let name = arg.name().unraw().to_string();
1123+
1124+
let extract = if let Some(from_py_with) =
1125+
arg.from_py_with().map(|from_py_with| &from_py_with.value)
1126+
{
1127+
let from_py_with_checker = holders.push_from_py_with_checker(from_py_with.span());
1128+
quote! {
1129+
#pyo3_path::impl_::extract_argument::from_py_with(
1130+
#pyo3_path::impl_::pymethods::BoundRef::ref_from_ptr(py, &#source_ptr).0,
1131+
#name,
1132+
#pyo3_path::impl_::deprecations::inspect_fn(#from_py_with, &#from_py_with_checker) as fn(_) -> _,
1133+
)
1134+
}
1135+
} else {
1136+
let holder = holders.push_holder(Span::call_site());
11301137
quote! {
11311138
#pyo3_path::impl_::extract_argument::extract_argument(
11321139
#pyo3_path::impl_::pymethods::BoundRef::ref_from_ptr(py, &#source_ptr).0,
11331140
&mut #holder,
11341141
#name
11351142
)
1136-
},
1137-
ctx,
1138-
);
1143+
}
1144+
};
1145+
1146+
let extracted = extract_error_mode.handle_error(extract, ctx);
11391147
quote! {
11401148
#pyo3_path::impl_::deprecations::inspect_type(#extracted, &#gil_refs_checker)
11411149
}

tests/test_class_basics.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,10 @@ fn get_length(obj: &Bound<'_, PyAny>) -> PyResult<usize> {
290290
Ok(length)
291291
}
292292

293+
fn is_even(obj: &Bound<'_, PyAny>) -> PyResult<bool> {
294+
obj.extract::<i32>().map(|i| i % 2 == 0)
295+
}
296+
293297
#[pyclass]
294298
struct ClassWithFromPyWithMethods {}
295299

@@ -319,6 +323,10 @@ impl ClassWithFromPyWithMethods {
319323
fn staticmethod(#[pyo3(from_py_with = "get_length")] argument: usize) -> usize {
320324
argument
321325
}
326+
327+
fn __contains__(&self, #[pyo3(from_py_with = "is_even")] obj: bool) -> bool {
328+
obj
329+
}
322330
}
323331

324332
#[test]
@@ -339,6 +347,9 @@ fn test_pymethods_from_py_with() {
339347
if has_gil_refs:
340348
assert instance.classmethod_gil_ref(arg) == 2
341349
assert instance.staticmethod(arg) == 2
350+
351+
assert 42 in instance
352+
assert 73 not in instance
342353
"#
343354
);
344355
})

tests/ui/deprecations.rs

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,14 @@ impl MyClass {
3838

3939
#[setter]
4040
fn set_bar_bound(&self, _value: &Bound<'_, PyAny>) {}
41+
42+
fn __eq__(&self, #[pyo3(from_py_with = "extract_gil_ref")] _other: i32) -> bool {
43+
true
44+
}
45+
46+
fn __contains__(&self, #[pyo3(from_py_with = "extract_bound")] _value: i32) -> bool {
47+
true
48+
}
4149
}
4250

4351
fn main() {}

tests/ui/deprecations.stderr

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,12 @@ error: use of deprecated struct `pyo3::PyCell`: `PyCell` was merged into `Bound`
1616
23 | fn method_gil_ref(_slf: &PyCell<Self>) {}
1717
| ^^^^^^
1818

19+
error: use of deprecated method `pyo3::deprecations::GilRefs::<T>::from_py_with_arg`: use `&Bound<'_, PyAny>` as the argument for this `from_py_with` extractor
20+
--> tests/ui/deprecations.rs:42:44
21+
|
22+
42 | fn __eq__(&self, #[pyo3(from_py_with = "extract_gil_ref")] _other: i32) -> bool {
23+
| ^^^^^^^^^^^^^^^^^
24+
1925
error: use of deprecated method `pyo3::deprecations::GilRefs::<T>::function_arg`: use `&Bound<'_, T>` instead for this function argument
2026
--> tests/ui/deprecations.rs:18:33
2127
|
@@ -47,69 +53,69 @@ error: use of deprecated method `pyo3::deprecations::GilRefs::<T>::function_arg`
4753
| ^
4854

4955
error: use of deprecated method `pyo3::deprecations::GilRefs::<T>::function_arg`: use `&Bound<'_, T>` instead for this function argument
50-
--> tests/ui/deprecations.rs:53:43
56+
--> tests/ui/deprecations.rs:61:43
5157
|
52-
53 | fn pyfunction_with_module_gil_ref(module: &PyModule) -> PyResult<&str> {
58+
61 | fn pyfunction_with_module_gil_ref(module: &PyModule) -> PyResult<&str> {
5359
| ^
5460

5561
error: use of deprecated method `pyo3::deprecations::GilRefs::<T>::function_arg`: use `&Bound<'_, T>` instead for this function argument
56-
--> tests/ui/deprecations.rs:63:19
62+
--> tests/ui/deprecations.rs:71:19
5763
|
58-
63 | fn module_gil_ref(m: &PyModule) -> PyResult<()> {
64+
71 | fn module_gil_ref(m: &PyModule) -> PyResult<()> {
5965
| ^
6066

6167
error: use of deprecated method `pyo3::deprecations::GilRefs::<T>::function_arg`: use `&Bound<'_, T>` instead for this function argument
62-
--> tests/ui/deprecations.rs:69:57
68+
--> tests/ui/deprecations.rs:77:57
6369
|
64-
69 | fn module_gil_ref_with_explicit_py_arg(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
70+
77 | fn module_gil_ref_with_explicit_py_arg(_py: Python<'_>, m: &PyModule) -> PyResult<()> {
6571
| ^
6672

6773
error: use of deprecated method `pyo3::deprecations::GilRefs::<T>::from_py_with_arg`: use `&Bound<'_, PyAny>` as the argument for this `from_py_with` extractor
68-
--> tests/ui/deprecations.rs:102:27
74+
--> tests/ui/deprecations.rs:110:27
6975
|
70-
102 | #[pyo3(from_py_with = "extract_gil_ref")] _gil_ref: i32,
76+
110 | #[pyo3(from_py_with = "extract_gil_ref")] _gil_ref: i32,
7177
| ^^^^^^^^^^^^^^^^^
7278

7379
error: use of deprecated method `pyo3::deprecations::GilRefs::<T>::function_arg`: use `&Bound<'_, T>` instead for this function argument
74-
--> tests/ui/deprecations.rs:108:29
80+
--> tests/ui/deprecations.rs:116:29
7581
|
76-
108 | fn pyfunction_gil_ref(_any: &PyAny) {}
82+
116 | fn pyfunction_gil_ref(_any: &PyAny) {}
7783
| ^
7884

7985
error: use of deprecated method `pyo3::deprecations::OptionGilRefs::<std::option::Option<T>>::function_arg`: use `Option<&Bound<'_, T>>` instead for this function argument
80-
--> tests/ui/deprecations.rs:111:36
86+
--> tests/ui/deprecations.rs:119:36
8187
|
82-
111 | fn pyfunction_option_gil_ref(_any: Option<&PyAny>) {}
88+
119 | fn pyfunction_option_gil_ref(_any: Option<&PyAny>) {}
8389
| ^^^^^^
8490

8591
error: use of deprecated method `pyo3::deprecations::GilRefs::<T>::from_py_with_arg`: use `&Bound<'_, PyAny>` as the argument for this `from_py_with` extractor
86-
--> tests/ui/deprecations.rs:118:27
92+
--> tests/ui/deprecations.rs:126:27
8793
|
88-
118 | #[pyo3(from_py_with = "PyAny::len", item("my_object"))]
94+
126 | #[pyo3(from_py_with = "PyAny::len", item("my_object"))]
8995
| ^^^^^^^^^^^^
9096

9197
error: use of deprecated method `pyo3::deprecations::GilRefs::<T>::from_py_with_arg`: use `&Bound<'_, PyAny>` as the argument for this `from_py_with` extractor
92-
--> tests/ui/deprecations.rs:128:27
98+
--> tests/ui/deprecations.rs:136:27
9399
|
94-
128 | #[pyo3(from_py_with = "PyAny::len")] usize,
100+
136 | #[pyo3(from_py_with = "PyAny::len")] usize,
95101
| ^^^^^^^^^^^^
96102

97103
error: use of deprecated method `pyo3::deprecations::GilRefs::<T>::from_py_with_arg`: use `&Bound<'_, PyAny>` as the argument for this `from_py_with` extractor
98-
--> tests/ui/deprecations.rs:134:31
104+
--> tests/ui/deprecations.rs:142:31
99105
|
100-
134 | Zip(#[pyo3(from_py_with = "extract_gil_ref")] i32),
106+
142 | Zip(#[pyo3(from_py_with = "extract_gil_ref")] i32),
101107
| ^^^^^^^^^^^^^^^^^
102108

103109
error: use of deprecated method `pyo3::deprecations::GilRefs::<T>::from_py_with_arg`: use `&Bound<'_, PyAny>` as the argument for this `from_py_with` extractor
104-
--> tests/ui/deprecations.rs:141:27
110+
--> tests/ui/deprecations.rs:149:27
105111
|
106-
141 | #[pyo3(from_py_with = "extract_gil_ref")]
112+
149 | #[pyo3(from_py_with = "extract_gil_ref")]
107113
| ^^^^^^^^^^^^^^^^^
108114

109115
error: use of deprecated method `pyo3::deprecations::GilRefs::<pyo3::Python<'_>>::is_python`: use `wrap_pyfunction_bound!` instead
110-
--> tests/ui/deprecations.rs:154:13
116+
--> tests/ui/deprecations.rs:162:13
111117
|
112-
154 | let _ = wrap_pyfunction!(double, py);
118+
162 | let _ = wrap_pyfunction!(double, py);
113119
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^^
114120
|
115121
= note: this error originates in the macro `wrap_pyfunction` (in Nightly builds, run with -Z macro-backtrace for more info)

0 commit comments

Comments
 (0)