diff --git a/src/dataneuron/core/nlp_helpers/cte_handler.py b/src/dataneuron/core/nlp_helpers/cte_handler.py index f57d6c4..7494228 100644 --- a/src/dataneuron/core/nlp_helpers/cte_handler.py +++ b/src/dataneuron/core/nlp_helpers/cte_handler.py @@ -49,6 +49,12 @@ def extract_main_query(parsed): def filter_cte(cte_part, filter_function, client_id): filtered_ctes = [] + is_recursive = False + + for token in cte_part.tokens: + if token.ttype is Keyword and token.value.upper() == 'RECURSIVE': + is_recursive = True + def process_cte(token): if isinstance(token, sqlparse.sql.Identifier): cte_name = token.get_name() @@ -57,7 +63,7 @@ def process_cte(token): # Remove outer parentheses inner_query_str = str(inner_query)[1:-1] filtered_inner_query = filter_function( - sqlparse.parse(inner_query_str)[0], client_id) + sqlparse.parse(inner_query_str)[0], client_id, cte_name) filtered_ctes.append(f"{cte_name} AS ({filtered_inner_query})") for token in cte_part.tokens: @@ -68,7 +74,10 @@ def process_cte(token): process_cte(token) if filtered_ctes: - filtered_cte_str = "WITH " + ",\n".join(filtered_ctes) + if is_recursive: + filtered_cte_str = "WITH RECURSIVE " + ",\n".join(filtered_ctes) + else: + filtered_cte_str = "WITH " + ",\n".join(filtered_ctes) else: filtered_cte_str = "" return filtered_cte_str diff --git a/src/dataneuron/core/nlp_helpers/is_subquery.py b/src/dataneuron/core/nlp_helpers/is_subquery.py new file mode 100644 index 0000000..1653cd4 --- /dev/null +++ b/src/dataneuron/core/nlp_helpers/is_subquery.py @@ -0,0 +1,148 @@ +import re +from sqlparse.sql import Token +from sqlparse.tokens import DML, Keyword, Whitespace, Newline +from query_cleanup import _cleanup_whitespace + +def _contains_subquery(parsed): + tokens = parsed.tokens if hasattr(parsed, 'tokens') else [parsed] + end_keywords = {'GROUP BY', 'HAVING', 'ORDER BY'} + joins = {'JOIN', 'LEFT JOIN', 'RIGHT JOIN', 'INNER JOIN'} + where_keywords = {'IN', 'NOT IN', 'EXISTS', 'ALL', 'ANY'} + where_keyword_pattern = '|'.join(where_keywords) + + select_index = None + from_index = None + where_index = None + end_index = None + + select_block = [] + from_block = [] + where_block = [] + end_keywords_block = [] + results = [] + join_statement = [] + join_found = False + + i = 0 + while i < len(tokens): + token = tokens[i] + + if isinstance(token, Token) and token.ttype is DML and token.value.upper() == 'SELECT': + select_index = i + k = i + 1 + while k < len(tokens) and not (isinstance(tokens[k], Token) and tokens[k].ttype == Keyword and tokens[k].value.upper() == 'FROM'): + k += 1 + + from_index = k + k = from_index + 1 + while k < len(tokens): + if isinstance(tokens[k], Token) and 'WHERE' in str(tokens[k]) and not \ + re.match(r'\(\s*SELECT.*?\bWHERE\b.*?\)', str(tokens[k])): + where_index = k + elif isinstance(tokens[k], Token) and str(tokens[k]) in end_keywords: + end_index = k + break + + k += 1 + i += 1 + + where_end = end_index if end_index else len(tokens) + from_end = min( + index for index in [where_index, end_index] if index is not None) if any([where_index, end_index]) \ + else len(tokens) + + for j in range(select_index + 1, from_index): # Between SELECT and FROM block + select_block.append(_cleanup_whitespace(str(tokens[j]))) + + select_elements = ' '.join(select_block).strip().split(',') # Split by commas to handle multiple elements in the SELECT block + for element in select_elements: + element = element.replace('\n', ' ').strip() # Clean up any extra whitespace + + if re.search(r'\bCASE\b((\s+WHEN\b.*?\bTHEN\b.*?)+)(\s+ELSE\b.*)?(?=\s+END\b)', element, re.DOTALL): + + for match in re.findall(r'\bWHEN\b.*?\bTHEN\b.*?\bELSE\b.*?(?=\bWHEN\b|\bELSE\b|\bEND\b)', element, re.DOTALL): #Split them into WHEN, THEN and ELSE blocks: # Check for subquery inside WHEN THEN + if re.search(r'\(.*?\bSELECT\b.*?\)', match, re.DOTALL): + results.append("Subquery exists inside CASE WHEN THEN ELSE block") + + elif '(' in element and ')' in element: # Find if any element has parenthesis + if re.search(r'\(.*?\bSELECT\b.*?\)', element, re.DOTALL): + results.append("Inline Subquery exists inside SELECT block") + + + for j in range(from_index + 1, from_end): + if isinstance(tokens[j], Token) and tokens[j].ttype not in [Whitespace, Newline]: + from_block.append(tokens[j]) + + for i, element in enumerate(from_block): + if isinstance(element, Token) and element.ttype == Keyword and element.value.upper() in joins: + join_found = True + + if i == 1: + join_statement.append(str(from_block[i - 1])) + join_statement.append(str(from_block[i + 1])) + elif i > 1: + join_statement.append(str(from_block[i + 1])) + + elif not join_found and re.match(r'\(\s*([\s\S]*?)\s*\)', str(element), re.DOTALL): + if re.findall(r'(UNION\s+ALL|UNION|INTERSECT\s+ALL|INTERSECT|EXCEPT\s+ALL|EXCEPT)', str(element), re.IGNORECASE | re.DOTALL): + results.append("Contains set operation - Subquery found inside FROM block") + elif re.match(r'\(\s*SELECT.*\)\s+\w+', str(element), re.IGNORECASE | re.DOTALL): + results.append("Inline subquery inside FROM block") + + if join_found: + for stmt in join_statement: + join_statement_str = _cleanup_whitespace(str(stmt)) + if re.findall(r'\(\s*([\s\S]*?)\s*\)', join_statement_str): + if re.findall(r'(UNION\s+ALL|UNION|INTERSECT\s+ALL|INTERSECT|EXCEPT\s+ALL|EXCEPT)', join_statement_str, re.IGNORECASE | re.DOTALL): + results.append("Set operation - Subquery inside JOIN") + elif re.match(r'\(\s*SELECT.*\)\s+\w+', join_statement_str, re.IGNORECASE | re.DOTALL): + results.append("Inline subquery inside JOIN") + + if where_index: + for j in range(where_index, where_end): + where_block.append(_cleanup_whitespace(str(tokens[j]).strip('WHERE '))) + + for i in where_block: + for clause in re.split(r'\bAND\b(?![^()]*\))', i): + clause = clause.strip() + + if re.search(fr'\b({where_keyword_pattern})\b\s*\(.*?\bSELECT\b.*?\)', clause, re.DOTALL): + found_keyword = re.search(fr'\b({where_keyword_pattern})\b', clause).group() + results.append(f"Subquery with special keyword found in WHERE block: {found_keyword} \n") + elif re.search(r'\(.*?\bSELECT\b.*?\)', clause, re.DOTALL): + results.append("Inline subquery found in WHERE block \n") + + if end_index: + for j in range(end_index, len(tokens)): + if isinstance(tokens[j], Token) and tokens[j].ttype not in [Whitespace, Newline]: + end_keywords_block.append(_cleanup_whitespace(str(tokens[j]))) + + endsubquery_block = [] + count = 0 + indices = [] + + for index, token in enumerate(end_keywords_block): + if str(token).upper() in end_keywords: + count += 1 + indices.append(index) + + if count >= 1: # If there is at least one end keyword + for i in range(len(indices)): + start_idx = indices[i] # Start and end indices of each block + if i < len(indices) - 1: + end_idx = indices[i + 1] # Until the next keyword + else: + end_idx = len(end_keywords_block) # Until the end of the block + + # Extract the block between start_idx and end_idx + endsubquery_block = end_keywords_block[start_idx:end_idx] + endsubquery_block_str = ' '.join(endsubquery_block) + + if re.search(r'\((SELECT [\s\S]*?)\)', str(endsubquery_block_str), re.IGNORECASE): + if re.search(r'\(((?:[^()]+|\([^()]*\))*)\)\s*(?:AS\s+)?(\w+)?', str(endsubquery_block_str), re.IGNORECASE).group(1): + results.append("Subquery in END keywords") + + if len(results) >= 1: + return True + else: + return False \ No newline at end of file diff --git a/src/dataneuron/core/nlp_helpers/query_cleanup.py b/src/dataneuron/core/nlp_helpers/query_cleanup.py new file mode 100644 index 0000000..dedc2dd --- /dev/null +++ b/src/dataneuron/core/nlp_helpers/query_cleanup.py @@ -0,0 +1,17 @@ +import re + +def _cleanup_whitespace(query: str) -> str: + # Split the query into lines + lines = query.split('\n') + cleaned_lines = [] + for line in lines: + # Remove leading/trailing whitespace from each line + line = line.strip() + # Replace multiple spaces with a single space, but not in quoted strings + line = re.sub(r'\s+(?=(?:[^\']*\'[^\']*\')*[^\']*$)', ' ', line) + # Ensure single space after commas, but not in quoted strings + line = re.sub( + r'\s*,\s*(?=(?:[^\']*\'[^\']*\')*[^\']*$)', ', ', line) + cleaned_lines.append(line) + # Join the lines back together + return '\n'.join(cleaned_lines) \ No newline at end of file diff --git a/src/dataneuron/core/nlp_helpers/subquery_handler.py b/src/dataneuron/core/nlp_helpers/subquery_handler.py new file mode 100644 index 0000000..fda8d36 --- /dev/null +++ b/src/dataneuron/core/nlp_helpers/subquery_handler.py @@ -0,0 +1,361 @@ +import sqlparse +from sqlparse.sql import Token, Identifier +from sqlparse.tokens import Keyword, DML, Whitespace, Newline +import re +from query_cleanup import _cleanup_whitespace + + +class SubqueryHandler: + def __init__(self, query_filter=None, setop_query_filter=None, matching_table_finder=None): + self.SQLQueryFilter = query_filter + self.SetOP_QueryFilter = setop_query_filter + self._find_matching_table = matching_table_finder + self._cleanup_whitespace = _cleanup_whitespace + self.client_id = 1 + self.schemas=['main', 'inventory'] + + + def SELECT_subquery(self, SELECT_block): + select_elements = ' '.join(SELECT_block).strip().split(',') + filtered_dict = { + 'subquery_list': [], + 'filtered_subquery': [], + 'placeholder_value': [] + } + + for element in select_elements: + element = element.replace('\n', ' ').strip() + + # Detect CASE WHEN THEN ELSE + case_match = re.search(r'\bCASE\b(.*?\bEND\b)', element, re.DOTALL) + if case_match: + case_block = case_match.group(1) + when_then_else_blocks = re.findall(r'\bWHEN\b(.*?)\bTHEN\b(.*?)(?=\bWHEN\b|\bELSE\b|\bEND\b)', case_block, re.DOTALL) + else_clause = re.search(r'\bELSE\b(.*?)(?=\bEND\b)', case_block, re.DOTALL) + + # Process WHEN-THEN pairs + for when, then in when_then_else_blocks: + if re.search(r'\(.*?\bSELECT\b.*?\)', when, re.DOTALL): #WHEN has a subquery + filtered_dict['subquery_list'].append(when) + if re.search(r'\(.*?\bSELECT\b.*?\)', then, re.DOTALL): #THEN has a subquery + filtered_dict['subquery_list'].append(then) + + if else_clause and re.search(r'\(.*?\bSELECT\b.*?\)', else_clause.group(1), re.DOTALL): #ELSE has a subquery + filtered_dict['subquery_list'].append(else_clause.group(1)) + + elif '(' in element and ')' in element: + if re.search(r'\(.*?\bSELECT\b.*?\)', element, re.DOTALL): + filtered_dict['subquery_list'].append(element) + + for i, subquery in enumerate(filtered_dict['subquery_list']): + placeholder = f"" + + filtered_subquery = self.SQLQueryFilter( + sqlparse.parse( + re.search(r'\(((?:[^()]+|\([^()]*\))*)\)', subquery).group(1))[0], self.client_id) + + filtered_dict['placeholder_value'].append(placeholder) + filtered_dict['filtered_subquery'].append(filtered_subquery) + + return filtered_dict + + + def FROM_subquery(self, FROM_block): + joins = {'JOIN', 'LEFT JOIN', 'RIGHT JOIN'} + join_found = False + join_statements = [] + exit_early = False + + join_dict = { + "matching_table": [], + "filtered_matching_table": [], + "alias": [] + } + + def _handle_joins(): + alias = None + for i, token in enumerate(FROM_block): + if join_found and isinstance(token, Token) and token.ttype == Keyword and token.value.upper() in joins: + previous_token = FROM_block[i - 1] if i > 0 else None + next_token = FROM_block[i + 1] if i + 1 < len(FROM_block) else None + if previous_token: + join_statements.append(previous_token.value.strip()) + if next_token: + join_statements.append(next_token.value.strip()) + + for statement in join_statements: + join_statement_str = _cleanup_whitespace(statement) + + for t in sqlparse.parse(join_statement_str)[0].tokens: + if isinstance(t, Identifier): + alias = t.get_alias() + name = t.get_real_name() + + if alias and self._find_matching_table(str(name), self.schemas) or \ + self._find_matching_table(join_statement_str, self.schemas): + + filtered_table = self.SQLQueryFilter( + sqlparse.parse(f'SELECT * FROM {join_statement_str}')[0], self.client_id) + join_dict['filtered_matching_table'].append(f'({filtered_table})') + + if alias: + join_dict['alias'].append(f"AS {alias}") + join_dict['matching_table'].append(join_statement_str) + else: + + join_dict['alias'].append(f"AS {join_statement_str}") + join_dict['matching_table'].append(join_statement_str) + + else: + if re.match(r'\(\s*([\s\S]*?)\s*\)', join_statement_str): + if re.findall(r'(UNION\s+ALL|UNION|INTERSECT\s+ALL|INTERSECT|EXCEPT\s+ALL|EXCEPT)', join_statement_str, re.IGNORECASE | re.DOTALL): + match = re.search(r'\(((?:[^()]+|\([^()]*\))*)\)', join_statement_str) + inner_parentheses = match.group(1) + start, end = match.span() + alias = join_statement_str[end + 1:] # +1 for WHITESPACEEEE + + filtered_subquery = self.SetOP_QueryFilter(sqlparse.parse(inner_parentheses)[0], self.client_id) + join_dict['filtered_matching_table'].append(f'({filtered_subquery})') + join_dict['matching_table'].append(join_statement_str) + join_dict['alias'].append(alias) + + elif re.match(r'\(\s*SELECT.*?\)\s*(?:AS\s+)?(\w+)?', join_statement_str, re.IGNORECASE | re.DOTALL): + subquery_match = re.match(r'\(\s*SELECT.*?\)\s*(?:AS\s+)?(\w+)?', join_statement_str, re.IGNORECASE | re.DOTALL) + inner_parentheses = re.search(r'\(((?:[^()]+|\([^()]*\))*)\)', join_statement_str).group(1) + alias = subquery_match.group(1) if subquery_match else '' + start, end = subquery_match.span() + + filtered_subquery = self.SetOP_QueryFilter(sqlparse.parse(inner_parentheses)[0], self.client_id) + + join_dict['matching_table'].append(join_statement_str) + join_dict['filtered_matching_table'].append(f'({filtered_subquery})') + join_dict['alias'].append(f"{alias}" if alias else "") + + def _not_handle_joins(): + nonlocal exit_early + for token in FROM_block: + FROM_block_str = _cleanup_whitespace(str(token)) + if re.match(r'\(\s*([\s\S]*?)\s*\)', FROM_block_str) and re.findall(r'(UNION\s+ALL|UNION|INTERSECT\s+ALL|INTERSECT|EXCEPT\s+ALL|EXCEPT)', FROM_block_str, re.IGNORECASE | re.DOTALL): + match = re.search(r'\(((?:[^()]+|\([^()]*\))*)\)', FROM_block_str) + inner_parentheses = match.group(1) + start, end = match.span() + alias = FROM_block_str[end + 1:] # +1 for WHITESPACEEEE + + filtered_subquery = self.SetOP_QueryFilter(sqlparse.parse(inner_parentheses)[0], self.client_id) + join_dict['filtered_matching_table'].append(f'({filtered_subquery})') + join_dict['matching_table'].append(FROM_block_str) + join_dict['alias'].append(alias) + + elif re.match(r'\(\s*SELECT.*?\)\s*(?:AS\s+)?(\w+)?', FROM_block_str, re.IGNORECASE | re.DOTALL): + subquery_match = re.match(r'\(\s*SELECT.*?\)\s*(?:AS\s+)?(\w+)?', FROM_block_str, re.IGNORECASE | re.DOTALL) + inner_parentheses = re.search(r'\(((?:[^()]+|\([^()]*\))*)\)', FROM_block_str).group(1) + alias_match = re.search(r'\)\s*(AS\s+\w+)?', FROM_block_str, re.IGNORECASE) + alias = alias_match.group(1) if alias_match and alias_match.group(1) else '' + start, end = subquery_match.span() + + filtered_subquery = self.SetOP_QueryFilter(sqlparse.parse(inner_parentheses)[0], self.client_id) + + join_dict['matching_table'].append(FROM_block_str) + join_dict['filtered_matching_table'].append(f'({filtered_subquery})') + join_dict['alias'].append(alias) + + else: + for t in sqlparse.parse(FROM_block_str)[0].tokens: + if isinstance(t, Identifier): + name = t.get_real_name() + if self._find_matching_table(str(name), self.schemas): + exit_early = True + + for token in FROM_block: + if isinstance(token, Token) and token.ttype == Keyword and token.value.upper() in joins: + join_found = True + break + if join_found: + _handle_joins() + else: + _not_handle_joins() + + if exit_early: + return 0 + else: + reconstructed_from_clause = [] + for token in FROM_block: + if isinstance(token, Token) and token.value.strip() in join_dict["matching_table"]: + table_index = join_dict["matching_table"].index(token.value.strip()) + filtered_table = join_dict["filtered_matching_table"][table_index] + added_alias = join_dict["alias"][table_index] + reconstructed_from_clause.append(f"{filtered_table} {added_alias}") + else: + reconstructed_from_clause.append(token.value.strip()) + + reconstructed_query = " ".join(reconstructed_from_clause) + return reconstructed_query + + + def WHERE_subquery(self, WHERE_block): + where_keywords = {'IN', 'NOT IN', 'EXISTS', 'ALL', 'ANY'} + where_keyword_pattern = '|'.join(where_keywords) + filtered_dict = { + 'subquery_list': [], + 'filtered_subquery': [], + 'placeholder_value': [] + } + + for i in WHERE_block: + for clause in re.split(r'\bAND\b(?![^()]*\))', i): + clause = clause.strip() + + if re.search(fr'\b({where_keyword_pattern})\b\s*\(.*?\bSELECT\b.*?\)', clause, re.DOTALL): + filtered_dict['subquery_list'].append(clause) + elif re.search(r'\(.*?\bSELECT\b.*?\)', clause, re.DOTALL): + filtered_dict['subquery_list'].append(clause) + + for j in range(len(filtered_dict['subquery_list'])): + placeholder = f"" + filtered_subquery = self.SQLQueryFilter( sqlparse.parse( re.search(r'\(((?:[^()]+|\([^()]*\))*)\)', (filtered_dict['subquery_list'][j])).group(1) )[0], self.client_id ) + filtered_dict['placeholder_value'].append(placeholder) + filtered_dict['filtered_subquery'].append(filtered_subquery) + + return filtered_dict + + + def END_subqueries(self, end_keywords_block): + + end_keywords = {'GROUP BY', 'HAVING', 'ORDER BY'} + filtered_dict = { + 'subquery_list': [], + 'filtered_subquery': [], + 'placeholder_value': [] + } + + endsubquery_block = [] + count = 0 + indices = [] + + for index, token in enumerate(end_keywords_block): + if str(token).upper() in end_keywords: + count += 1 + indices.append(index) + + if count >= 1: # If there is at least one end keyword + for i in range(len(indices)): + start_idx = indices[i] + + if i < len(indices) - 1: + end_idx = indices[i + 1] # Until the next keyword + else: + end_idx = len(end_keywords_block) # Until the end of the block + + endsubquery_block = end_keywords_block[start_idx:end_idx] + endsubquery_block_str = ' '.join(endsubquery_block) + + if re.search(r'\((SELECT [\s\S]*?)\)', str(endsubquery_block_str), re.IGNORECASE): + subquery_match = re.search(r'\(((?:[^()]+|\([^()]*\))*)\)\s*(?:AS\s+)?(\w+)?', str(endsubquery_block_str), re.IGNORECASE).group(1) + print(subquery_match) + filtered_dict['subquery_list'].append(subquery_match) + placeholder = f"" + filtered_dict['filtered_subquery'].append(self.SQLQueryFilter(sqlparse.parse(subquery_match)[0], self.client_id)) + filtered_dict['placeholder_value'].append(placeholder) + + return filtered_dict + + + def handle_subquery(self, parsed): + tokens = parsed.tokens if hasattr(parsed, 'tokens') else [parsed] + end_keywords = {'GROUP BY', 'HAVING', 'ORDER BY'} + + select_index = None + from_index = None + where_index = None + end_index = None + + select_block = [] + from_block = [] + where_block = [] + end_keywords_block = [] + + i = 0 + while i < len(tokens): + token = tokens[i] + + if isinstance(token, Token) and token.ttype is DML and token.value.upper() == 'SELECT': + select_index = i + k = i + 1 + while k < len(tokens) and not (isinstance(tokens[k], Token) and tokens[k].ttype == Keyword and tokens[k].value.upper() == 'FROM'): + k += 1 + + from_index = k + k += 1 + while k < len(tokens): + if isinstance(tokens[k], Token) and 'WHERE' in str(tokens[k]) and not \ + re.match(r'\(\s*SELECT.*?\bWHERE\b.*?\)', str(tokens[k])): + where_index = k + elif isinstance(tokens[k], Token) and str(tokens[k]) in end_keywords: + end_index = k + break + + k += 1 + i += 1 + + where_end = end_index if end_index else len(tokens) + from_end = min( + index for index in [where_index, end_index] if index is not None) if any([where_index, end_index]) \ + else len(tokens) + + for j in range(select_index + 1, from_index): + select_block.append(self._cleanup_whitespace(str(tokens[j]))) + + for j in range(from_index + 1, from_end): + if isinstance(tokens[j], Token) and tokens[j].ttype not in [Whitespace, Newline]: + from_block.append(tokens[j]) + + WHERE_dict = {'subquery_list': [], 'filtered_subquery': [], 'placeholder_value': []} # For cases where WHERE_dict is empty and leads to [UnboundLocalError: cannot access local variable 'WHERE_dict' where it is not associated with a value] + if where_index: + for j in range(where_index, where_end): + where_block.append(self._cleanup_whitespace(str(tokens[j]).strip('WHERE '))) + WHERE_dict = self.WHERE_subquery(where_block) + + END_dict = {'subquery_list': [], 'filtered_subquery': [], 'placeholder_value': []} + if end_index: + for j in range(end_index, len(tokens)): + if isinstance(tokens[j], Token) and tokens[j].ttype not in [Whitespace, Newline]: + end_keywords_block.append(self._cleanup_whitespace(str(tokens[j]))) + END_dict = self.END_subqueries(end_keywords_block) + + SELECT_dict = self.SELECT_subquery(select_block) + FROM_filtering = self.FROM_subquery(from_block) + + subquery_dict = { + "subqueries": SELECT_dict['subquery_list'] + WHERE_dict['subquery_list'] + END_dict['subquery_list'], + "filtered subqueries": SELECT_dict['filtered_subquery'] + WHERE_dict['filtered_subquery'] + END_dict['filtered_subquery'], + "placeholder names": SELECT_dict['placeholder_value'] + WHERE_dict['placeholder_value'] + END_dict['placeholder_value'] + } + + if FROM_filtering == 0: + for i in range(len(subquery_dict['filtered subqueries'])): + pattern = re.search(r'\(((?:[^()]+|\([^()]*\))*)\)\s*(?:AS\s+)?(\w+)?', subquery_dict['subqueries'][i], re.IGNORECASE) + if pattern: + subquery_with_alias = pattern.group(1) + mainquery_str = str(parsed).replace(subquery_with_alias, subquery_dict["placeholder names"][i]) if i == 0 \ + else mainquery_str.replace(subquery_with_alias, subquery_dict["placeholder names"][i]) + + if len(subquery_dict['subqueries']) == 1: + filtered_mainquery = self.SQLQueryFilter(sqlparse.parse(mainquery_str)[0], self.client_id) + else: + if i == 0: + filtered_mainquery = mainquery_str + elif i == len(subquery_dict['subqueries']) - 1: + filtered_mainquery = self.SQLQueryFilter(sqlparse.parse(mainquery_str)[0], self.client_id) + else: + mainquery_str = str(parsed) + + from_start = mainquery_str.upper().find('FROM') + where_start = mainquery_str.upper().find('WHERE') + + part_to_replace = mainquery_str[from_start:where_start].strip() + filtered_mainquery = mainquery_str.replace(part_to_replace, f"FROM {FROM_filtering}") + + for placeholder, filtered_subquery in zip(subquery_dict['placeholder names'], subquery_dict['filtered subqueries']): + filtered_mainquery = filtered_mainquery.replace(placeholder, filtered_subquery) + + return filtered_mainquery \ No newline at end of file diff --git a/src/dataneuron/core/sql_query_filter.py b/src/dataneuron/core/sql_query_filter.py index 7fc1fe1..aebf140 100644 --- a/src/dataneuron/core/sql_query_filter.py +++ b/src/dataneuron/core/sql_query_filter.py @@ -1,10 +1,14 @@ import re import sqlparse -from sqlparse.sql import IdentifierList, Identifier, Token, TokenList, Parenthesis, Where, Comparison -from sqlparse.tokens import Keyword, DML, Name, Whitespace, Punctuation +from sqlparse.sql import IdentifierList, Identifier, Token, Parenthesis, Where, Comparison +from sqlparse.tokens import Keyword, DML from typing import List, Dict, Optional + +from .nlp_helpers.query_cleanup import _cleanup_whitespace from .nlp_helpers.cte_handler import handle_cte_query from .nlp_helpers.is_cte import is_cte_query +from .nlp_helpers.is_subquery import _contains_subquery +from .nlp_helpers.subquery_handler import SubqueryHandler class SQLQueryFilter: @@ -13,63 +17,173 @@ def __init__(self, client_tables: Dict[str, str], schemas: List[str] = ['main'], self.schemas = schemas self.case_sensitive = case_sensitive self.filtered_tables = set() - self._is_cte_query = is_cte_query + self._cleanup_whitespace = _cleanup_whitespace + self.subquery_handler = SubqueryHandler(self._apply_filter_recursive, self._handle_set_operation, self._find_matching_table) + def apply_client_filter(self, sql_query: str, client_id: int) -> str: self.filtered_tables = set() parsed = sqlparse.parse(sql_query)[0] - - is_cte = self._is_cte_query(parsed) + is_cte = is_cte_query(parsed) if is_cte: return handle_cte_query(parsed, self._apply_filter_recursive, client_id) else: result = self._apply_filter_recursive(parsed, client_id) - return self._cleanup_whitespace(str(result)) + - def _apply_filter_recursive(self, parsed, client_id): - if self._is_cte_query(parsed): - return handle_cte_query(parsed, self._apply_filter_recursive, client_id) + def _apply_filter_recursive(self, parsed, client_id, cte_name: str = None): - if isinstance(parsed, Token) and parsed.ttype is DML: - return self._apply_filter_to_single_query(str(parsed), client_id) - elif self._contains_set_operation(parsed): - return self._handle_set_operation(parsed, client_id) - elif self._contains_subquery(parsed): - return self._handle_subquery(parsed, client_id) + if is_cte_query(parsed): + return handle_cte_query(parsed, self._apply_filter_recursive, client_id) else: - filtered_query = self._apply_filter_to_single_query( - str(parsed), client_id) - return self._handle_where_subqueries(sqlparse.parse(filtered_query)[0], client_id) + for token in parsed.tokens: + if isinstance(token, Token) and token.ttype is DML: + if self._contains_set_operation(parsed) and not _contains_subquery(parsed): + return self._handle_set_operation(parsed, client_id, True, cte_name) if cte_name else self._handle_set_operation(parsed, client_id) + elif _contains_subquery(parsed): + return self.subquery_handler.handle_subquery(parsed) + else: + return self._apply_filter_to_single_query(str(parsed), client_id) + def _contains_set_operation(self, parsed): set_operations = ('UNION', 'INTERSECT', 'EXCEPT') + + for token in parsed.tokens: + if token.ttype is Keyword and (token.value.upper() in set_operations or token.value.upper() in {op + ' ALL' for op in set_operations}): + return True + return False + + + def _handle_set_operation(self, parsed, client_id, is_cte: bool = False, cte_name: str = None): + set_operations = {'UNION', 'INTERSECT', 'EXCEPT'} + statements = [] + current_statement = [] + set_operation = None + for token in parsed.tokens: + if token.ttype is Keyword and (token.value.upper() in set_operations or token.value.upper() in {op + ' ALL' for op in set_operations}): + if current_statement: + statements.append(''.join(str(t) + for t in current_statement).strip()) + current_statement = [] + set_operation = token.value.upper() + else: + current_statement.append(token) + + if current_statement: + statements.append(''.join(str(t) + for t in current_statement).strip()) + + filtered_statements = [] + for stmt in statements: + if is_cte: + filtered_stmt = self._apply_filter_to_single_CTE_query(stmt, client_id, cte_name) + filtered_statements.append(filtered_stmt) + print(f"Filtered statement: {filtered_stmt}") + else: + match = re.search(r'\(([^()]*)\)', stmt) + if match: + extracted_part = match.group(1) + filtered_stmt = stmt.replace(extracted_part, self._apply_filter_to_single_query(extracted_part, client_id)) + filtered_statements.append(filtered_stmt) + else: + filtered_stmt = self._apply_filter_to_single_query(str(stmt), client_id) + filtered_statements.append(filtered_stmt) + + result = f" {set_operation} ".join(filtered_statements) + return result + + + def _apply_filter_to_single_query(self, sql_query: str, client_id: int) -> str: + parts = sql_query.split(' GROUP BY ') + main_query = parts[0] + group_by = f" GROUP BY {parts[1]}" if len(parts) > 1 else "" + + tables_info = self._extract_tables_info(sqlparse.parse(main_query)[0]) + + filters = [] + for table_info in tables_info: + table_name = table_info['name'] + table_alias = table_info['alias'] + schema = table_info['schema'] + + matching_table = self._find_matching_table(table_name, schema) + + if matching_table: + client_id_column = self.client_tables[matching_table] + table_reference = table_alias or table_name + filters.append(f'{self._quote_identifier(table_reference)}.{self._quote_identifier(client_id_column)} = {client_id}') + self.filtered_tables.add(matching_table) - # Check if parsed is a TokenList (has tokens attribute) - if hasattr(parsed, 'tokens'): - tokens = parsed.tokens + if filters: + where_clause = " AND ".join(filters) + if 'WHERE' in main_query.upper(): + where_parts = main_query.split('WHERE', 1) + result = f"{where_parts[0]} WHERE {where_parts[1].strip()} AND {where_clause}" + else: + result = f"{main_query} WHERE {where_clause}" else: - # If it's a single Token, wrap it in a list - tokens = [parsed] - - for i, token in enumerate(tokens): - if token.ttype is Keyword: - # Check for 'UNION ALL' as a single token - if token.value.upper() == 'UNION ALL': - print("Set operation found: UNION ALL") - return True - # Check for 'UNION', 'INTERSECT', 'EXCEPT' followed by 'ALL' - if token.value.upper() in set_operations: - next_token = parsed.token_next(i) if hasattr( - parsed, 'token_next') else None - if next_token and next_token[1].value.upper() == 'ALL': - print(f"Set operation found: {token.value} ALL") - return True - else: - print(f"Set operation found: {token.value}") - return True - return False + result = main_query + + return result + group_by + + + def _find_matching_table(self, table_name: str, schema: Optional[str] = None) -> Optional[str]: + possible_names = [ + f"{schema}.{table_name}" if schema else table_name, + table_name, + ] + [f"{s}.{table_name}" for s in self.schemas] + + for name in possible_names: + if self._case_insensitive_get(self.client_tables, name) is not None: + return name + return None + + + def _quote_identifier(self, identifier: str) -> str: + return f'"{identifier}"' + + + def _strip_quotes(self, identifier: str) -> str: + return identifier.strip('"').strip("'").strip('`') + + + def _case_insensitive_get(self, dict_obj: Dict[str, str], key: str) -> Optional[str]: + if self.case_sensitive: + return dict_obj.get(key) + return next((v for k, v in dict_obj.items() if k.lower() == key.lower()), None) + + + def _parse_table_identifier(self, identifier): + schema = None + alias = None + name = self._strip_quotes(str(identifier)) + + if identifier.has_alias(): + alias = self._strip_quotes(identifier.get_alias()) + name = self._strip_quotes(identifier.get_real_name()) + + if '.' in name: + parts = name.split('.') + if len(parts) == 2: + schema, name = parts + name = f"{schema}.{name}" if schema else name + + return {'name': name, 'schema': schema, 'alias': alias} + + + def _extract_tables_info(self, parsed, tables_info=None): + if tables_info is None: + tables_info = [] + + self._extract_from_clause_tables(parsed, tables_info) + self._extract_where_clause_tables(parsed, tables_info) + self._extract_cte_tables(parsed, tables_info) + + return tables_info + def _extract_from_clause_tables(self, parsed, tables_info): from_seen = False @@ -90,6 +204,7 @@ def _extract_from_clause_tables(self, parsed, tables_info): tables_info.append(self._parse_table_identifier( parsed.token_next(token)[1])) + def _extract_where_clause_tables(self, parsed, tables_info): where_clause = next( (token for token in parsed.tokens if isinstance(token, Where)), None) @@ -109,6 +224,7 @@ def _extract_where_clause_tables(self, parsed, tables_info): self._extract_from_clause_tables( subquery_parsed, tables_info) + def _extract_cte_tables(self, parsed, tables_info): cte_start = next((i for i, token in enumerate( parsed.tokens) if token.ttype is Keyword and token.value.upper() == 'WITH'), None) @@ -127,164 +243,34 @@ def _extract_cte_tables(self, parsed, tables_info): elif token.ttype is DML and token.value.upper() == 'SELECT': break - def _extract_tables_info(self, parsed, tables_info=None): - if tables_info is None: - tables_info = [] - - self._extract_from_clause_tables(parsed, tables_info) - self._extract_where_clause_tables(parsed, tables_info) - self._extract_cte_tables(parsed, tables_info) - - return tables_info - - def _extract_nested_subqueries(self, parsed, tables_info): - for token in parsed.tokens: - if isinstance(token, Identifier) and token.has_alias(): - if isinstance(token.tokens[0], Parenthesis): - subquery = token.tokens[0].tokens[1:-1] - subquery_str = ' '.join(str(t) for t in subquery) - subquery_parsed = sqlparse.parse(subquery_str)[0] - self._extract_from_clause_tables( - subquery_parsed, tables_info) - self._extract_where_clause_tables( - subquery_parsed, tables_info) - self._extract_nested_subqueries( - subquery_parsed, tables_info) - - def _parse_table_identifier(self, identifier): - schema = None - alias = None - name = self._strip_quotes(str(identifier)) - - if identifier.has_alias(): - alias = self._strip_quotes(identifier.get_alias()) - name = self._strip_quotes(identifier.get_real_name()) - - if '.' in name: - parts = name.split('.') - if len(parts) == 2: - schema, name = parts - name = f"{schema}.{name}" if schema else name - - return {'name': name, 'schema': schema, 'alias': alias} - - def _find_matching_table(self, table_name: str, schema: Optional[str] = None) -> Optional[str]: - possible_names = [ - f"{schema}.{table_name}" if schema else table_name, - table_name, - ] + [f"{s}.{table_name}" for s in self.schemas] - - for name in possible_names: - if self._case_insensitive_get(self.client_tables, name) is not None: - return name - return None - - def _case_insensitive_get(self, dict_obj: Dict[str, str], key: str) -> Optional[str]: - if self.case_sensitive: - return dict_obj.get(key) - return next((v for k, v in dict_obj.items() if k.lower() == key.lower()), None) - - def _strip_quotes(self, identifier: str) -> str: - return identifier.strip('"').strip("'").strip('`') - - def _quote_identifier(self, identifier: str) -> str: - return f'"{identifier}"' - - def _inject_where_clause(self, parsed, where_clause): - - where_index = next((i for i, token in enumerate(parsed.tokens) - if token.ttype is Keyword and token.value.upper() == 'WHERE'), None) - - if where_index is not None: - # Find the end of the existing WHERE clause - end_where_index = len(parsed.tokens) - 1 - for i in range(where_index + 1, len(parsed.tokens)): - token = parsed.tokens[i] - if token.ttype is Keyword and token.value.upper() in ('GROUP', 'ORDER', 'LIMIT'): - end_where_index = i - 1 - break - - # Insert our condition at the end of the existing WHERE clause - parsed.tokens.insert(end_where_index + 1, Token(Whitespace, ' ')) - parsed.tokens.insert(end_where_index + 2, Token(Keyword, 'AND')) - parsed.tokens.insert(end_where_index + 3, Token(Whitespace, ' ')) - parsed.tokens.insert(end_where_index + 4, - Token(Name, where_clause)) - else: - # Find the position to insert the WHERE clause - insert_position = len(parsed.tokens) - for i, token in enumerate(parsed.tokens): - if token.ttype is Keyword and token.value.upper() in ('GROUP', 'ORDER', 'LIMIT'): - insert_position = i - break - - # Insert the new WHERE clause - parsed.tokens.insert(insert_position, Token(Whitespace, ' ')) - parsed.tokens.insert(insert_position + 1, Token(Keyword, 'WHERE')) - parsed.tokens.insert(insert_position + 2, Token(Whitespace, ' ')) - parsed.tokens.insert(insert_position + 3, - Token(Name, where_clause)) - - return str(parsed) - - def _handle_set_operation(self, parsed, client_id): - print("Handling set operation") - # Split the query into individual SELECT statements - statements = [] - current_statement = [] - set_operation = None - for token in parsed.tokens: - if token.ttype is Keyword and token.value.upper() in ('UNION', 'INTERSECT', 'EXCEPT', 'UNION ALL'): - if current_statement: - statements.append(''.join(str(t) - for t in current_statement).strip()) - current_statement = [] - set_operation = token.value.upper() - else: - current_statement.append(token) - - if current_statement: - statements.append(''.join(str(t) - for t in current_statement).strip()) - - print(f"Split statements: {statements}") - print(f"Set operation: {set_operation}") - - # Apply the filter to each SELECT statement - filtered_statements = [] - for stmt in statements: - filtered_stmt = self._apply_filter_to_single_query(stmt, client_id) - filtered_statements.append(filtered_stmt) - print(f"Filtered statement: {filtered_stmt}") - - # Reconstruct the query - result = f" {set_operation} ".join(filtered_statements) - print(f"Final result: {result}") - return result - - def _apply_filter_to_single_query(self, sql_query: str, client_id: int) -> str: + def _apply_filter_to_single_CTE_query(self, sql_query: str, client_id: int, cte_name: str) -> str: parts = sql_query.split(' GROUP BY ') main_query = parts[0] + group_by = f" GROUP BY {parts[1]}" if len(parts) > 1 else "" - parsed = sqlparse.parse(main_query)[0] tables_info = self._extract_tables_info(parsed) filters = [] + _table_ = [] + for table_info in tables_info: - table_name = table_info['name'] - table_alias = table_info['alias'] - schema = table_info['schema'] + if table_info['name'] != cte_name: + table_dict = { + "name": table_info['name'], + "alias": table_info['alias'], + "schema": table_info['schema'] + } + _table_.append(table_dict) + + matching_table = self._find_matching_table(_table_[0]['name'], _table_[0]['schema']) - matching_table = self._find_matching_table(table_name, schema) + if matching_table: + client_id_column = self.client_tables[matching_table] + table_reference = _table_[0]['alias'] or _table_[0]['name'] - if matching_table and matching_table not in self.filtered_tables: - client_id_column = self.client_tables[matching_table] - table_reference = table_alias or table_name - filters.append( - f'{self._quote_identifier(table_reference)}.{self._quote_identifier(client_id_column)} = {client_id}') - self.filtered_tables.add(matching_table) + filters.append(f'{self._quote_identifier(table_reference)}.{self._quote_identifier(client_id_column)} = {client_id}') if filters: where_clause = " AND ".join(filters) @@ -294,217 +280,6 @@ def _apply_filter_to_single_query(self, sql_query: str, client_id: int) -> str: else: result = f"{main_query} WHERE {where_clause}" else: - result = main_query - - return result + group_by - - def _contains_subquery(self, parsed): - tokens = parsed.tokens if hasattr(parsed, 'tokens') else [parsed] - - for i, token in enumerate(tokens): - if isinstance(token, Identifier) and token.has_alias(): - if isinstance(token.tokens[0], Parenthesis): - return True - elif isinstance(token, Parenthesis): - if any(t.ttype is DML and t.value.upper() == 'SELECT' for t in token.tokens): - return True - # Recursively check inside parentheses - if self._contains_subquery(token): - return True - elif isinstance(token, Where): - in_found = False - for j, sub_token in enumerate(token.tokens): - if in_found: - if isinstance(sub_token, Parenthesis): - if any(t.ttype is DML and t.value.upper() == 'SELECT' for t in sub_token.tokens): - return True - elif hasattr(sub_token, 'ttype') and not sub_token.is_whitespace: - # Check if the token is a parenthesis-like structure - if '(' in sub_token.value and ')' in sub_token.value: - if 'SELECT' in sub_token.value.upper(): - return True - # If we find a non-whitespace token that's not a parenthesis, reset in_found - in_found = False - elif hasattr(sub_token, 'ttype') and sub_token.ttype is Keyword and sub_token.value.upper() == 'IN': - in_found = True - elif isinstance(sub_token, Comparison): - for item in sub_token.tokens: - if isinstance(item, Parenthesis): - if self._contains_subquery(item): - return True - elif hasattr(token, 'ttype') and token.ttype is Keyword and token.value.upper() == 'IN': - next_token = tokens[i+1] if i+1 < len(tokens) else None - if next_token: - if isinstance(next_token, Parenthesis): - if any(t.ttype is DML and t.value.upper() == 'SELECT' for t in next_token.tokens): - return True - elif hasattr(next_token, 'value') and '(' in next_token.value and ')' in next_token.value: - if 'SELECT' in next_token.value.upper(): - return True - - return False - - def _cleanup_whitespace(self, query: str) -> str: - # Split the query into lines - lines = query.split('\n') - cleaned_lines = [] - for line in lines: - # Remove leading/trailing whitespace from each line - line = line.strip() - # Replace multiple spaces with a single space, but not in quoted strings - line = re.sub(r'\s+(?=(?:[^\']*\'[^\']*\')*[^\']*$)', ' ', line) - # Ensure single space after commas, but not in quoted strings - line = re.sub( - r'\s*,\s*(?=(?:[^\']*\'[^\']*\')*[^\']*$)', ', ', line) - cleaned_lines.append(line) - # Join the lines back together - return '\n'.join(cleaned_lines) - - def _handle_subquery(self, parsed, client_id): - result = [] - tokens = parsed.tokens if hasattr(parsed, 'tokens') else [parsed] - - for token in tokens: - if isinstance(token, Identifier) and token.has_alias(): - if isinstance(token.tokens[0], Parenthesis): - subquery = token.tokens[0].tokens[1:-1] - subquery_str = ' '.join(str(t) for t in subquery) - filtered_subquery = self._apply_filter_recursive( - sqlparse.parse(subquery_str)[0], client_id) - alias = token.get_alias() - result.append(f"({filtered_subquery}) AS {alias}") - else: - result.append(str(token)) - elif isinstance(token, Parenthesis): - subquery = token.tokens[1:-1] - subquery_str = ' '.join(str(t) for t in subquery) - filtered_subquery = self._apply_filter_recursive( - sqlparse.parse(subquery_str)[0], client_id) - result.append(f"({filtered_subquery})") - elif isinstance(token, Where): - try: - filtered_where = self._handle_where_subqueries( - token, client_id) - result.append(str(filtered_where)) - except Exception as e: - result.append(str(token)) - else: - # Preserve whitespace tokens - if token.is_whitespace: - result.append(str(token)) - else: - # Add space before and after non-whitespace tokens, except for punctuation - if result and not result[-1].endswith(' ') and not str(token).startswith((')', ',', '.')): - result.append(' ') - result.append(str(token)) - if not str(token).endswith(('(', ',')): - result.append(' ') - - final_result = ''.join(result).strip() - return final_result + result = main_query - def _handle_where_subqueries(self, where_clause, client_id): - if self._is_cte_query(where_clause): - cte_part = self._extract_cte_definition(where_clause) - main_query = self._extract_main_query(where_clause) - - filtered_cte = self._apply_filter_recursive(cte_part, client_id) - - if 'WHERE' not in str(main_query).upper(): - main_query = self._add_where_clause_to_main_query( - main_query, client_id) - - return f"{filtered_cte} {main_query}" - else: - new_where_tokens = [] - i = 0 - while i < len(where_clause.tokens): - token = where_clause.tokens[i] - if token.ttype is Keyword and token.value.upper() == 'IN': - next_token = where_clause.token_next(i) - if next_token and isinstance(next_token[1], Parenthesis): - subquery = next_token[1].tokens[1:-1] - subquery_str = ' '.join(str(t) for t in subquery) - filtered_subquery = self._apply_filter_recursive( - sqlparse.parse(subquery_str)[0], client_id) - filtered_subquery_str = str(filtered_subquery) - try: - new_subquery_tokens = [ - Token(Whitespace, ' '), - Token(Punctuation, '(') - ] + sqlparse.parse(filtered_subquery_str)[0].tokens + [Token(Punctuation, ')')] - new_where_tokens.extend( - [token] + new_subquery_tokens) - except Exception as e: - # Fallback to original subquery with space - new_where_tokens.extend( - [token, Token(Whitespace, ' '), next_token[1]]) - i += 2 # Skip the next token as we've handled it - else: - new_where_tokens.append(token) - elif isinstance(token, Parenthesis): - subquery = token.tokens[1:-1] - subquery_str = ' '.join(str(t) for t in subquery) - if self._contains_subquery(sqlparse.parse(subquery_str)[0]): - filtered_subquery = self._apply_filter_recursive( - sqlparse.parse(subquery_str)[0], client_id) - filtered_subquery_str = str(filtered_subquery) - try: - new_subquery_tokens = sqlparse.parse( - f"({filtered_subquery_str})")[0].tokens - new_where_tokens.extend(new_subquery_tokens) - except Exception as e: - # Fallback to original subquery - new_where_tokens.append(token) - else: - new_where_tokens.append(token) - else: - new_where_tokens.append(token) - i += 1 - - # Add the client filter for the main table - try: - main_table = self._extract_main_table(where_clause) - if main_table: - main_table_filter = self._generate_client_filter( - main_table, client_id) - if main_table_filter: - filter_tokens = [ - Token(Whitespace, ' '), - Token(Keyword, 'AND'), - Token(Whitespace, ' ') - ] + sqlparse.parse(main_table_filter)[0].tokens - new_where_tokens.extend(filter_tokens) - except Exception as e: - print(f"error: {e}") - - where_clause.tokens = new_where_tokens - return where_clause - - def _generate_client_filter(self, table_name, client_id): - matching_table = self._find_matching_table(table_name) - if matching_table: - client_id_column = self.client_tables[matching_table] - return f'{self._quote_identifier(table_name)}.{self._quote_identifier(client_id_column)} = {client_id}' - return None - - def _extract_main_query(self, parsed): - main_query_tokens = [] - main_query_started = False - - for token in parsed.tokens: - if main_query_started: - main_query_tokens.append(token) - elif token.ttype is DML and token.value.upper() == 'SELECT': - main_query_started = True - main_query_tokens.append(token) - - return TokenList(main_query_tokens) - - def _extract_main_table(self, where_clause): - if where_clause.parent is None: - return None - for token in where_clause.parent.tokens: - if isinstance(token, Identifier): - return token.get_real_name() - return None + return result + group_by \ No newline at end of file diff --git a/tests/core/test_sql_query_filter.py b/tests/core/test_sql_query_filter.py index 80b7eb4..27bdd3a 100644 --- a/tests/core/test_sql_query_filter.py +++ b/tests/core/test_sql_query_filter.py @@ -98,10 +98,10 @@ def test_subquery_in_from(self): expected = 'SELECT * FROM (SELECT * FROM orders WHERE "orders"."user_id" = 1) AS subq' self.assertEqual(self.filter.apply_client_filter(query, 1), expected) - # def test_subquery_in_join(self): - # query = 'SELECT o.* FROM orders o JOIN (SELECT * FROM products) p ON o.product_id = p.id' - # expected = 'SELECT o.* FROM orders o JOIN (SELECT * FROM products WHERE "products"."company_id" = 1) p ON o.product_id = p.id WHERE "o"."user_id" = 1' - # self.assertEqual(self.filter.apply_client_filter(query, 1), expected) + def test_subquery_in_join(self): + query = 'SELECT o.* FROM orders o JOIN (SELECT * FROM products) p ON o.product_id = p.id' + expected = 'SELECT o.* FROM orders o JOIN (SELECT * FROM products WHERE "products"."company_id" = 1) p ON o.product_id = p.id WHERE "o"."user_id" = 1' + self.assertEqual(self.filter.apply_client_filter(query, 1), expected) def test_nested_subqueries(self): query = 'SELECT * FROM (SELECT * FROM (SELECT * FROM orders) AS inner_subq) AS outer_subq' @@ -123,7 +123,8 @@ def setUp(self): 'products': 'company_id', 'inventory.items': 'organization_id', 'items': 'organization_id', - 'customers': 'customer_id' + 'customers': 'customer_id', + 'categories': 'company_id' } self.filter = SQLQueryFilter( self.client_tables, schemas=['main', 'inventory']) @@ -217,62 +218,61 @@ def test_multiple_ctes(self): self.assertSQLEqual( self.filter.apply_client_filter(query, 1), expected) - # def test_cte_with_subquery(self): - # query = ''' - # WITH top_products AS ( - # SELECT p.id, p.name, SUM(o.quantity) as total_sold - # FROM products p - # JOIN (SELECT * FROM orders WHERE status = 'completed') o ON p.id = o.product_id - # GROUP BY p.id, p.name - # ORDER BY total_sold DESC - # LIMIT 10 - # ) - # SELECT * FROM top_products - # ''' - # expected = ''' - # WITH top_products AS ( - # SELECT p.id, p.name, SUM(o.quantity) as total_sold - # FROM products p - # JOIN (SELECT * FROM orders WHERE status = 'completed' AND "orders"."user_id" = 1) o ON p.id = o.product_id - # WHERE "p"."company_id" = 1 - # GROUP BY p.id, p.name - # ORDER BY total_sold DESC - # LIMIT 10 - # ) - # SELECT * FROM top_products - # ''' - # self.assertSQLEqual( - # self.filter.apply_client_filter(query, 1), expected) - - # def test_recursive_cte(self): - # query = ''' - # WITH RECURSIVE category_tree AS ( - # SELECT id, name, parent_id, 0 AS level - # FROM categories - # WHERE parent_id IS NULL - # UNION ALL - # SELECT c.id, c.name, c.parent_id, ct.level + 1 - # FROM categories c - # JOIN category_tree ct ON c.parent_id = ct.id - # ) - # SELECT * FROM category_tree - # ''' - # expected = ''' - # WITH RECURSIVE category_tree AS ( - # SELECT id, name, parent_id, 0 AS level - # FROM categories - # WHERE parent_id IS NULL AND "categories"."company_id" = 1 - # UNION ALL - # SELECT c.id, c.name, c.parent_id, ct.level + 1 - # FROM categories c - # JOIN category_tree ct ON c.parent_id = ct.id - # WHERE "c"."company_id" = 1 - # ) - # SELECT * FROM category_tree - # ''' - # self.assertSQLEqual( - # self.filter.apply_client_filter(query, 1), expected) + def test_cte_with_subquery(self): + query = ''' + WITH top_products AS ( + SELECT p.id, p.name, SUM(o.quantity) as total_sold + FROM products p + JOIN (SELECT * FROM orders WHERE status = 'completed') o ON p.id = o.product_id + GROUP BY p.id, p.name + ORDER BY total_sold DESC + LIMIT 10 + ) + SELECT * FROM top_products + ''' + expected = ''' + WITH top_products AS ( + SELECT p.id, p.name, SUM(o.quantity) as total_sold + FROM products p + JOIN (SELECT * FROM orders WHERE status = 'completed' AND "orders"."user_id" = 1) o ON p.id = o.product_id + WHERE "p"."company_id" = 1 + GROUP BY p.id, p.name + ORDER BY total_sold DESC + LIMIT 10 + ) + SELECT * FROM top_products + ''' + self.assertSQLEqual( + self.filter.apply_client_filter(query, 1), expected) + def test_recursive_cte(self): + query = ''' + WITH RECURSIVE category_tree AS ( + SELECT id, name, parent_id, 0 AS level + FROM categories + WHERE parent_id IS NULL + UNION ALL + SELECT c.id, c.name, c.parent_id, ct.level + 1 + FROM categories c + JOIN category_tree ct ON c.parent_id = ct.id + ) + SELECT * FROM category_tree + ''' + expected = ''' + WITH RECURSIVE category_tree AS ( + SELECT id, name, parent_id, 0 AS level + FROM categories + WHERE parent_id IS NULL AND "categories"."company_id" = 1 + UNION ALL + SELECT c.id, c.name, c.parent_id, ct.level + 1 + FROM categories c + JOIN category_tree ct ON c.parent_id = ct.id + WHERE "c"."company_id" = 1 + ) + SELECT * FROM category_tree + ''' + self.assertSQLEqual( + self.filter.apply_client_filter(query, 1), expected) if __name__ == '__main__': unittest.main()