From 0407fcfc3f946e1824af503392b4f1c9e9340966 Mon Sep 17 00:00:00 2001 From: ChengHui Chen <27797326+chenghuichen@users.noreply.github.com> Date: Thu, 20 Mar 2025 11:17:18 +0800 Subject: [PATCH 1/2] #44 Make Split and Predicate Serializable --- pypaimon/py4j/java_implementation.py | 53 +++++++++++++++++----------- pypaimon/py4j/util/java_utils.py | 26 ++++++++++++++ 2 files changed, 58 insertions(+), 21 deletions(-) diff --git a/pypaimon/py4j/java_implementation.py b/pypaimon/py4j/java_implementation.py index 9f378b7..9a13037 100644 --- a/pypaimon/py4j/java_implementation.py +++ b/pypaimon/py4j/java_implementation.py @@ -23,6 +23,7 @@ from pypaimon.py4j.java_gateway import get_gateway from pypaimon.py4j.util import java_utils, constants +from pypaimon.py4j.util.java_utils import serialize_java_object, deserialize_java_object from pypaimon.api import \ (catalog, table, read_builder, table_scan, split, row_type, table_read, write_builder, table_write, commit_message, @@ -145,33 +146,41 @@ def __init__(self, j_splits): self._j_splits = j_splits def splits(self) -> List['Split']: - return list(map(lambda s: Split(s), self._j_splits)) + return list(map(lambda s: self._build_single_split(s), self._j_splits)) + + def _build_single_split(self, j_split) -> 'Split': + j_split_bytes = serialize_java_object(j_split) + row_count = j_split.rowCount() + files_optional = j_split.convertToRawFiles() + if not files_optional.isPresent(): + file_size = 0 + file_paths = [] + else: + files = files_optional.get() + file_size = sum(file.length() for file in files) + file_paths = [file.path() for file in files] + return Split(j_split_bytes, row_count, file_size, file_paths) class Split(split.Split): - def __init__(self, j_split): - self._j_split = j_split + def __init__(self, j_split_bytes, row_count: int, file_size: int, file_paths: List[str]): + self._j_split_bytes = j_split_bytes + self._row_count = row_count + self._file_size = file_size + self._file_paths = file_paths def to_j_split(self): - return self._j_split + return deserialize_java_object(self._j_split_bytes) def row_count(self) -> int: - return self._j_split.rowCount() + return self._row_count def file_size(self) -> int: - files_optional = self._j_split.convertToRawFiles() - if not files_optional.isPresent(): - return 0 - files = files_optional.get() - return sum(file.length() for file in files) + return self._file_size def file_paths(self) -> List[str]: - files_optional = self._j_split.convertToRawFiles() - if not files_optional.isPresent(): - return [] - files = files_optional.get() - return [file.path() for file in files] + return self._file_paths class TableRead(table_read.TableRead): @@ -317,11 +326,11 @@ def close(self): class Predicate(predicate.Predicate): - def __init__(self, j_predicate): - self._j_predicate = j_predicate + def __init__(self, j_predicate_bytes): + self._j_predicate_bytes = j_predicate_bytes def to_j_predicate(self): - return self._j_predicate + return deserialize_java_object(self._j_predicate_bytes) class PredicateBuilder(predicate.PredicateBuilder): @@ -350,7 +359,7 @@ def _build(self, method: str, field: str, literals: Optional[List[Any]] = None): index, literals ) - return Predicate(j_predicate) + return Predicate(serialize_java_object(j_predicate)) def equal(self, field: str, literal: Any) -> Predicate: return self._build('equal', field, [literal]) @@ -397,8 +406,10 @@ def between(self, field: str, included_lower_bound: Any, included_upper_bound: A def and_predicates(self, predicates: List[Predicate]) -> Predicate: predicates = list(map(lambda p: p.to_j_predicate(), predicates)) - return Predicate(get_gateway().jvm.PredicationUtil.buildAnd(predicates)) + j_predicate = get_gateway().jvm.PredicationUtil.buildAnd(predicates) + return Predicate(serialize_java_object(j_predicate)) def or_predicates(self, predicates: List[Predicate]) -> Predicate: predicates = list(map(lambda p: p.to_j_predicate(), predicates)) - return Predicate(get_gateway().jvm.PredicationUtil.buildOr(predicates)) + j_predicate = get_gateway().jvm.PredicationUtil.buildOr(predicates) + return Predicate(serialize_java_object(j_predicate)) diff --git a/pypaimon/py4j/util/java_utils.py b/pypaimon/py4j/util/java_utils.py index 0beb527..5e2252e 100644 --- a/pypaimon/py4j/util/java_utils.py +++ b/pypaimon/py4j/util/java_utils.py @@ -100,3 +100,29 @@ def to_arrow_schema(j_row_type): arrow_schema = schema_reader.schema schema_reader.close() return arrow_schema + + +def serialize_java_object(java_obj) -> bytes: + gateway = get_gateway() + util = gateway.jvm.org.apache.paimon.utils.InstantiationUtil + try: + java_bytes = util.serializeObject(java_obj) + return bytes(java_bytes) + except Exception as e: + raise RuntimeError(f"Java serialization failed: {e}") + + +def deserialize_java_object(bytes_data): + gateway = get_gateway() + cl = get_gateway().jvm.Thread.currentThread().getContextClassLoader() + util = gateway.jvm.org.apache.paimon.utils.InstantiationUtil + try: + byte_buffer = gateway.jvm.java.nio.ByteBuffer.allocate(len(bytes_data)) + for b in bytes_data: + byte_buffer.put(b if b >= 0 else b + 256) + byte_buffer.flip() + java_bytes = byte_buffer.array() + + return util.deserializeObject(java_bytes, cl) + except Exception as e: + raise RuntimeError(f"Java deserialization failed: {e}") From ee572476d028aac79cca8b773dae852fa0757c43 Mon Sep 17 00:00:00 2001 From: ChengHui Chen <27797326+chenghuichen@users.noreply.github.com> Date: Thu, 20 Mar 2025 11:25:19 +0800 Subject: [PATCH 2/2] #44 Make Split and Predicate Serializable --- pypaimon/py4j/util/java_utils.py | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/pypaimon/py4j/util/java_utils.py b/pypaimon/py4j/util/java_utils.py index 5e2252e..2a2aac9 100644 --- a/pypaimon/py4j/util/java_utils.py +++ b/pypaimon/py4j/util/java_utils.py @@ -116,13 +116,4 @@ def deserialize_java_object(bytes_data): gateway = get_gateway() cl = get_gateway().jvm.Thread.currentThread().getContextClassLoader() util = gateway.jvm.org.apache.paimon.utils.InstantiationUtil - try: - byte_buffer = gateway.jvm.java.nio.ByteBuffer.allocate(len(bytes_data)) - for b in bytes_data: - byte_buffer.put(b if b >= 0 else b + 256) - byte_buffer.flip() - java_bytes = byte_buffer.array() - - return util.deserializeObject(java_bytes, cl) - except Exception as e: - raise RuntimeError(f"Java deserialization failed: {e}") + return util.deserializeObject(bytes_data, cl)