Skip to content

Commit e6006ae

Browse files
authored
Fix normalize case input in cli (#258)
The Databricks cli currently only supports string input
1 parent 3996e67 commit e6006ae

File tree

3 files changed

+8
-7
lines changed

3 files changed

+8
-7
lines changed

src/databricks/labs/lsql/cli.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,13 +43,14 @@ def create_dashboard(
4343

4444

4545
@lsql.command(is_unauthenticated=True)
46-
def fmt(folder: Path = Path.cwd(), normalize_case: bool = True):
46+
def fmt(folder: Path = Path.cwd(), *, normalize_case: str = "true"):
4747
"""Format SQL files in a folder"""
4848
logger.debug("Formatting SQL files ...")
4949
folder = Path(folder)
50+
should_normalize_case = normalize_case in STRING_AFFIRMATIVES
5051
for sql_file in folder.glob("**/*.sql"):
5152
sql = sql_file.read_text()
52-
formatted_sql = QueryTile.format(sql, normalize_case)
53+
formatted_sql = QueryTile.format(sql, normalize_case=should_normalize_case)
5354
sql_file.write_text(formatted_sql)
5455
logger.debug(f"Formatted {sql_file}")
5556

src/databricks/labs/lsql/dashboards.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -418,16 +418,16 @@ def validate(self) -> None:
418418
raise ValueError(f"Invalid query content: {self.content}") from e
419419

420420
@staticmethod
421-
def format(content: str, normalize_case: bool = True, *, max_text_width: int = 120) -> str:
421+
def format(content: str, *, max_text_width: int = 120, normalize_case: bool = True) -> str:
422422
"""Format the content
423423
424424
Args:
425425
content : str
426426
The content to format
427-
max_text_width : int
427+
max_text_width : int, optional (default: 120)
428428
The maximum text width to wrap at
429-
normalize_case : bool
430-
If the query should be normalized to lower case
429+
normalize_case : bool, optional (default: True)
430+
If the query identifiers should be normalized to lower case
431431
"""
432432
try:
433433
parsed_query = sqlglot.parse(content, dialect=_SQL_DIALECT)

tests/unit/test_dashboards.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -843,7 +843,7 @@ def test_query_formats_no_normalize():
843843
FROM system.access.audit AS a
844844
LEFT OUTER JOIN inventory.clusters AS c
845845
ON a.request_params.clusterId = c.cluster_id AND a.action_name = 'runCommand'"""
846-
assert QueryTile.format(query, False) == query_formatted
846+
assert QueryTile.format(query, normalize_case=False) == query_formatted
847847

848848

849849
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)