Skip to content

Commit ed31a89

Browse files
committed
Split standard_b64encode_impl
1 parent c9deee6 commit ed31a89

File tree

1 file changed

+24
-14
lines changed

1 file changed

+24
-14
lines changed

Modules/_base64/src/lib.rs

Lines changed: 24 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ struct BorrowedBuffer {
8383
}
8484

8585
impl BorrowedBuffer {
86-
unsafe fn from_object(obj: *mut PyObject) -> Result<Self, ()> {
86+
fn from_object(obj: &mut PyObject) -> Result<Self, ()> {
8787
let mut view = MaybeUninit::<Py_buffer>::uninit();
8888
if unsafe { PyObject_GetBuffer(obj, view.as_mut_ptr(), PYBUF_SIMPLE) } != 0 {
8989
return Err(());
@@ -110,6 +110,9 @@ impl Drop for BorrowedBuffer {
110110
}
111111
}
112112

113+
/// # Safety
114+
/// `module` must be a valid pointer of PyObject representing the module.
115+
/// `args` must be a valid pointer to an array of valid PyObject pointers with length `nargs`.
113116
#[unsafe(no_mangle)]
114117
pub unsafe extern "C" fn standard_b64encode(
115118
_module: *mut PyObject,
@@ -123,13 +126,21 @@ pub unsafe extern "C" fn standard_b64encode(
123126
c"standard_b64encode() takes exactly one argument".as_ptr(),
124127
);
125128
}
126-
return ptr::null_mut();
127129
}
128130

129-
let source = unsafe { *args };
130-
let buffer = match unsafe { BorrowedBuffer::from_object(source) } {
131+
let source = unsafe { &mut **args };
132+
133+
// Safe cast by Safety
134+
match standard_b64encode_impl(source) {
135+
Ok(result) => result,
136+
Err(_) => ptr::null_mut(),
137+
}
138+
}
139+
140+
fn standard_b64encode_impl(source: &mut PyObject) -> Result<*mut PyObject, ()> {
141+
let buffer = match BorrowedBuffer::from_object(source) {
131142
Ok(buf) => buf,
132-
Err(_) => return ptr::null_mut(),
143+
Err(_) => return Err(()),
133144
};
134145

135146
let view_len = buffer.len();
@@ -140,44 +151,43 @@ pub unsafe extern "C" fn standard_b64encode(
140151
c"standard_b64encode() argument has negative length".as_ptr(),
141152
);
142153
}
143-
return ptr::null_mut();
154+
return Err(());
144155
}
156+
145157
let input_len = view_len as usize;
146158
let input = unsafe { slice::from_raw_parts(buffer.as_ptr(), input_len) };
147159

148160
let Some(output_len) = encoded_output_len(input_len) else {
149161
unsafe {
150162
PyErr_NoMemory();
151163
}
152-
return ptr::null_mut();
164+
return Err(());
153165
};
154166

155167
if output_len > isize::MAX as usize {
156168
unsafe {
157169
PyErr_NoMemory();
158170
}
159-
return ptr::null_mut();
171+
return Err(());
160172
}
161173

162-
let result = unsafe {
163-
PyBytes_FromStringAndSize(ptr::null(), output_len as Py_ssize_t)
164-
};
174+
let result = unsafe { PyBytes_FromStringAndSize(ptr::null(), output_len as Py_ssize_t) };
165175
if result.is_null() {
166-
return ptr::null_mut();
176+
return Err(());
167177
}
168178

169179
let dest_ptr = unsafe { PyBytes_AsString(result) };
170180
if dest_ptr.is_null() {
171181
unsafe {
172182
Py_DecRef(result);
173183
}
174-
return ptr::null_mut();
184+
return Err(());
175185
}
176186
let dest = unsafe { slice::from_raw_parts_mut(dest_ptr.cast::<u8>(), output_len) };
177187

178188
let written = encode_into(input, dest);
179189
debug_assert_eq!(written, output_len);
180-
result
190+
Ok(result)
181191
}
182192

183193
#[unsafe(no_mangle)]

0 commit comments

Comments
 (0)