Skip to content

Commit 7cae189

Browse files
authored
fix: count_overlaps for non-parquet sources (#110)
1 parent 68dd32c commit 7cae189

File tree

8 files changed

+255
-130
lines changed

8 files changed

+255
-130
lines changed

polars_bio/range_op_helpers.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ def range_operation(
5858
merged_schema = pl.Schema(
5959
{**_get_schema(df1, ctx, None, read_options1), **{"count": pl.Int32}}
6060
)
61-
# print(merged_schema)
6261
else:
6362
df_schema1 = _get_schema(df1, ctx, range_options.suffixes[0], read_options1)
6463
df_schema2 = _get_schema(df2, ctx, range_options.suffixes[1], read_options2)
@@ -121,9 +120,10 @@ def range_operation(
121120
raise ValueError(
122121
"Input and output dataframes must be of the same type: either polars or pandas"
123122
)
124-
return range_operation_frame_wrapper(
125-
ctx, df1, df2, range_options
126-
).to_pandas()
123+
df = range_operation_frame_wrapper(ctx, df1, df2, range_options)
124+
print(range_options.range_op)
125+
print(df.schema())
126+
return df.to_pandas()
127127
else:
128128
raise ValueError(
129129
"Both dataframes must be of the same type: either polars or pandas or a path to a file"

polars_bio/range_wrappers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def range_operation_frame_wrapper(
2121
df2,
2222
range_options: RangeOptions,
2323
limit: Union[int, None] = None,
24-
) -> datafusion.dataframe:
24+
) -> datafusion.DataFrame:
2525
if range_options.range_op != RangeOp.CountOverlaps:
2626
return range_operation_frame(ctx, df1, df2, range_options)
2727
py_ctx = datafusion.SessionContext()

src/context.rs

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
use std::collections::HashMap;
2-
use std::sync::Arc;
32

43
use datafusion::config::ConfigOptions;
54
use datafusion::prelude::SessionConfig;
@@ -9,8 +8,6 @@ use log::debug;
98
use pyo3::{pyclass, pymethods, PyResult};
109
use sequila_core::session_context::SequilaConfig;
1110

12-
use crate::udtf::CountOverlapsFunction;
13-
1411
#[pyclass(name = "BioSessionContext")]
1512
// #[derive(Clone)]
1613
pub struct PyBioSessionContext {
@@ -28,10 +25,6 @@ impl PyBioSessionContext {
2825
pub fn new(seed: String, catalog_dir: String) -> PyResult<Self> {
2926
let ctx = create_context().unwrap();
3027
let session_config: HashMap<String, String> = HashMap::new();
31-
ctx.session.register_udtf(
32-
"count_overlaps",
33-
Arc::new(CountOverlapsFunction::new(ctx.session.clone())),
34-
);
3528

3629
Ok(PyBioSessionContext {
3730
ctx,

src/lib.rs

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -60,13 +60,17 @@ fn range_operation_frame(
6060
)
6161
.limit(0, Some(l))?,
6262
)),
63-
_ => Ok(PyDataFrame::new(do_range_operation(
64-
ctx,
65-
&rt,
66-
range_options,
67-
LEFT_TABLE.to_string(),
68-
RIGHT_TABLE.to_string(),
69-
))),
63+
_ => {
64+
let df = do_range_operation(
65+
ctx,
66+
&rt,
67+
range_options,
68+
LEFT_TABLE.to_string(),
69+
RIGHT_TABLE.to_string(),
70+
);
71+
let py_df = PyDataFrame::new(df);
72+
Ok(py_df)
73+
},
7074
}
7175
}
7276

src/operation.rs

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
use std::sync::Arc;
2+
13
use datafusion::catalog_common::TableReference;
24
use exon::ExonSession;
35
use log::{debug, info};
@@ -7,6 +9,7 @@ use tokio::runtime::Runtime;
79
use crate::context::set_option_internal;
810
use crate::option::{FilterOp, RangeOp, RangeOptions};
911
use crate::query::{count_overlaps_query, nearest_query, overlap_query};
12+
use crate::udtf::CountOverlapsProvider;
1013
use crate::utils::default_cols_to_string;
1114
use crate::DEFAULT_COLUMN_NAMES;
1215

@@ -150,17 +153,30 @@ async fn do_count_overlaps_naive(
150153
) -> datafusion::dataframe::DataFrame {
151154
let columns_1 = range_opts.columns_1.unwrap();
152155
let columns_2 = range_opts.columns_2.unwrap();
153-
let query = format!(
154-
"SELECT * FROM count_overlaps('{}', '{}', '{}', '{}', '{}', '{}', '{}', '{}' , false )",
156+
let session = &ctx.session;
157+
let right_table_ref = TableReference::from(right_table.clone());
158+
let right_schema = session
159+
.table(right_table_ref.clone())
160+
.await
161+
.unwrap()
162+
.schema()
163+
.as_arrow()
164+
.clone();
165+
let count_overlaps_provider = CountOverlapsProvider::new(
166+
Arc::new(session.clone()),
155167
left_table,
156168
right_table,
157-
columns_1[0],
158-
columns_1[1],
159-
columns_1[2],
160-
columns_2[0],
161-
columns_2[1],
162-
columns_2[2]
169+
right_schema,
170+
columns_1,
171+
columns_2,
172+
range_opts.filter_op.unwrap(),
173+
false,
163174
);
175+
session.deregister_table("count_overlaps").unwrap();
176+
session
177+
.register_table("count_overlaps", Arc::new(count_overlaps_provider))
178+
.unwrap();
179+
let query = "SELECT * FROM count_overlaps";
164180
debug!("Query: {}", query);
165181
ctx.sql(&query).await.unwrap()
166182
}

0 commit comments

Comments
 (0)