diff --git a/polars_bio/__init__.py b/polars_bio/__init__.py index 01d66d58..c18c4da3 100644 --- a/polars_bio/__init__.py +++ b/polars_bio/__init__.py @@ -14,7 +14,7 @@ sql, ) from .polars_ext import PolarsRangesOperations as LazyFrame -from .range_op import FilterOp, count_overlaps, coverage, merge, nearest, overlap +from .range_op import FilterOp, cluster, count_overlaps, coverage, merge, nearest, overlap from .range_viz import visualize_intervals POLARS_BIO_MAX_THREADS = "datafusion.execution.target_partitions" @@ -24,7 +24,9 @@ __all__ = [ "overlap", "nearest", + "coverage", "merge", + "cluster", "count_overlaps", "coverage", "ctx", diff --git a/polars_bio/interval_op_helpers.py b/polars_bio/interval_op_helpers.py index 4e0f0495..7e235b3a 100644 --- a/polars_bio/interval_op_helpers.py +++ b/polars_bio/interval_op_helpers.py @@ -16,14 +16,17 @@ def get_py_ctx() -> datafusion.context.SessionContext: def read_df_to_datafusion( py_ctx: datafusion.context.SessionContext, - df: Union[str, pl.DataFrame, pl.LazyFrame, pd.DataFrame], -) -> datafusion.dataframe: + df: Union[str, pl.DataFrame, pl.LazyFrame, pd.DataFrame, datafusion.dataframe.DataFrame] +) -> datafusion.dataframe.DataFrame: + if isinstance(df, pl.DataFrame): return py_ctx.from_polars(df) elif isinstance(df, pd.DataFrame): return py_ctx.from_pandas(df) elif isinstance(df, pl.LazyFrame): return py_ctx.from_polars(df.collect()) + elif isinstance(df, datafusion.dataframe.DataFrame): + return df elif isinstance(df, str): ext = Path(df).suffix if ext == ".csv": @@ -46,8 +49,9 @@ def read_df_to_datafusion( return py_ctx.read_parquet(df) raise ValueError("Invalid `df` argument.") - -def df_to_lazyframe(df: datafusion.DataFrame) -> pl.LazyFrame: +def df_to_lazyframe( + df: datafusion.dataframe.DataFrame +) -> pl.LazyFrame: # TODO: make it actually lazy """ def _get_lazy( @@ -63,8 +67,10 @@ def _get_lazy( def convert_result( - df: datafusion.DataFrame, output_type: str, streaming: bool -) -> Union[pl.LazyFrame, pl.DataFrame, pd.DataFrame]: + df: datafusion.dataframe.DataFrame, + output_type: str, + streaming: bool +) -> Union[pl.LazyFrame, pl.DataFrame, pd.DataFrame, datafusion.dataframe.DataFrame]: # TODO: implement streaming if streaming: # raise NotImplementedError("streaming is not implemented") @@ -75,4 +81,6 @@ def convert_result( return df.to_pandas() elif output_type == "polars.LazyFrame": return df_to_lazyframe(df) + elif output_type == "datafusion.DataFrame": + return df raise ValueError("Invalid `output_type` argument") diff --git a/polars_bio/polars_ext.py b/polars_bio/polars_ext.py index 2c4f29ff..1e802bfe 100644 --- a/polars_bio/polars_ext.py +++ b/polars_bio/polars_ext.py @@ -54,6 +54,59 @@ def nearest( cols1=cols1, cols2=cols2, ) + + def coverage( + self, + other_df: pl.LazyFrame, + suffixes: tuple[str, str] = ("", "_"), + cols1=["chrom", "start", "end"], + cols2=["chrom", "start", "end"], + ) -> pl.LazyFrame: + """ + !!! note + Alias for [coverage](api.md#polars_bio.coverage) + """ + return pb.coverage( + self._ldf, + other_df, + suffixes=suffixes, + cols1=cols1, + cols2=cols2, + ) + + def merge( + self, + overlap_filter: FilterOp = FilterOp.Strict, + min_dist: float = 0, + cols: Union[list[str], None] = None, + ) -> pl.LazyFrame: + """ + !!! note + Alias for [merge](api.md#polars_bio.merge) + """ + return pb.merge( + self._ldf, + overlap_filter=overlap_filter, + min_dist=min_dist, + cols=cols + ) + + def cluster( + self, + overlap_filter: FilterOp = FilterOp.Strict, + min_dist: float = 0, + cols: Union[list[str], None] = None, + ) -> pl.LazyFrame: + """ + !!! note + Alias for [cluster](api.md#polars_bio.cluster) + """ + return pb.cluster( + self._ldf, + overlap_filter=overlap_filter, + min_dist=min_dist, + cols=cols + ) def count_overlaps( self, diff --git a/polars_bio/range_op.py b/polars_bio/range_op.py index b52b493e..cd1240f9 100644 --- a/polars_bio/range_op.py +++ b/polars_bio/range_op.py @@ -13,7 +13,11 @@ from .interval_op_helpers import convert_result, get_py_ctx, read_df_to_datafusion from .range_op_helpers import _validate_overlap_input, range_operation -__all__ = ["overlap", "nearest", "count_overlaps", "merge"] +import datafusion +from datafusion import col, literal +import pyarrow + +__all__ = ["overlap", "nearest", "merge", "cluster", "coverage", "count_overlaps"] if TYPE_CHECKING: @@ -173,65 +177,347 @@ def nearest( return range_operation(df1, df2, range_options, output_type, ctx, read_options) -def coverage( - df1: Union[str, pl.DataFrame, pl.LazyFrame, pd.DataFrame], - df2: Union[str, pl.DataFrame, pl.LazyFrame, pd.DataFrame], +def merge( + df: Union[str, pl.DataFrame, pl.LazyFrame, pd.DataFrame, datafusion.dataframe.DataFrame], overlap_filter: FilterOp = FilterOp.Strict, - suffixes: tuple[str, str] = ("_1", "_2"), + min_dist: float = 0, + cols: Union[list[str], None] = ["chrom", "start", "end"], + on_cols: Union[list[str], None] = None, + output_type: str = "polars.LazyFrame", + streaming: bool = False, +) -> Union[pl.LazyFrame, pl.DataFrame, pd.DataFrame, datafusion.dataframe.DataFrame]: + """ + Merge overlapping intervals. It is assumed that start < end. + + + Parameters: + df: Can be a path to a file, a polars DataFrame, or a pandas DataFrame. CSV with a header, BED and Parquet are supported. + overlap_filter: FilterOp, optional. The type of overlap to consider(Weak or Strict). + cols: The names of columns containing the chromosome, start and end of the + genomic intervals, provided separately for each set. + on_cols: List of additional column names for clustering. default is None. + output_type: Type of the output. default is "polars.LazyFrame", "polars.DataFrame", or "pandas.DataFrame" are also supported. + streaming: **EXPERIMENTAL** If True, use Polars [streaming](features.md#streaming-out-of-core-processing) engine. + + Returns: + **polars.LazyFrame** or polars.DataFrame or pandas.DataFrame of the overlapping intervals. + + Example: + + Todo: + Support for on_cols. + """ + suffixes = ("_1", "_2") + _validate_overlap_input(cols, cols, on_cols, suffixes, output_type, how="inner") + + + my_ctx = get_py_ctx() + cols = DEFAULT_INTERVAL_COLUMNS if cols is None else cols + contig = cols[0] + start = cols[1] + end = cols[2] + + + on_cols = [] if on_cols is None else on_cols + on_cols = [contig] + on_cols + + df = read_df_to_datafusion(my_ctx, df) + df_schema = df.schema() + start_type = df_schema.field(start).type + end_type = df_schema.field(end).type + # TODO: make sure to avoid conflicting column names + start_end = "start_end" + is_start_end = "is_start_or_end" + current_intervals = "current_intervals" + n_intervals = "n_intervals" + + end_positions = df.select(*([(col(end) + min_dist).alias(start_end), literal(-1).alias(is_start_end)] + on_cols)) + start_positions = df.select(*([col(start).alias(start_end), literal(1).alias(is_start_end)] + on_cols)) + all_positions = start_positions.union(end_positions) + start_end_type = all_positions.schema().field(start_end).type + all_positions = all_positions.select(*([col(start_end).cast(start_end_type), col(is_start_end)] + on_cols)) + + sorting = [col(start_end).sort(), col(is_start_end).sort(ascending=(overlap_filter == FilterOp.Strict))] + all_positions = all_positions.sort(*sorting) + + on_cols_expr = [col(c) for c in on_cols] + + win = datafusion.expr.Window( + partition_by=on_cols_expr, + order_by=sorting, + ) + all_positions = all_positions.select(*([start_end, is_start_end, + datafusion.functions.sum(col(is_start_end)).over(win).alias(current_intervals)] + on_cols + + [datafusion.functions.row_number(partition_by = on_cols_expr, order_by=sorting).alias(n_intervals)])) + all_positions = all_positions.filter( + ((col(current_intervals) == 0) & (col(is_start_end) == -1)) | ((col(current_intervals) == 1) & (col(is_start_end) == 1)) + ) + all_positions = all_positions.select(*([start_end, is_start_end] + on_cols + [((col(n_intervals) - datafusion.functions.lag(col(n_intervals), partition_by=on_cols_expr) + 1) / 2).alias(n_intervals)])) + result = all_positions.select(*([(col(start_end) - min_dist).alias(end), is_start_end, + datafusion.functions.lag(col(start_end), partition_by=on_cols_expr).alias(start)] + on_cols + [n_intervals])) + result = result.filter(col(is_start_end) == -1) + result = result.select(*([contig, col(start).cast(start_type), col(end).cast(end_type)] + on_cols[1:] + [n_intervals])) + + return convert_result(result, output_type, streaming) + +def cluster( + df: Union[str, pl.DataFrame, pl.LazyFrame, pd.DataFrame, datafusion.dataframe.DataFrame], + overlap_filter: FilterOp = FilterOp.Strict, + min_dist: float = 0, + cols: Union[list[str], None] = ["chrom", "start", "end"], on_cols: Union[list[str], None] = None, + output_type: str = "polars.LazyFrame", + streaming: bool = False, +) -> Union[pl.LazyFrame, pl.DataFrame, pd.DataFrame, datafusion.dataframe.DataFrame]: + """ + Merge overlapping intervals. It is assumed that start < end. + + + Parameters: + df: Can be a path to a file, a polars DataFrame, or a pandas DataFrame. CSV with a header, BED and Parquet are supported. + overlap_filter: FilterOp, optional. The type of overlap to consider(Weak or Strict). + cols: The names of columns containing the chromosome, start and end of the + genomic intervals, provided separately for each set. + on_cols: List of additional column names for clustering. default is None. + output_type: Type of the output. default is "polars.LazyFrame", "polars.DataFrame", or "pandas.DataFrame" are also supported. + streaming: **EXPERIMENTAL** If True, use Polars [streaming](features.md#streaming-out-of-core-processing) engine. + + Returns: + **polars.LazyFrame** or polars.DataFrame or pandas.DataFrame of the overlapping intervals. + + Example: + + """ + suffixes = ("_1", "_2") + _validate_overlap_input(cols, cols, on_cols, suffixes, output_type, how="inner") + + my_ctx = get_py_ctx() + cols = DEFAULT_INTERVAL_COLUMNS if cols is None else cols + contig = cols[0] + start = cols[1] + end = cols[2] + + + on_cols = [] if on_cols is None else on_cols + on_cols = [contig] + on_cols + + df = read_df_to_datafusion(my_ctx, df) + df_schema = df.schema() + print(df_schema) + print(start) + start_type = df_schema.field(start).type + end_type = df_schema.field(end).type + # TODO: make sure to avoid conflicting column names + start_end = "start_end" + is_start_end = "is_start_or_end" + current_intervals = "current_intervals" + n_intervals = "n_intervals" + row_no = "row_no" + cluster_start = "cluster_start" + cluster_end = "cluster_end" + does_cluster_start = "does_cluster_start" + does_cluster_end = "does_cluster_end" + cluster_id = "cluster" + + end_positions = df.select(*([(col(end) + min_dist).alias(start_end), literal(-1).alias(is_start_end), start, end, + literal(0).alias(row_no)] + on_cols)) + start_positions = df.select(*([col(start).alias(start_end), literal(1).alias(is_start_end), start, end, + datafusion.functions.row_number().alias(row_no)] + on_cols)) + all_positions = start_positions.union(end_positions) + start_end_type = all_positions.schema().field(start_end).type + all_positions = all_positions.select(*([col(start_end).cast(start_end_type), col(is_start_end), start, end, row_no] + on_cols)) + + sorting = [col(start_end).sort(), col(is_start_end).sort(ascending=(overlap_filter == FilterOp.Strict))] + + on_cols_expr = [col(c) for c in on_cols] + win = datafusion.expr.Window( + partition_by=on_cols_expr, + order_by=sorting, + ) + all_positions = all_positions.select(*([start_end, is_start_end, start, end, row_no, + datafusion.functions.sum(col(is_start_end)).over(win).alias(current_intervals)] + on_cols)) + + all_positions = all_positions.select(*([ + start, + end, + start_end, + is_start_end, + current_intervals, + row_no, + ((col(current_intervals) == 1) & (col(is_start_end) == 1)).cast(pyarrow.int64()).alias(does_cluster_start)] + on_cols)) + + all_positions = all_positions.select(*([ + row_no, + start, + end, + start_end, + is_start_end, + does_cluster_start, + current_intervals, + datafusion.functions.sum(col(does_cluster_start)) + .over(datafusion.expr.Window( + order_by = [c.sort() for c in on_cols_expr] + sorting + )) + .alias(cluster_id), + ] + on_cols)) + all_positions = all_positions.filter(col(is_start_end) == 1) + cluster_window = datafusion.expr.Window( + partition_by=[col(cluster_id)], + window_frame=datafusion.expr.WindowFrame( + units='rows', + start_bound=None, + end_bound=None, + ) + ) + + all_positions = all_positions.select(*([ + row_no, + start, + end, + cluster_id, + datafusion.functions.min(col(start)).over(cluster_window).alias(cluster_start), + datafusion.functions.max(col(end)).over(cluster_window).alias(cluster_end)] + on_cols)) + all_positions = all_positions.sort(col(row_no).sort()) + + all_positions = all_positions.select(*(on_cols + [ + start, + end, + (col(cluster_id) - 1).alias(cluster_id), + cluster_start, + cluster_end])) + return convert_result(all_positions, output_type, streaming) + +def coverage( + df1: Union[str, pl.DataFrame, pl.LazyFrame, pd.DataFrame, datafusion.dataframe.DataFrame], + df2: Union[str, pl.DataFrame, pl.LazyFrame, pd.DataFrame, datafusion.dataframe.DataFrame], + suffixes: tuple[str, str] = ("", "_"), + return_input: bool = True, cols1: Union[list[str], None] = ["chrom", "start", "end"], cols2: Union[list[str], None] = ["chrom", "start", "end"], + on_cols: Union[list[str], None] = None, output_type: str = "polars.LazyFrame", streaming: bool = False, - read_options: Union[ReadOptions, None] = None, -) -> Union[pl.LazyFrame, pl.DataFrame, pd.DataFrame, datafusion.DataFrame]: +) -> Union[pl.LazyFrame, pl.DataFrame, pd.DataFrame, datafusion.dataframe.DataFrame]: """ - Calculate intervals coverage. + Count coverage of intervals. Bioframe inspired API. Parameters: - df1: Can be a path to a file, a polars DataFrame, or a pandas DataFrame or a registered table (see [register_vcf](api.md#polars_bio.register_vcf)). CSV with a header, BED and Parquet are supported. - df2: Can be a path to a file, a polars DataFrame, or a pandas DataFrame or a registered table. CSV with a header, BED and Parquet are supported. - overlap_filter: FilterOp, optional. The type of overlap to consider(Weak or Strict). + df1: Can be a path to a file, a polars DataFrame, or a pandas DataFrame. CSV with a header, BED and Parquet are supported. + df2: Can be a path to a file, a polars DataFrame, or a pandas DataFrame. CSV with a header, BED and Parquet are supported. + suffixes: Suffixes for the columns of the two overlapped sets. + return_input: If true, return input. cols1: The names of columns containing the chromosome, start and end of the genomic intervals, provided separately for each set. cols2: The names of columns containing the chromosome, start and end of the genomic intervals, provided separately for each set. - suffixes: Suffixes for the columns of the two overlapped sets. on_cols: List of additional column names to join on. default is None. - output_type: Type of the output. default is "polars.LazyFrame", "polars.DataFrame", or "pandas.DataFrame" or "datafusion.DataFrame" are also supported. - streaming: **EXPERIMENTAL** If True, use Polars [streaming](features.md#streaming) engine. - read_options: Additional options for reading the input files. - + output_type: Type of the output. default is "polars.LazyFrame", "polars.DataFrame", or "pandas.DataFrame" are also supported. + streaming: **EXPERIMENTAL** If True, use Polars [streaming](features.md#streaming-out-of-core-processing) engine. Returns: **polars.LazyFrame** or polars.DataFrame or pandas.DataFrame of the overlapping intervals. - Note: - The default output format, i.e. [LazyFrame](https://docs.pola.rs/api/python/stable/reference/lazyframe/index.html), is recommended for large datasets as it supports output streaming and lazy evaluation. - This enables efficient processing of large datasets without loading the entire output dataset into memory. - Example: Todo: - Support for on_cols. - """ - + Support return_input. + """ _validate_overlap_input(cols1, cols2, on_cols, suffixes, output_type, how="inner") + my_ctx = get_py_ctx() + + df1 = read_df_to_datafusion(my_ctx, df1) + df2 = read_df_to_datafusion(my_ctx, df2) + df2 = merge(df2, output_type="datafusion.DataFrame", cols=cols2, on_cols=on_cols) + + on_cols = [] if on_cols is None else on_cols cols1 = DEFAULT_INTERVAL_COLUMNS if cols1 is None else cols1 cols2 = DEFAULT_INTERVAL_COLUMNS if cols2 is None else cols2 - range_options = RangeOptions( - range_op=RangeOp.Coverage, - filter_op=overlap_filter, - suffixes=suffixes, - columns_1=cols1, - columns_2=cols2, - streaming=streaming, + cols1 = list(cols1) + cols2 = list(cols2) + + # TODO: guarantee no collisions + contig = "contig" + row_id = "row_id" + interval_counter = "interval_counter" + interval_sum = "interval_sum" + position = "position" + coverage = "coverage" + + suff, _ = suffixes + + df1 = df1.select(*([(literal(2) * datafusion.functions.row_number()).alias(row_id)] + cols1 + on_cols)) + + df1_starts = df1.select(*([ + row_id, + col(cols1[0]).alias(contig), + col(cols1[1]).alias(position), + literal(0).alias(interval_counter), + literal(0).alias(interval_sum)] + on_cols)) + df1_ends = df1.select(*([ + (col(row_id) + 1).alias(row_id), + col(cols1[0]).alias(contig), + col(cols1[2]).alias(position), + literal(0).alias(interval_counter), + literal(0).alias(interval_sum)] + on_cols)) + + df2_starts = df2.select(*([ + literal(0).alias(row_id), + col(cols2[0]).alias(contig), + col(cols2[1]).alias(position), + literal(1).alias(interval_counter), + (literal(0) - col(cols2[1])).alias(interval_sum)] + on_cols)) + df2_ends = df2.select(*([ + literal(0).alias(row_id), + col(cols2[0]).alias(contig), + col(cols2[2]).alias(position), + literal(-1).alias(interval_counter), + col(cols2[2]).alias(interval_sum)] + on_cols)) + + df = df1_starts.union(df1_ends).union(df2_starts).union(df2_ends) + + on_cols = [contig] + on_cols + on_cols_expr = [col(c) for c in on_cols] + + win = datafusion.expr.Window( + partition_by=on_cols_expr, + order_by=[col(position).sort()] ) - return range_operation(df2, df1, range_options, output_type, ctx, read_options) + df = df.select(*([ + row_id, + position, + datafusion.functions.sum(col(interval_counter)).over(win).alias(interval_counter), + datafusion.functions.sum(col(interval_sum)).over(win).alias(interval_sum), + ] + on_cols)) + df = df.select(*([ + row_id, + position, + ((col(interval_counter) * col(position)) + col(interval_sum)).alias(interval_sum), + ] + on_cols)) + df = df.filter(col(row_id) > 0) + df = df.sort(col(row_id)) + + start_result = cols1[1] + suff + end_result = cols1[2] + suff + + df = df.select(*([ + row_id, + col(position).alias(start_result), + datafusion.functions.lead(col(position)).alias(end_result), + (datafusion.functions.lead(col(interval_sum)) - col(interval_sum)).alias(coverage) + ] + on_cols)) + df = df.filter((col(row_id) % 2) == 0) + df = df.select(*([ + col(contig).alias(cols1[0] + suff), + start_result, + end_result] + on_cols[1:] + [ + coverage])) + return convert_result(df, output_type, streaming) + + def count_overlaps( df1: Union[str, pl.DataFrame, pl.LazyFrame, pd.DataFrame], df2: Union[str, pl.DataFrame, pl.LazyFrame, pd.DataFrame], diff --git a/polars_bio/range_op_helpers.py b/polars_bio/range_op_helpers.py index 3257c97f..3189fb4a 100644 --- a/polars_bio/range_op_helpers.py +++ b/polars_bio/range_op_helpers.py @@ -149,7 +149,7 @@ def _validate_overlap_input(col1, col2, on_cols, suffixes, output_type, how): "polars.DataFrame", "pandas.DataFrame", "datafusion.DataFrame", - ], "Only polars.LazyFrame, polars.DataFrame, and pandas.DataFrame are supported" + ], "Only polars.LazyFrame, polars.DataFrame, datafusion.DataFrame, and pandas.DataFrame are supported" assert how in ["inner"], "Only inner join is supported" diff --git a/tests/_expected.py b/tests/_expected.py index b2f9be9f..cfe46712 100644 --- a/tests/_expected.py +++ b/tests/_expected.py @@ -60,6 +60,49 @@ | chr2 | 22000 | 22300 | 2 | """ +EXPECTED_CLUSTER = """ +| contig | pos_start | pos_end | cluster | cluster_start | cluster_end | +|:---------|------------:|----------:|----------:|----------------:|--------------:| +| chr1 | 150 | 250 | 0 | 100 | 300 | +| chr1 | 190 | 300 | 0 | 100 | 300 | +| chr1 | 300 | 501 | 1 | 300 | 700 | +| chr1 | 500 | 700 | 1 | 300 | 700 | +| chr1 | 22000 | 22300 | 3 | 22000 | 22300 | +| chr1 | 15000 | 15001 | 2 | 10000 | 20000 | +| chr2 | 150 | 250 | 4 | 100 | 300 | +| chr2 | 190 | 300 | 4 | 100 | 300 | +| chr2 | 300 | 500 | 5 | 300 | 700 | +| chr2 | 500 | 700 | 5 | 300 | 700 | +| chr2 | 22000 | 22300 | 7 | 22000 | 22300 | +| chr2 | 15000 | 15001 | 6 | 10000 | 20000 | +| chr1 | 100 | 190 | 0 | 100 | 300 | +| chr1 | 200 | 290 | 0 | 100 | 300 | +| chr1 | 400 | 600 | 1 | 300 | 700 | +| chr1 | 10000 | 20000 | 2 | 10000 | 20000 | +| chr1 | 22100 | 22101 | 3 | 22000 | 22300 | +| chr2 | 100 | 190 | 4 | 100 | 300 | +| chr2 | 200 | 290 | 4 | 100 | 300 | +| chr2 | 400 | 600 | 5 | 300 | 700 | +| chr2 | 10000 | 20000 | 6 | 10000 | 20000 | +| chr2 | 22100 | 22101 | 7 | 22000 | 22300 | +""" + +EXPECTED_COVERAGE = """ +| contig | pos_start | pos_end | coverage | +|:---------|------------:|----------:|-----------:| +| chr1 | 100 | 190 | 40 | +| chr1 | 200 | 290 | 90 | +| chr1 | 400 | 600 | 200 | +| chr1 | 10000 | 20000 | 0 | +| chr1 | 22100 | 22100 | 0 | +| chr2 | 100 | 190 | 40 | +| chr2 | 200 | 290 | 90 | +| chr2 | 400 | 600 | 200 | +| chr2 | 10000 | 20000 | 0 | +| chr2 | 22100 | 22100 | 0 | +| chr3 | 100 | 200 | 0 | +""" + EXPECTED_COUNT_OVERLAPS = """ +--------+-----------+---------+-------+ | contig | pos_start | pos_end | count | @@ -107,6 +150,22 @@ .astype({"count": "int64"}) ) +PD_DF_CLUSTER = ( + mdpd.from_md(EXPECTED_CLUSTER) + .astype({"pos_start": "int64"}) + .astype({"pos_end": "int64"}) + .astype({"cluster": "int64"}) + .astype({"cluster_start": "int64"}) + .astype({"cluster_end": "int64"}) +) + +PD_DF_COVERAGE = ( + mdpd.from_md(EXPECTED_COVERAGE).astype({ + "pos_start": "int64", + "pos_end": "int64", + "coverage": "int64", + }) +) PD_DF_OVERLAP = PD_DF_OVERLAP.sort_values(by=list(PD_DF_OVERLAP.columns)).reset_index( drop=True @@ -117,6 +176,15 @@ PD_DF_MERGE = PD_DF_MERGE.sort_values(by=list(PD_DF_MERGE.columns)).reset_index( drop=True ) +PD_DF_CLUSTER = PD_DF_CLUSTER.sort_values(by=list(PD_DF_CLUSTER.columns)).reset_index( + drop=True +) +PD_DF_COVERAGE = PD_DF_COVERAGE.sort_values(by=list(PD_DF_COVERAGE.columns)).reset_index( + drop=True +) +PD_DF_COUNT_OVERLAPS = PD_DF_COUNT_OVERLAPS.sort_values(by=list(PD_DF_COUNT_OVERLAPS.columns)).reset_index( + drop=True +) PD_DF_COUNT_OVERLAPS = PD_DF_COUNT_OVERLAPS.sort_values( by=list(PD_DF_COUNT_OVERLAPS.columns) ).reset_index(drop=True) @@ -133,18 +201,26 @@ DF_MERGE_PATH = f"{DATA_DIR}/merge/input.csv" PD_MERGE_DF = pd.read_csv(DF_MERGE_PATH) + DF_COUNT_OVERLAPS_PATH1 = f"{DATA_DIR}/count_overlaps/targets.csv" DF_COUNT_OVERLAPS_PATH2 = f"{DATA_DIR}/count_overlaps/reads.csv" PD_COUNT_OVERLAPS_DF1 = pd.read_csv(DF_COUNT_OVERLAPS_PATH1) PD_COUNT_OVERLAPS_DF2 = pd.read_csv(DF_COUNT_OVERLAPS_PATH2) +DF_CLUSTER_PATH = f"{DATA_DIR}/cluster/input.csv" +PD_CLUSTER_DF = pd.read_csv(DF_CLUSTER_PATH) + +DF_COVERAGE_PATH1 = f"{DATA_DIR}/coverage/reads.csv" +DF_COVERAGE_PATH2 = f"{DATA_DIR}/coverage/targets.csv" +PD_COVERAGE_DF1 = pd.read_csv(DF_COVERAGE_PATH1) +PD_COVERAGE_DF2 = pd.read_csv(DF_COVERAGE_PATH2) + +BIO_PD_DF1 = pd.read_parquet(f"{DATA_DIR}/exons/").astype({"pos_start": "int64", "pos_end": "int64"}) +BIO_PD_DF2 = pd.read_parquet(f"{DATA_DIR}/fBrain-DS14718/").astype({"pos_start": "int64", "pos_end": "int64"}) BIO_DF_PATH1 = f"{DATA_DIR}/exons/*.parquet" BIO_DF_PATH2 = f"{DATA_DIR}/fBrain-DS14718/*.parquet" -BIO_PD_DF1 = pd.read_parquet(f"{DATA_DIR}/exons/") -BIO_PD_DF2 = pd.read_parquet(f"{DATA_DIR}/fBrain-DS14718/") - # Polars PL_DF_OVERLAP = pl.DataFrame(PD_DF_OVERLAP) @@ -158,6 +234,13 @@ PL_DF_MERGE = pl.DataFrame(PD_DF_MERGE) PL_MERGE_DF = pl.DataFrame(PD_MERGE_DF) +PL_DF_CLUSTER = pl.DataFrame(PD_DF_CLUSTER) +PL_CLUSTER_DF = pl.DataFrame(PD_MERGE_DF) + +PL_DF_COVERAGE = pl.DataFrame(PD_DF_COVERAGE) +PL_COVERAGE_DF1 = pl.DataFrame(PD_COVERAGE_DF1) +PL_COVERAGE_DF2 = pl.DataFrame(PD_COVERAGE_DF2) + PL_DF_COUNT_OVERLAPS = pl.DataFrame(PD_DF_COUNT_OVERLAPS) PL_COUNT_OVERLAPS_DF1 = pl.DataFrame(PD_COUNT_OVERLAPS_DF1) PL_COUNT_OVERLAPS_DF2 = pl.DataFrame(PD_COUNT_OVERLAPS_DF2) diff --git a/tests/data/cluster/input.csv b/tests/data/cluster/input.csv new file mode 100644 index 00000000..cda46bf0 --- /dev/null +++ b/tests/data/cluster/input.csv @@ -0,0 +1,23 @@ +contig,pos_start,pos_end +chr1,150,250 +chr1,190,300 +chr1,300,501 +chr1,500,700 +chr1,22000,22300 +chr1,15000,15001 +chr2,150,250 +chr2,190,300 +chr2,300,500 +chr2,500,700 +chr2,22000,22300 +chr2,15000,15001 +chr1,100,190 +chr1,200,290 +chr1,400,600 +chr1,10000,20000 +chr1,22100,22101 +chr2,100,190 +chr2,200,290 +chr2,400,600 +chr2,10000,20000 +chr2,22100,22101 diff --git a/tests/data/coverage/reads.csv b/tests/data/coverage/reads.csv new file mode 100644 index 00000000..5d34c380 --- /dev/null +++ b/tests/data/coverage/reads.csv @@ -0,0 +1,12 @@ +contig,pos_start,pos_end +chr1,100,190 +chr1,200,290 +chr1,400,600 +chr1,10000,20000 +chr1,22100,22100 +chr2,100,190 +chr2,200,290 +chr2,400,600 +chr2,10000,20000 +chr2,22100,22100 +chr3,100,200 diff --git a/tests/data/coverage/targets.csv b/tests/data/coverage/targets.csv new file mode 100644 index 00000000..1d7ba9ee --- /dev/null +++ b/tests/data/coverage/targets.csv @@ -0,0 +1,14 @@ +contig,pos_start,pos_end +chr1,150,250 +chr1,190,300 +chr1,300,501 +chr1,500,700 +chr1,22000,22300 +chr1,15000,15000 +chr2,150,250 +chr2,190,300 +chr2,300,500 +chr2,500,700 +chr2,22000,22300 +chr2,15000,15000 +chr3,234,300 diff --git a/tests/test_bioframe.py b/tests/test_bioframe.py index 7b791b5d..eba33d20 100644 --- a/tests/test_bioframe.py +++ b/tests/test_bioframe.py @@ -92,10 +92,48 @@ class TestBioframe: output_type="polars.LazyFrame", ) result_bio_merge = bf.merge( - BIO_PD_DF1, cols=("contig", "pos_start", "pos_end"), min_dist=None - ).astype( - {"pos_start": "int32", "pos_end": "int32"} - ) # bioframe changes input types + BIO_PD_DF1, + cols=("contig", "pos_start", "pos_end"), + min_dist=None + ) + + result_cluster = pb.cluster( + BIO_PD_DF1, + cols=("contig", "pos_start", "pos_end"), + output_type="pandas.DataFrame", + ) + result_cluster_lf = pb.cluster( + BIO_PD_DF1, + cols=("contig", "pos_start", "pos_end"), + output_type="polars.LazyFrame", + ) + result_bio_cluster = bf.cluster( + BIO_PD_DF1, + cols=("contig", "pos_start", "pos_end"), + min_dist=None + ) + + result_coverage = pb.coverage( + BIO_PD_DF1, + BIO_PD_DF2, + cols1=("contig", "pos_start", "pos_end"), + cols2=("contig", "pos_start", "pos_end"), + output_type="pandas.DataFrame", + ) + result_coverage_lf = pb.coverage( + BIO_PD_DF1, + BIO_PD_DF2, + cols1=("contig", "pos_start", "pos_end"), + cols2=("contig", "pos_start", "pos_end"), + output_type="polars.LazyFrame", + ) + + result_bio_coverage = bf.coverage( + BIO_PD_DF1, + BIO_PD_DF2, + cols1=("contig", "pos_start", "pos_end"), + cols2=("contig", "pos_start", "pos_end"), + ) def test_overlap_count(self): assert len(self.result_overlap) == len(self.result_bio_overlap) @@ -156,7 +194,7 @@ def test_overlaps_schema_rows(self): .reset_index(drop=True) ) pd.testing.assert_frame_equal(result, expected) - pd.testing.assert_frame_equal(result_naive, expected, check_dtype=True) + pd.testing.assert_frame_equal(result_naive, expected, check_dtype=False) def test_merge_count(self): assert len(self.result_merge) == len(self.result_bio_merge) @@ -170,45 +208,29 @@ def test_merge_schema_rows(self): by=list(self.result_merge.columns) ).reset_index(drop=True) pd.testing.assert_frame_equal(result, expected) - + + def test_cluster_count(self): + assert len(self.result_cluster) == len(self.result_bio_cluster) + assert len(self.result_cluster_lf.collect()) == len(self.result_bio_cluster) + + def test_cluster_schema_rows(self): + expected = self.result_bio_cluster.sort_values( + by=list(self.result_cluster.columns) + ).reset_index(drop=True) + result = self.result_cluster.sort_values( + by=list(self.result_cluster.columns) + ).reset_index(drop=True) + pd.testing.assert_frame_equal(result, expected) + def test_coverage_count(self): - result = pb.coverage( - BIO_PD_DF1, - BIO_PD_DF2, - cols1=("contig", "pos_start", "pos_end"), - cols2=("contig", "pos_start", "pos_end"), - output_type="pandas.DataFrame", - overlap_filter=FilterOp.Strict, - ) - result_bio = bf.coverage( - BIO_PD_DF1, - BIO_PD_DF2, - cols1=("contig", "pos_start", "pos_end"), - cols2=("contig", "pos_start", "pos_end"), - suffixes=("_1", "_2"), - ) - assert len(result) == len(result_bio) + assert len(self.result_coverage) == len(self.result_bio_coverage) + assert len(self.result_coverage_lf.collect()) == len(self.result_bio_coverage) def test_coverage_schema_rows(self): - result = pb.coverage( - BIO_PD_DF1, - BIO_PD_DF2, - cols1=("contig", "pos_start", "pos_end"), - cols2=("contig", "pos_start", "pos_end"), - output_type="pandas.DataFrame", - overlap_filter=FilterOp.Strict, - ) - result_bio = bf.coverage( - BIO_PD_DF1, - BIO_PD_DF2, - cols1=("contig", "pos_start", "pos_end"), - cols2=("contig", "pos_start", "pos_end"), - suffixes=("_1", "_2"), - ) - expected = ( - result_bio.sort_values(by=list(result.columns)) - .reset_index(drop=True) - .astype({"coverage": "int64"}) - ) - result = result.sort_values(by=list(result.columns)).reset_index(drop=True) + expected = self.result_bio_coverage.sort_values( + by=list(self.result_coverage.columns) + ).reset_index(drop=True) + result = self.result_coverage.sort_values( + by=list(self.result_coverage.columns) + ).reset_index(drop=True) pd.testing.assert_frame_equal(result, expected) diff --git a/tests/test_native.py b/tests/test_native.py index d4b1c397..728ec666 100644 --- a/tests/test_native.py +++ b/tests/test_native.py @@ -10,12 +10,20 @@ DF_MERGE_PATH, DF_NEAREST_PATH1, DF_NEAREST_PATH2, + DF_COVERAGE_PATH1, + DF_COVERAGE_PATH2, + DF_MERGE_PATH, + DF_CLUSTER_PATH, DF_OVER_PATH1, DF_OVER_PATH2, PD_DF_COUNT_OVERLAPS, PD_DF_MERGE, PD_DF_NEAREST, + PD_DF_COVERAGE, PD_DF_OVERLAP, + PD_DF_MERGE, + PD_DF_CLUSTER, + PD_DF_COUNT_OVERLAPS, ) import polars_bio as pb @@ -64,6 +72,63 @@ def test_nearest_schema_rows(self): expected = PD_DF_NEAREST pd.testing.assert_frame_equal(result, expected) +class TestMergeNative: + result = pb.merge( + DF_MERGE_PATH, + cols=("contig", "pos_start", "pos_end"), + output_type="pandas.DataFrame", + overlap_filter=FilterOp.Strict, + ) + + def test_merge_count(self): + print(self.result) + assert len(self.result) == len(PD_DF_MERGE) + + def test_merge_schema_rows(self): + result = self.result.sort_values(by=list(self.result.columns)).reset_index( + drop=True + ) + expected = PD_DF_MERGE + pd.testing.assert_frame_equal(result, expected) + +class TestClusterNative: + result = pb.cluster( + DF_CLUSTER_PATH, + cols=("contig", "pos_start", "pos_end"), + output_type="pandas.DataFrame", + overlap_filter=FilterOp.Strict, + ) + + def test_cluster_count(self): + print(self.result) + assert len(self.result) == len(PD_DF_CLUSTER) + + def test_cluster_schema_rows(self): + result = self.result.sort_values(by=list(self.result.columns)).reset_index( + drop=True + ) + expected = PD_DF_CLUSTER + pd.testing.assert_frame_equal(result, expected) + +class TestCoverageNative: + result = pb.coverage( + DF_COVERAGE_PATH1, + DF_COVERAGE_PATH2, + cols1=("contig", "pos_start", "pos_end"), + cols2=("contig", "pos_start", "pos_end"), + output_type="pandas.DataFrame", + ) + + def test_coverage_count(self): + print(self.result) + assert len(self.result) == len(PD_DF_COVERAGE) + + def test_coverage_schema_rows(self): + result = self.result.sort_values(by=list(self.result.columns)).reset_index( + drop=True + ) + expected = PD_DF_COVERAGE + pd.testing.assert_frame_equal(result, expected) class TestCountOverlapsNative: result = pb.count_overlaps( @@ -115,14 +180,12 @@ class TestCoverageNative: cols1=("contig", "pos_start", "pos_end"), cols2=("contig", "pos_start", "pos_end"), output_type="pandas.DataFrame", - overlap_filter=FilterOp.Strict, ) result_bio = bf.coverage( BIO_PD_DF1, BIO_PD_DF2, cols1=("contig", "pos_start", "pos_end"), cols2=("contig", "pos_start", "pos_end"), - suffixes=("_1", "_2"), ) def test_coverage_count(self): @@ -133,5 +196,5 @@ def test_coverage_schema_rows(self): result = self.result.sort_values(by=list(self.result.columns)).reset_index( drop=True ) - expected = self.result_bio.astype({"coverage": "int64"}) - pd.testing.assert_frame_equal(result, expected) + expected = self.result_bio + pd.testing.assert_frame_equal(result, expected, check_dtype=False) diff --git a/tests/test_pandas.py b/tests/test_pandas.py index 28258eab..85f6cb75 100644 --- a/tests/test_pandas.py +++ b/tests/test_pandas.py @@ -5,10 +5,17 @@ PD_DF_COUNT_OVERLAPS, PD_DF_MERGE, PD_DF_NEAREST, + PD_DF_COVERAGE, + PD_DF_MERGE, + PD_DF_CLUSTER, PD_DF_OVERLAP, PD_MERGE_DF, PD_NEAREST_DF1, PD_NEAREST_DF2, + PD_COVERAGE_DF1, + PD_COVERAGE_DF2, + PD_MERGE_DF, + PD_CLUSTER_DF, PD_OVERLAP_DF1, PD_OVERLAP_DF2, ) @@ -58,6 +65,58 @@ def test_nearest_schema_rows(self): expected = PD_DF_NEAREST pd.testing.assert_frame_equal(result, expected) +class TestMergePandas: + result = pb.merge( + PD_MERGE_DF, + cols=("contig", "pos_start", "pos_end"), + output_type="pandas.DataFrame", + ) + + def test_merge_count(self): + assert len(self.result) == len(PD_DF_MERGE) + + def test_merge_schema_rows(self): + result = self.result.sort_values(by=list(self.result.columns)).reset_index( + drop=True + ) + expected = PD_DF_MERGE + pd.testing.assert_frame_equal(result, expected) + +class TestClusterPandas: + result = pb.cluster( + PD_CLUSTER_DF, + cols=("contig", "pos_start", "pos_end"), + output_type="pandas.DataFrame", + ) + + def test_cluster_count(self): + assert len(self.result) == len(PD_DF_CLUSTER) + + def test_cluster_schema_rows(self): + result = self.result.sort_values(by=list(self.result.columns)).reset_index( + drop=True + ) + expected = PD_DF_CLUSTER + pd.testing.assert_frame_equal(result, expected) + +class TestCoveragePandas: + result = pb.coverage( + PD_COVERAGE_DF1, + PD_COVERAGE_DF2, + cols1=("contig", "pos_start", "pos_end"), + cols2=("contig", "pos_start", "pos_end"), + output_type="pandas.DataFrame", + ) + + def test_coverage_count(self): + assert len(self.result) == len(PD_DF_COVERAGE) + + def test_coverage_schema_rows(self): + result = self.result.sort_values(by=list(self.result.columns)).reset_index( + drop=True + ) + expected = PD_DF_COVERAGE + pd.testing.assert_frame_equal(result, expected) class TestCountOverlapsPandas: result_optim = pb.count_overlaps( diff --git a/tests/test_polars.py b/tests/test_polars.py index 026ab21e..6f38b4ca 100644 --- a/tests/test_polars.py +++ b/tests/test_polars.py @@ -3,13 +3,23 @@ PL_COUNT_OVERLAPS_DF2, PL_DF1, PL_DF2, + PL_DF_NEAREST, + PL_DF_COVERAGE, + PL_DF_MERGE, PL_DF_COUNT_OVERLAPS, PL_DF_MERGE, PL_DF_NEAREST, PL_DF_OVERLAP, - PL_MERGE_DF, + PL_DF_CLUSTER, PL_NEAREST_DF1, PL_NEAREST_DF2, + PL_COVERAGE_DF1, + PL_COVERAGE_DF2, + PL_MERGE_DF, + PL_CLUSTER_DF, + PL_COUNT_OVERLAPS_DF1, + PL_COUNT_OVERLAPS_DF2, + ) import polars_bio as pb @@ -77,6 +87,84 @@ def test_nearest_schema_rows_lazy(self): result = self.result_lazy.sort(by=self.result_lazy.columns) assert self.expected.equals(result) +class TestMergePolars: + result_frame = pb.merge( + PL_MERGE_DF, + output_type="polars.DataFrame", + cols=("contig", "pos_start", "pos_end"), + ) + result_lazy = pb.merge( + PL_MERGE_DF, + output_type="polars.LazyFrame", + cols=("contig", "pos_start", "pos_end"), + ).collect() + expected = PL_DF_MERGE + + def test_merge_count(self): + assert len(self.result_frame) == len(PL_DF_MERGE) + assert len(self.result_lazy) == len(PL_DF_MERGE) + + def test_merge_schema_rows(self): + result = self.result_frame.sort(by=self.result_frame.columns) + assert self.expected.equals(result) + + def test_merge_schema_rows_lazy(self): + result = self.result_lazy.sort(by=self.result_lazy.columns) + assert self.expected.equals(result) + +class TestClusterPolars: + result_frame = pb.cluster( + PL_CLUSTER_DF, + output_type="polars.DataFrame", + cols=("contig", "pos_start", "pos_end"), + ) + result_lazy = pb.cluster( + PL_CLUSTER_DF, + output_type="polars.LazyFrame", + cols=("contig", "pos_start", "pos_end"), + ).collect() + expected = PL_DF_CLUSTER + + def test_cluster_count(self): + assert len(self.result_frame) == len(PL_DF_CLUSTER) + assert len(self.result_lazy) == len(PL_DF_CLUSTER) + + def test_cluster_schema_rows(self): + result = self.result_frame.sort(by=self.result_frame.columns) + assert self.expected.equals(result) + + def test_cluster_schema_rows_lazy(self): + result = self.result_lazy.sort(by=self.result_lazy.columns) + assert self.expected.equals(result) + +class TestCoveragePolars: + result_frame = pb.coverage( + PL_COVERAGE_DF1, + PL_COVERAGE_DF2, + output_type="polars.DataFrame", + cols1=("contig", "pos_start", "pos_end"), + cols2=("contig", "pos_start", "pos_end"), + ) + result_lazy = pb.coverage( + PL_COVERAGE_DF1, + PL_COVERAGE_DF2, + output_type="polars.LazyFrame", + cols1=("contig", "pos_start", "pos_end"), + cols2=("contig", "pos_start", "pos_end"), + ).collect() + expected = PL_DF_COVERAGE + + def test_coverage_count(self): + assert len(self.result_frame) == len(PL_DF_COVERAGE) + assert len(self.result_lazy) == len(PL_DF_COVERAGE) + + def test_coverage_schema_rows(self): + result = self.result_frame.sort(by=self.result_frame.columns) + assert self.expected.equals(result) + + def test_coverage_schema_rows_lazy(self): + result = self.result_lazy.sort(by=self.result_lazy.columns) + assert self.expected.equals(result) class TestCountOverlapsPolars: result_frame = pb.count_overlaps( diff --git a/tests/test_polars_ext.py b/tests/test_polars_ext.py index b635cef2..6ba877f6 100644 --- a/tests/test_polars_ext.py +++ b/tests/test_polars_ext.py @@ -116,6 +116,7 @@ def test_merge(self): .to_pandas() .reset_index(drop=True) )""" + df_3 = ( bf.merge(df_1, min_dist=None) .sort_values(by=["chrom", "start", "end"]) @@ -135,6 +136,76 @@ def test_merge(self): print(df_4.columns) pd.testing.assert_frame_equal(df_3, df_4, check_dtype=False) + def test_cluster(self): + cols = ("chrom", "start", "end") + df_1 = ( + pb.read_table(self.file, schema="bed9") + .select(cols) + .collect() + .to_pandas() + .reset_index(drop=True) + ) + ''' + df_2 = ( + pb.read_table(self.file, schema="bed9") + .select(cols) + .collect() + .to_pandas() + .reset_index(drop=True) + )''' + df_3 = ( + bf.cluster(df_1, min_dist=None) + .sort_values(by=["chrom", "start", "end"]) + .reset_index(drop=True) + ) + # + df_4 = ( + pl.DataFrame(df_1) + .lazy() + .pb.cluster() + .collect() + .to_pandas() + .sort_values(by=["chrom", "start", "end"]) + .reset_index(drop=True) + ) + print(df_3.columns) + print(df_4.columns) + pd.testing.assert_frame_equal(df_3, df_4, check_dtype=False) + def test_coverage(self): + cols = ("chrom", "start", "end") + df_1 = ( + pb.read_table(self.file, schema="bed9") + .select(cols) + .collect() + .to_pandas() + .reset_index(drop=True) + ) + df_2 = ( + pb.read_table(self.file, schema="bed9") + .select(cols) + .collect() + .to_pandas() + .reset_index(drop=True) + ) + df_3 = ( + bf.coverage(df_1, df_2, suffixes=("", "_")) + .sort_values(by=["chrom", "start", "end"]) + .reset_index(drop=True) + ) + # + df_4 = ( + pl.DataFrame(df_1) + .lazy() + .pb.coverage(pl.DataFrame(df_2).lazy(), suffixes=("", "_")) + .collect() + .to_pandas() + .sort_values(by=["chrom", "start", "end"]) + .reset_index(drop=True) + ) + print(df_3.columns) + print(df_4.columns) + pd.testing.assert_frame_equal(df_3, df_4, check_dtype=False) + def test_count_overlaps(self): cols = ("chrom", "start", "end") df_1 = (