@@ -83,7 +83,7 @@ struct BorrowedBuffer {
8383}
8484
8585impl BorrowedBuffer {
86- unsafe fn from_object ( obj : * mut PyObject ) -> Result < Self , ( ) > {
86+ fn from_object ( obj : & mut PyObject ) -> Result < Self , ( ) > {
8787 let mut view = MaybeUninit :: < Py_buffer > :: uninit ( ) ;
8888 if unsafe { PyObject_GetBuffer ( obj, view. as_mut_ptr ( ) , PYBUF_SIMPLE ) } != 0 {
8989 return Err ( ( ) ) ;
@@ -110,6 +110,9 @@ impl Drop for BorrowedBuffer {
110110 }
111111}
112112
113+ /// # Safety
114+ /// `module` must be a valid pointer of PyObject representing the module.
115+ /// `args` must be a valid pointer to an array of valid PyObject pointers with length `nargs`.
113116#[ unsafe( no_mangle) ]
114117pub unsafe extern "C" fn standard_b64encode (
115118 _module : * mut PyObject ,
@@ -123,13 +126,21 @@ pub unsafe extern "C" fn standard_b64encode(
123126 c"standard_b64encode() takes exactly one argument" . as_ptr ( ) ,
124127 ) ;
125128 }
126- return ptr:: null_mut ( ) ;
127129 }
128130
129- let source = unsafe { * args } ;
130- let buffer = match unsafe { BorrowedBuffer :: from_object ( source) } {
131+ let source = unsafe { & mut * * args } ;
132+
133+ // Safe cast by Safety
134+ match standard_b64encode_impl ( source) {
135+ Ok ( result) => result,
136+ Err ( _) => ptr:: null_mut ( ) ,
137+ }
138+ }
139+
140+ fn standard_b64encode_impl ( source : & mut PyObject ) -> Result < * mut PyObject , ( ) > {
141+ let buffer = match BorrowedBuffer :: from_object ( source) {
131142 Ok ( buf) => buf,
132- Err ( _) => return ptr :: null_mut ( ) ,
143+ Err ( _) => return Err ( ( ) ) ,
133144 } ;
134145
135146 let view_len = buffer. len ( ) ;
@@ -140,44 +151,43 @@ pub unsafe extern "C" fn standard_b64encode(
140151 c"standard_b64encode() argument has negative length" . as_ptr ( ) ,
141152 ) ;
142153 }
143- return ptr :: null_mut ( ) ;
154+ return Err ( ( ) ) ;
144155 }
156+
145157 let input_len = view_len as usize ;
146158 let input = unsafe { slice:: from_raw_parts ( buffer. as_ptr ( ) , input_len) } ;
147159
148160 let Some ( output_len) = encoded_output_len ( input_len) else {
149161 unsafe {
150162 PyErr_NoMemory ( ) ;
151163 }
152- return ptr :: null_mut ( ) ;
164+ return Err ( ( ) ) ;
153165 } ;
154166
155167 if output_len > isize:: MAX as usize {
156168 unsafe {
157169 PyErr_NoMemory ( ) ;
158170 }
159- return ptr :: null_mut ( ) ;
171+ return Err ( ( ) ) ;
160172 }
161173
162- let result = unsafe {
163- PyBytes_FromStringAndSize ( ptr:: null ( ) , output_len as Py_ssize_t )
164- } ;
174+ let result = unsafe { PyBytes_FromStringAndSize ( ptr:: null ( ) , output_len as Py_ssize_t ) } ;
165175 if result. is_null ( ) {
166- return ptr :: null_mut ( ) ;
176+ return Err ( ( ) ) ;
167177 }
168178
169179 let dest_ptr = unsafe { PyBytes_AsString ( result) } ;
170180 if dest_ptr. is_null ( ) {
171181 unsafe {
172182 Py_DecRef ( result) ;
173183 }
174- return ptr :: null_mut ( ) ;
184+ return Err ( ( ) ) ;
175185 }
176186 let dest = unsafe { slice:: from_raw_parts_mut ( dest_ptr. cast :: < u8 > ( ) , output_len) } ;
177187
178188 let written = encode_into ( input, dest) ;
179189 debug_assert_eq ! ( written, output_len) ;
180- result
190+ Ok ( result)
181191}
182192
183193#[ unsafe( no_mangle) ]
0 commit comments