diff --git a/src/dataneuron/core/nlp_helpers/cte_handler.py b/src/dataneuron/core/nlp_helpers/cte_handler.py index f57d6c4..7494228 100644 --- a/src/dataneuron/core/nlp_helpers/cte_handler.py +++ b/src/dataneuron/core/nlp_helpers/cte_handler.py @@ -49,6 +49,12 @@ def extract_main_query(parsed): def filter_cte(cte_part, filter_function, client_id): filtered_ctes = [] + is_recursive = False + + for token in cte_part.tokens: + if token.ttype is Keyword and token.value.upper() == 'RECURSIVE': + is_recursive = True + def process_cte(token): if isinstance(token, sqlparse.sql.Identifier): cte_name = token.get_name() @@ -57,7 +63,7 @@ def process_cte(token): # Remove outer parentheses inner_query_str = str(inner_query)[1:-1] filtered_inner_query = filter_function( - sqlparse.parse(inner_query_str)[0], client_id) + sqlparse.parse(inner_query_str)[0], client_id, cte_name) filtered_ctes.append(f"{cte_name} AS ({filtered_inner_query})") for token in cte_part.tokens: @@ -68,7 +74,10 @@ def process_cte(token): process_cte(token) if filtered_ctes: - filtered_cte_str = "WITH " + ",\n".join(filtered_ctes) + if is_recursive: + filtered_cte_str = "WITH RECURSIVE " + ",\n".join(filtered_ctes) + else: + filtered_cte_str = "WITH " + ",\n".join(filtered_ctes) else: filtered_cte_str = "" return filtered_cte_str diff --git a/src/dataneuron/core/sql_query_filter.py b/src/dataneuron/core/sql_query_filter.py index 7fc1fe1..2927e71 100644 --- a/src/dataneuron/core/sql_query_filter.py +++ b/src/dataneuron/core/sql_query_filter.py @@ -28,20 +28,18 @@ def apply_client_filter(self, sql_query: str, client_id: int) -> str: return self._cleanup_whitespace(str(result)) - def _apply_filter_recursive(self, parsed, client_id): + def _apply_filter_recursive(self, parsed, client_id, cte_name: str = None): if self._is_cte_query(parsed): return handle_cte_query(parsed, self._apply_filter_recursive, client_id) - if isinstance(parsed, Token) and parsed.ttype is DML: - return self._apply_filter_to_single_query(str(parsed), client_id) - elif self._contains_set_operation(parsed): - return self._handle_set_operation(parsed, client_id) - elif self._contains_subquery(parsed): - return self._handle_subquery(parsed, client_id) - else: - filtered_query = self._apply_filter_to_single_query( - str(parsed), client_id) - return self._handle_where_subqueries(sqlparse.parse(filtered_query)[0], client_id) + for token in parsed.tokens: + if isinstance(token, Token) and token.ttype is DML: + if self._contains_set_operation(parsed): + return self._handle_set_operation(parsed, client_id, True, cte_name) if cte_name else self._handle_set_operation(parsed, client_id) + elif self._contains_subquery(parsed): + return self._handle_subquery(parsed, client_id) + else: + return self._apply_filter_to_single_query(str(parsed), client_id) def _contains_set_operation(self, parsed): set_operations = ('UNION', 'INTERSECT', 'EXCEPT') @@ -227,7 +225,7 @@ def _inject_where_clause(self, parsed, where_clause): return str(parsed) - def _handle_set_operation(self, parsed, client_id): + def _handle_set_operation(self, parsed, client_id, is_cte: bool = False, cte_name: str = None): print("Handling set operation") # Split the query into individual SELECT statements statements = [] @@ -253,9 +251,14 @@ def _handle_set_operation(self, parsed, client_id): # Apply the filter to each SELECT statement filtered_statements = [] for stmt in statements: - filtered_stmt = self._apply_filter_to_single_query(stmt, client_id) - filtered_statements.append(filtered_stmt) - print(f"Filtered statement: {filtered_stmt}") + if is_cte: + filtered_stmt = self._apply_filter_to_single_CTE_query(stmt, client_id, cte_name) + filtered_statements.append(filtered_stmt) + print(f"Filtered statement: {filtered_stmt}") + else: + filtered_stmt = self._apply_filter_to_single_query(stmt, client_id) + filtered_statements.append(filtered_stmt) + print(f"Filtered statement: {filtered_stmt}") # Reconstruct the query result = f" {set_operation} ".join(filtered_statements) @@ -363,25 +366,35 @@ def _cleanup_whitespace(self, query: str) -> str: def _handle_subquery(self, parsed, client_id): result = [] tokens = parsed.tokens if hasattr(parsed, 'tokens') else [parsed] + mainquery = [] for token in tokens: if isinstance(token, Identifier) and token.has_alias(): if isinstance(token.tokens[0], Parenthesis): + mainquery.append(" PLACEHOLDER ") subquery = token.tokens[0].tokens[1:-1] subquery_str = ' '.join(str(t) for t in subquery) filtered_subquery = self._apply_filter_recursive( sqlparse.parse(subquery_str)[0], client_id) alias = token.get_alias() - result.append(f"({filtered_subquery}) AS {alias}") + AS_keyword = next((t for t in token.tokens if t.ttype == sqlparse.tokens.Keyword and t.value.upper() == 'AS'), None) # Checks for existence of 'AS' keyword + + if AS_keyword: + result.append(f"({filtered_subquery}) AS {alias}") + else: + result.append(f"({filtered_subquery}) {alias}") else: - result.append(str(token)) + mainquery.append(str(token)) + elif isinstance(token, Parenthesis): + mainquery.append(" PLACEHOLDER ") subquery = token.tokens[1:-1] subquery_str = ' '.join(str(t) for t in subquery) filtered_subquery = self._apply_filter_recursive( sqlparse.parse(subquery_str)[0], client_id) result.append(f"({filtered_subquery})") - elif isinstance(token, Where): + + elif isinstance(token, Where) and 'IN' in str(parsed): try: filtered_where = self._handle_where_subqueries( token, client_id) @@ -389,19 +402,15 @@ def _handle_subquery(self, parsed, client_id): except Exception as e: result.append(str(token)) else: - # Preserve whitespace tokens - if token.is_whitespace: - result.append(str(token)) - else: - # Add space before and after non-whitespace tokens, except for punctuation - if result and not result[-1].endswith(' ') and not str(token).startswith((')', ',', '.')): - result.append(' ') - result.append(str(token)) - if not str(token).endswith(('(', ',')): - result.append(' ') + mainquery.append(str(token)) - final_result = ''.join(result).strip() - return final_result + mainquery = ''.join(mainquery).strip() + if ' IN ' in str(parsed): + return f"{mainquery} {result[0]}" + else: + filtered_mainquery = self._apply_filter_to_single_query(mainquery, client_id) + query = filtered_mainquery.replace("PLACEHOLDER", result[0]) + return query def _handle_where_subqueries(self, where_clause, client_id): if self._is_cte_query(where_clause): @@ -508,3 +517,43 @@ def _extract_main_table(self, where_clause): if isinstance(token, Identifier): return token.get_real_name() return None + + def _apply_filter_to_single_CTE_query(self, sql_query: str, client_id: int, cte_name: str) -> str: + parts = sql_query.split(' GROUP BY ') + main_query = parts[0] + + group_by = f" GROUP BY {parts[1]}" if len(parts) > 1 else "" + parsed = sqlparse.parse(main_query)[0] + tables_info = self._extract_tables_info(parsed) + + filters = [] + _table_ = [] + + for table_info in tables_info: + if table_info['name'] != cte_name: + table_dict = { + "name": table_info['name'], + "alias": table_info['alias'], + "schema": table_info['schema'] + } + _table_.append(table_dict) + + matching_table = self._find_matching_table(_table_[0]['name'], _table_[0]['schema']) + + if matching_table: + client_id_column = self.client_tables[matching_table] + table_reference = _table_[0]['alias'] or _table_[0]['name'] + + filters.append(f'{self._quote_identifier(table_reference)}.{self._quote_identifier(client_id_column)} = {client_id}') + + if filters: + where_clause = " AND ".join(filters) + if 'WHERE' in main_query.upper(): + where_parts = main_query.split('WHERE', 1) + result = f"{where_parts[0]} WHERE {where_parts[1].strip()} AND {where_clause}" + else: + result = f"{main_query} WHERE {where_clause}" + else: + result = main_query + + return result + group_by \ No newline at end of file diff --git a/tests/core/test_sql_query_filter.py b/tests/core/test_sql_query_filter.py index 80b7eb4..7f5e29c 100644 --- a/tests/core/test_sql_query_filter.py +++ b/tests/core/test_sql_query_filter.py @@ -98,10 +98,10 @@ def test_subquery_in_from(self): expected = 'SELECT * FROM (SELECT * FROM orders WHERE "orders"."user_id" = 1) AS subq' self.assertEqual(self.filter.apply_client_filter(query, 1), expected) - # def test_subquery_in_join(self): - # query = 'SELECT o.* FROM orders o JOIN (SELECT * FROM products) p ON o.product_id = p.id' - # expected = 'SELECT o.* FROM orders o JOIN (SELECT * FROM products WHERE "products"."company_id" = 1) p ON o.product_id = p.id WHERE "o"."user_id" = 1' - # self.assertEqual(self.filter.apply_client_filter(query, 1), expected) + def test_subquery_in_join(self): + query = 'SELECT o.* FROM orders o JOIN (SELECT * FROM products) p ON o.product_id = p.id' + expected = 'SELECT o.* FROM orders o JOIN (SELECT * FROM products WHERE "products"."company_id" = 1) p ON o.product_id = p.id WHERE "o"."user_id" = 1' + self.assertEqual(self.filter.apply_client_filter(query, 1), expected) def test_nested_subqueries(self): query = 'SELECT * FROM (SELECT * FROM (SELECT * FROM orders) AS inner_subq) AS outer_subq' @@ -123,7 +123,8 @@ def setUp(self): 'products': 'company_id', 'inventory.items': 'organization_id', 'items': 'organization_id', - 'customers': 'customer_id' + 'customers': 'customer_id', + 'categories': 'company_id' } self.filter = SQLQueryFilter( self.client_tables, schemas=['main', 'inventory']) @@ -217,61 +218,61 @@ def test_multiple_ctes(self): self.assertSQLEqual( self.filter.apply_client_filter(query, 1), expected) - # def test_cte_with_subquery(self): - # query = ''' - # WITH top_products AS ( - # SELECT p.id, p.name, SUM(o.quantity) as total_sold - # FROM products p - # JOIN (SELECT * FROM orders WHERE status = 'completed') o ON p.id = o.product_id - # GROUP BY p.id, p.name - # ORDER BY total_sold DESC - # LIMIT 10 - # ) - # SELECT * FROM top_products - # ''' - # expected = ''' - # WITH top_products AS ( - # SELECT p.id, p.name, SUM(o.quantity) as total_sold - # FROM products p - # JOIN (SELECT * FROM orders WHERE status = 'completed' AND "orders"."user_id" = 1) o ON p.id = o.product_id - # WHERE "p"."company_id" = 1 - # GROUP BY p.id, p.name - # ORDER BY total_sold DESC - # LIMIT 10 - # ) - # SELECT * FROM top_products - # ''' - # self.assertSQLEqual( - # self.filter.apply_client_filter(query, 1), expected) - - # def test_recursive_cte(self): - # query = ''' - # WITH RECURSIVE category_tree AS ( - # SELECT id, name, parent_id, 0 AS level - # FROM categories - # WHERE parent_id IS NULL - # UNION ALL - # SELECT c.id, c.name, c.parent_id, ct.level + 1 - # FROM categories c - # JOIN category_tree ct ON c.parent_id = ct.id - # ) - # SELECT * FROM category_tree - # ''' - # expected = ''' - # WITH RECURSIVE category_tree AS ( - # SELECT id, name, parent_id, 0 AS level - # FROM categories - # WHERE parent_id IS NULL AND "categories"."company_id" = 1 - # UNION ALL - # SELECT c.id, c.name, c.parent_id, ct.level + 1 - # FROM categories c - # JOIN category_tree ct ON c.parent_id = ct.id - # WHERE "c"."company_id" = 1 - # ) - # SELECT * FROM category_tree - # ''' - # self.assertSQLEqual( - # self.filter.apply_client_filter(query, 1), expected) + def test_cte_with_subquery(self): + query = ''' + WITH top_products AS ( + SELECT p.id, p.name, SUM(o.quantity) as total_sold + FROM products p + JOIN (SELECT * FROM orders WHERE status = 'completed') o ON p.id = o.product_id + GROUP BY p.id, p.name + ORDER BY total_sold DESC + LIMIT 10 + ) + SELECT * FROM top_products + ''' + expected = ''' + WITH top_products AS ( + SELECT p.id, p.name, SUM(o.quantity) as total_sold + FROM products p + JOIN (SELECT * FROM orders WHERE status = 'completed' AND "orders"."user_id" = 1) o ON p.id = o.product_id + WHERE "p"."company_id" = 1 + GROUP BY p.id, p.name + ORDER BY total_sold DESC + LIMIT 10 + ) + SELECT * FROM top_products + ''' + self.assertSQLEqual( + self.filter.apply_client_filter(query, 1), expected) + + def test_recursive_cte(self): + query = ''' + WITH RECURSIVE category_tree AS ( + SELECT id, name, parent_id, 0 AS level + FROM categories + WHERE parent_id IS NULL + UNION ALL + SELECT c.id, c.name, c.parent_id, ct.level + 1 + FROM categories c + JOIN category_tree ct ON c.parent_id = ct.id + ) + SELECT * FROM category_tree + ''' + expected = ''' + WITH RECURSIVE category_tree AS ( + SELECT id, name, parent_id, 0 AS level + FROM categories + WHERE parent_id IS NULL AND "categories"."company_id" = 1 + UNION ALL + SELECT c.id, c.name, c.parent_id, ct.level + 1 + FROM categories c + JOIN category_tree ct ON c.parent_id = ct.id + WHERE "c"."company_id" = 1 + ) + SELECT * FROM category_tree + ''' + self.assertSQLEqual( + self.filter.apply_client_filter(query, 1), expected) if __name__ == '__main__':