Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 84 additions & 10 deletions Modules/_base64/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,9 @@ use std::mem::MaybeUninit;
use std::ptr;
use std::slice;

use cpython_sys::_Py_DecRef;
use cpython_sys::_Py_IncRef;
use cpython_sys::METH_FASTCALL;
use cpython_sys::Py_DecRef;
use cpython_sys::Py_buffer;
use cpython_sys::Py_ssize_t;
use cpython_sys::PyBuffer_Release;
Expand All @@ -26,6 +27,78 @@ const PYBUF_SIMPLE: c_int = 0;
const PAD_BYTE: u8 = b'=';
const ENCODE_TABLE: [u8; 64] = *b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";

pub struct PyRc {
ptr: std::ptr::NonNull<PyObject>,
}

impl PyRc {
/// # Safety
/// `ptr` must be a valid pointer to a PyObject.
pub unsafe fn from_raw(ptr: *mut PyObject) -> Option<Self> {
let ptr = std::ptr::NonNull::new(ptr)?;
Some(Self { ptr })
}

pub fn into_non_null(zelf: Self) -> std::ptr::NonNull<PyObject> {
let ptr = zelf.ptr;
std::mem::forget(zelf);
ptr
}

pub fn into_raw(zelf: Self) -> *mut PyObject {
let ptr = zelf.ptr.as_ptr();
std::mem::forget(zelf);
ptr
}

pub fn as_raw(&self) -> *mut PyObject {
self.ptr.as_ptr()
}
}

impl AsRef<PyObject> for PyRc {
fn as_ref(&self) -> &PyObject {
unsafe { self.ptr.as_ref() }
}
}

impl AsMut<PyObject> for PyRc {
fn as_mut(&mut self) -> &mut PyObject {
unsafe { self.ptr.as_mut() }
}
}

impl std::ops::Deref for PyRc {
type Target = PyObject;

fn deref(&self) -> &Self::Target {
self.as_ref()
}
}

impl std::ops::DerefMut for PyRc {
fn deref_mut(&mut self) -> &mut Self::Target {
self.as_mut()
}
}

impl Clone for PyRc {
fn clone(&self) -> Self {
unsafe {
_Py_IncRef(self.ptr.as_ptr());
}
Self { ptr: self.ptr }
}
}

impl Drop for PyRc {
fn drop(&mut self) {
unsafe {
_Py_DecRef(self.ptr.as_ptr());
}
}
}

#[inline]
fn encoded_output_len(input_len: usize) -> Option<usize> {
input_len
Expand Down Expand Up @@ -135,12 +208,12 @@ pub unsafe extern "C" fn standard_b64encode(

// Safe cast by Safety
match standard_b64encode_impl(source) {
Ok(result) => result,
Ok(result) => PyRc::into_raw(result),
Err(_) => ptr::null_mut(),
}
}

fn standard_b64encode_impl(source: &PyObject) -> Result<*mut PyObject, ()> {
fn standard_b64encode_impl(source: &PyObject) -> Result<PyRc, ()> {
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

&PyObject means it guarantees the RC count is never changed in this logic. (not 100% because it can be changed in C side)
PyRc means it can be increased or decreased.

let buffer = match BorrowedBuffer::from_object(source) {
Ok(buf) => buf,
Err(_) => return Err(()),
Expand Down Expand Up @@ -174,16 +247,17 @@ fn standard_b64encode_impl(source: &PyObject) -> Result<*mut PyObject, ()> {
return Err(());
}

let result = unsafe { PyBytes_FromStringAndSize(ptr::null(), output_len as Py_ssize_t) };
if result.is_null() {
let Some(result) = (unsafe {
PyRc::from_raw(PyBytes_FromStringAndSize(
ptr::null(),
output_len as Py_ssize_t,
))
}) else {
return Err(());
}
};

let dest_ptr = unsafe { PyBytes_AsString(result) };
let dest_ptr = unsafe { PyBytes_AsString(result.as_raw()) };
if dest_ptr.is_null() {
unsafe {
Py_DecRef(result);
}
return Err(());
}
let dest = unsafe { slice::from_raw_parts_mut(dest_ptr.cast::<u8>(), output_len) };
Expand Down