Skip to content

Commit 5533ede

Browse files
authored
Merge pull request #264 from demitryfly/some-refactorings-1
refactor utils, add tests, move exceptions into separate module
2 parents 2ef4718 + d2413d5 commit 5533ede

File tree

10 files changed

+177
-80
lines changed

10 files changed

+177
-80
lines changed

docs/README.rst

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,10 @@ How to use
9696
Extract additional information from HQL (& other dialects)
9797
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
9898

99-
In some dialects like HQL there is a lot of additional information about table like, fore example, is it external table, STORED AS, location & etc. This property will be always empty in 'classic' SQL DB like PostgreSQL or MySQL and this is the reason, why by default this information are 'hidden'.
100-
Also some fields hidden in HQL, because they are simple not exists in HIVE, for example 'deferrable_initially'
99+
In some dialects like HQL there is a lot of additional information about table like, fore example, is it external table,
100+
STORED AS, location & etc. This property will be always empty in 'classic' SQL DB like PostgreSQL or MySQL
101+
and this is the reason, why by default this information is 'hidden'.
102+
Also some fields are hidden in HQL, because they are simple not exists in HIVE, for example 'deferrable_initially'
101103
To get this 'hql' specific details about table in output please use 'output_mode' argument in run() method.
102104

103105
example:

simple_ddl_parser/ddl_parser.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,11 @@
1616
Snowflake,
1717
SparkSQL,
1818
)
19+
# "DDLParserError" is an alias for backward compatibility
20+
from simple_ddl_parser.exception import SimpleDDLParserException as DDLParserError
1921
from simple_ddl_parser.parser import Parser
2022

2123

22-
class DDLParserError(Exception):
23-
pass
24-
25-
2624
class Dialects(
2725
SparkSQL,
2826
Snowflake,

simple_ddl_parser/exception.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
__all__ = [
2+
"SimpleDDLParserException",
3+
]
4+
5+
6+
class SimpleDDLParserException(Exception):
7+
""" Base exception in simple ddl parser library """
8+
pass
9+

simple_ddl_parser/output/core.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def group_by_type_result(self) -> None:
123123
else:
124124
_type.extend(item["comments"])
125125
break
126-
if result_as_dict["comments"] == []:
126+
if not result_as_dict["comments"]:
127127
del result_as_dict["comments"]
128128

129129
self.final_result = result_as_dict

simple_ddl_parser/output/table_data.py

Lines changed: 19 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,21 @@
33
from simple_ddl_parser.output.base_data import BaseData
44
from simple_ddl_parser.output.dialects import CommonDialectsFieldsMixin, dialect_by_name
55

6+
__all__ = [
7+
"TableData",
8+
]
9+
10+
11+
def _pre_process_kwargs(kwargs: dict, aliased_fields: dict) -> None:
12+
for alias, field_name in aliased_fields.items():
13+
if alias in kwargs:
14+
kwargs[field_name] = kwargs[alias]
15+
del kwargs[alias]
16+
17+
# todo: need to figure out how workaround it normally
18+
if kwargs.get("fields_terminated_by") == "_ddl_parser_comma_only_str":
19+
kwargs["fields_terminated_by"] = "','"
20+
621

722
class TableData:
823
cls_prefix = "Dialect"
@@ -13,34 +28,18 @@ def get_dialect_class(cls, kwargs: dict):
1328

1429
if output_mode and output_mode != "sql":
1530
main_cls = dialect_by_name.get(output_mode)
16-
cls = dataclass(
31+
return dataclass(
1732
type(
1833
f"{main_cls.__name__}{cls.cls_prefix}",
1934
(main_cls, CommonDialectsFieldsMixin),
2035
{},
2136
)
2237
)
23-
else:
24-
cls = BaseData
25-
26-
return cls
27-
28-
@staticmethod
29-
def pre_process_kwargs(kwargs: dict, aliased_fields: dict) -> dict:
30-
for alias, field_name in aliased_fields.items():
31-
if alias in kwargs:
32-
kwargs[field_name] = kwargs[alias]
33-
del kwargs[alias]
3438

35-
# todo: need to figure out how workaround it normally
36-
if (
37-
"fields_terminated_by" in kwargs
38-
and "_ddl_parser_comma_only_str" == kwargs["fields_terminated_by"]
39-
):
40-
kwargs["fields_terminated_by"] = "','"
39+
return BaseData
4140

4241
@classmethod
43-
def pre_load_mods(cls, main_cls, kwargs):
42+
def pre_load_mods(cls, main_cls, kwargs) -> dict:
4443
if kwargs.get("output_mode") == "bigquery":
4544
if kwargs.get("schema"):
4645
kwargs["dataset"] = kwargs["schema"]
@@ -55,7 +54,7 @@ def pre_load_mods(cls, main_cls, kwargs):
5554
for name, value in cls_fields.items()
5655
if value.metadata and "alias" in value.metadata
5756
}
58-
cls.pre_process_kwargs(kwargs, aliased_fields)
57+
_pre_process_kwargs(kwargs, aliased_fields)
5958
table_main_args = {
6059
k.lower(): v for k, v in kwargs.items() if k.lower() in cls_fields
6160
}

simple_ddl_parser/parser.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,10 @@
66

77
from ply import lex, yacc
88

9+
from simple_ddl_parser.exception import SimpleDDLParserException
910
from simple_ddl_parser.output.core import Output, dump_data_to_file
1011
from simple_ddl_parser.output.dialects import dialect_by_name
11-
from simple_ddl_parser.utils import (
12-
SimpleDDLParserException,
13-
find_first_unpair_closed_par,
14-
)
12+
from simple_ddl_parser.utils import find_first_unpair_closed_par
1513

1614
# open comment
1715
OP_COM = "/*"

simple_ddl_parser/tokens.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -150,8 +150,8 @@
150150

151151

152152
tokens = tuple(
153-
set(
154-
[
153+
{
154+
*[
155155
"ID",
156156
"DOT",
157157
"STRING_BASE",
@@ -161,14 +161,14 @@
161161
"LT",
162162
"RT",
163163
"COMMAT",
164-
]
165-
+ list(definition_statements.values())
166-
+ list(common_statements.values())
167-
+ list(columns_definition.values())
168-
+ list(sequence_reserved.values())
169-
+ list(after_columns_tokens.values())
170-
+ list(alter_tokens.values())
171-
)
164+
],
165+
*definition_statements.values(),
166+
*common_statements.values(),
167+
*columns_definition.values(),
168+
*sequence_reserved.values(),
169+
*after_columns_tokens.values(),
170+
*alter_tokens.values(),
171+
}
172172
)
173173

174174
symbol_tokens = {

simple_ddl_parser/utils.py

Lines changed: 59 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,61 +1,83 @@
11
import re
2-
from typing import List
2+
from typing import List, Tuple, Optional, Union, Any
33

4+
# Backward compatibility import
5+
from simple_ddl_parser.exception import SimpleDDLParserException
46

5-
def remove_par(p_list: List[str]) -> List[str]:
6-
remove_list = ["(", ")"]
7-
for symbol in remove_list:
8-
while symbol in p_list:
9-
p_list.remove(symbol)
7+
__all__ = [
8+
"remove_par",
9+
"check_spec",
10+
"find_first_unpair_closed_par",
11+
"normalize_name",
12+
"get_table_id",
13+
"SimpleDDLParserException"
14+
]
15+
16+
_parentheses = ('(', ')')
17+
18+
19+
def remove_par(p_list: List[Union[str, Any]]) -> List[Union[str, Any]]:
20+
"""
21+
Remove the parentheses from the given list
22+
23+
Warn: p_list may contain unhashable types, such as 'dict'.
24+
"""
25+
j = 0
26+
for i in range(len(p_list)):
27+
if p_list[i] not in _parentheses:
28+
p_list[j] = p_list[i]
29+
j += 1
30+
while j < len(p_list):
31+
p_list.pop()
1032
return p_list
1133

1234

13-
spec_mapper = {
35+
_spec_mapper = {
1436
"'pars_m_t'": "'\t'",
1537
"'pars_m_n'": "'\n'",
1638
"'pars_m_dq'": '"',
1739
"pars_m_single": "'",
1840
}
1941

2042

21-
def check_spec(value: str) -> str:
22-
replace_value = spec_mapper.get(value)
23-
if not replace_value:
24-
for item in spec_mapper:
25-
if item in value:
26-
replace_value = value.replace(item, spec_mapper[item])
27-
break
28-
else:
29-
replace_value = value
30-
return replace_value
31-
32-
33-
def find_first_unpair_closed_par(str_: str) -> int:
34-
stack = []
35-
n = -1
36-
for i in str_:
37-
n += 1
38-
if i == ")":
39-
if not stack:
40-
return n
41-
else:
42-
stack.pop(-1)
43-
elif i == "(":
44-
stack.append(i)
43+
def check_spec(string: str) -> str:
44+
"""
45+
Replace escape tokens to their representation
46+
"""
47+
if string in _spec_mapper:
48+
return _spec_mapper[string]
49+
for replace_from, replace_to in _spec_mapper.items():
50+
if replace_from in string:
51+
return string.replace(replace_from, replace_to)
52+
return string
53+
54+
55+
def find_first_unpair_closed_par(str_: str) -> Optional[int]:
56+
"""
57+
Returns index of first unpair close parentheses.
58+
Or returns None, if there is no one.
59+
"""
60+
count_open = 0
61+
for i, char in enumerate(str_):
62+
if char == '(':
63+
count_open += 1
64+
if char == ')':
65+
count_open -= 1
66+
if count_open < 0:
67+
return i
68+
return None
4569

4670

4771
def normalize_name(name: str) -> str:
48-
# clean up [] and " symbols from names
72+
"""
73+
Clean up [] and " characters from the given name
74+
"""
4975
clean_up_re = r'[\[\]"]'
5076
return re.sub(clean_up_re, "", name).lower()
5177

5278

53-
def get_table_id(schema_name: str, table_name: str):
79+
def get_table_id(schema_name: str, table_name: str) -> Tuple[str, str]:
5480
table_name = normalize_name(table_name)
5581
if schema_name:
5682
schema_name = normalize_name(schema_name)
5783
return (table_name, schema_name)
58-
59-
60-
class SimpleDDLParserException(Exception):
61-
pass

tests/non_statement_tests/test_common.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import pytest
22

3-
from simple_ddl_parser import DDLParser, DDLParserError
3+
from simple_ddl_parser import DDLParser, SimpleDDLParserException
44
from simple_ddl_parser.output.core import get_table_id
55

66

@@ -29,7 +29,7 @@ def test_silent_false_flag():
2929
created_timestamp TIMESTAMPTZ NOT NULL DEFAULT ALTER (now() at time zone 'utc')
3030
);
3131
"""
32-
with pytest.raises(DDLParserError) as e:
32+
with pytest.raises(SimpleDDLParserException) as e:
3333
DDLParser(ddl, silent=False).run(group_by_type=True)
3434

3535
assert "Unknown statement" in e.value[1]

tests/test_utils.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,69 @@
1+
import pytest
2+
3+
from simple_ddl_parser import utils
4+
5+
6+
@pytest.mark.parametrize(
7+
"expression, expected_result",
8+
[
9+
([], []),
10+
(["("], []),
11+
([")"], []),
12+
(["(", ")"], []),
13+
([")", "("], []),
14+
(["(", "A"], ["A"]),
15+
(["A", ")"], ["A"]),
16+
(["(", "A", ")"], ["A"]),
17+
(["A", ")", ")"], ["A"]),
18+
(["(", "(", "A"], ["A"]),
19+
(["A", "B", "C"], ["A", "B", "C"]),
20+
(["A", "(", "(", "B", "C", "("], ["A", "B", "C"]),
21+
(["A", ")", "B", ")", "(", "C"], ["A", "B", "C"]),
22+
(["(", "A", ")", "B", "C", ")"], ["A", "B", "C"]),
23+
([dict()], [dict()]), # Edge case (unhashable types)
24+
]
25+
)
26+
def test_remove_par(expression, expected_result):
27+
assert utils.remove_par(expression) == expected_result
28+
29+
30+
@pytest.mark.parametrize(
31+
"expression, expected_result",
32+
[
33+
("", ""),
34+
("simple", "simple"),
35+
36+
("'pars_m_t'", "'\t'"),
37+
("'pars_m_n'", "'\n'"),
38+
("'pars_m_dq'", '"'),
39+
("pars_m_single", "'"),
40+
41+
("STRING_'pars_m_t'STRING", "STRING_'\t'STRING"),
42+
("STRING_'pars_m_n'STRING", "STRING_'\n'STRING"),
43+
("STRING_'pars_m_dq'STRING", "STRING_\"STRING"),
44+
("STRING_pars_m_singleSTRING", "STRING_'STRING"),
45+
46+
("pars_m_single pars_m_single", "' '"),
47+
("'pars_m_t''pars_m_n'", "'\t''pars_m_n'"), # determined by dict element order
48+
]
49+
)
50+
def test_check_spec(expression, expected_result):
51+
assert utils.check_spec(expression) == expected_result
52+
53+
54+
@pytest.mark.parametrize(
55+
"expression, expected_result",
56+
[
57+
(")", 0),
58+
(")()", 0),
59+
("())", 2),
60+
("()())", 4),
61+
("", None),
62+
("text", None),
63+
("()", None),
64+
("(balanced) (brackets)", None),
65+
("(not)) (balanced) (brackets", 5)
66+
]
67+
)
68+
def test_find_first_unpair_closed_par(expression, expected_result):
69+
assert utils.find_first_unpair_closed_par(expression) == expected_result

0 commit comments

Comments
 (0)