Skip to content

Commit 0349ea8

Browse files
authored
Handle JDBC passwords containing special chars (#2146)
<!-- REMOVE IRRELEVANT COMMENTS BEFORE CREATING A PULL REQUEST --> ## Changes <!-- Summary of your changes that are easy to understand. Add screenshots when necessary, they're helpful to illustrate the before and after state --> ### What does this PR do? ### Relevant implementation details * Support username and password in spark options instead of jdbc url * This makes sure non-url compliant chars in the password are supported
1 parent 660a23d commit 0349ea8

File tree

7 files changed

+54
-49
lines changed

7 files changed

+54
-49
lines changed

src/databricks/labs/lakebridge/reconcile/connectors/jdbc_reader.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@ def _get_jdbc_reader(self, query, jdbc_url, driver, additional_options: dict | N
2222
.option("dbtable", f"({query}) tmp")
2323
)
2424
if isinstance(additional_options, dict):
25-
for key, value in additional_options.items():
26-
reader = reader.option(key, value)
25+
reader = reader.options(**additional_options)
2726
return reader
2827

2928
@staticmethod

src/databricks/labs/lakebridge/reconcile/connectors/oracle.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import re
22
import logging
3+
from collections.abc import Mapping
34
from datetime import datetime
45

56
from pyspark.errors import PySparkException
@@ -12,7 +13,7 @@
1213
from databricks.labs.lakebridge.reconcile.connectors.models import NormalizedIdentifier
1314
from databricks.labs.lakebridge.reconcile.connectors.secrets import SecretsMixin
1415
from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils
15-
from databricks.labs.lakebridge.reconcile.recon_config import JdbcReaderOptions, Schema
16+
from databricks.labs.lakebridge.reconcile.recon_config import JdbcReaderOptions, Schema, OptionalPrimitiveType
1617
from databricks.sdk import WorkspaceClient
1718

1819
logger = logging.getLogger(__name__)
@@ -64,9 +65,9 @@ def read_data(
6465
table_query = query.replace(":tbl", f"{schema}.{table}")
6566
try:
6667
if options is None:
67-
return self.reader(table_query).options(**self._get_timestamp_options()).load()
68+
return self.reader(table_query, self._get_timestamp_options()).load()
6869
reader_options = self._get_jdbc_reader_options(options) | self._get_timestamp_options()
69-
df = self.reader(table_query).options(**reader_options).load()
70+
df = self.reader(table_query, reader_options).load()
7071
logger.warning(f"Fetching data using query: \n`{table_query}`")
7172

7273
# Convert all column names to lower case
@@ -107,12 +108,14 @@ def _get_timestamp_options() -> dict[str, str]:
107108
"HH24:MI:SS''');END;",
108109
}
109110

110-
def reader(self, query: str) -> DataFrameReader:
111+
def reader(self, query: str, options: Mapping[str, OptionalPrimitiveType] | None = None) -> DataFrameReader:
112+
if options is None:
113+
options = {}
111114
user = self._get_secret('user')
112115
password = self._get_secret('password')
113116
logger.debug(f"Using user: {user} to connect to Oracle")
114117
return self._get_jdbc_reader(
115-
query, self.get_jdbc_url, OracleDataSource._DRIVER, {"user": user, "password": password}
118+
query, self.get_jdbc_url, OracleDataSource._DRIVER, {**options, "user": user, "password": password}
116119
)
117120

118121
def normalize_identifier(self, identifier: str) -> NormalizedIdentifier:

src/databricks/labs/lakebridge/reconcile/connectors/tsql.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import re
22
import logging
33
from datetime import datetime
4+
from collections.abc import Mapping
45

56
from pyspark.errors import PySparkException
67
from pyspark.sql import DataFrame, DataFrameReader, SparkSession
@@ -12,7 +13,7 @@
1213
from databricks.labs.lakebridge.reconcile.connectors.models import NormalizedIdentifier
1314
from databricks.labs.lakebridge.reconcile.connectors.secrets import SecretsMixin
1415
from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils
15-
from databricks.labs.lakebridge.reconcile.recon_config import JdbcReaderOptions, Schema
16+
from databricks.labs.lakebridge.reconcile.recon_config import JdbcReaderOptions, Schema, OptionalPrimitiveType
1617
from databricks.sdk import WorkspaceClient
1718

1819
logger = logging.getLogger(__name__)
@@ -71,8 +72,6 @@ def get_jdbc_url(self) -> str:
7172
return (
7273
f"jdbc:{self._DRIVER}://{self._get_secret('host')}:{self._get_secret('port')};"
7374
f"databaseName={self._get_secret('database')};"
74-
f"user={self._get_secret('user')};"
75-
f"password={self._get_secret('password')};"
7675
f"encrypt={self._get_secret('encrypt')};"
7776
f"trustServerCertificate={self._get_secret('trustServerCertificate')};"
7877
)
@@ -96,10 +95,10 @@ def read_data(
9695
prepare_query_string = ""
9796
try:
9897
if options is None:
99-
df = self.reader(query, prepare_query_string).load()
98+
df = self.reader(query, {"prepareQuery": prepare_query_string}).load()
10099
else:
101-
options = self._get_jdbc_reader_options(options)
102-
df = self._get_jdbc_reader(table_query, self.get_jdbc_url, self._DRIVER).options(**options).load()
100+
spark_options = self._get_jdbc_reader_options(options)
101+
df = self.reader(table_query, spark_options).load()
103102
return df.select([col(column).alias(column.lower()) for column in df.columns])
104103
except (RuntimeError, PySparkException) as e:
105104
return self.log_and_throw_exception(e, "data", table_query)
@@ -133,8 +132,18 @@ def get_schema(
133132
except (RuntimeError, PySparkException) as e:
134133
return self.log_and_throw_exception(e, "schema", schema_query)
135134

136-
def reader(self, query: str, prepare_query_str="") -> DataFrameReader:
137-
return self._get_jdbc_reader(query, self.get_jdbc_url, self._DRIVER, {"prepareQuery": prepare_query_str})
135+
def reader(self, query: str, options: Mapping[str, OptionalPrimitiveType] | None = None) -> DataFrameReader:
136+
if options is None:
137+
options = {}
138+
139+
creds = self._get_user_password()
140+
return self._get_jdbc_reader(query, self.get_jdbc_url, self._DRIVER, {**options, **creds})
141+
142+
def _get_user_password(self) -> Mapping[str, str]:
143+
return {
144+
"user": self._get_secret("user"),
145+
"password": self._get_secret("password"),
146+
}
138147

139148
def normalize_identifier(self, identifier: str) -> NormalizedIdentifier:
140149
return DialectUtils.normalize_identifier(

src/databricks/labs/lakebridge/reconcile/recon_config.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44

55
from dataclasses import dataclass
66
from collections.abc import Callable
7-
87
from sqlglot import expressions as exp
98

109
from databricks.labs.lakebridge.reconcile.connectors.dialect_utils import DialectUtils
@@ -28,6 +27,9 @@
2827
RECONCILE_OPERATION_NAME = "reconcile"
2928
AGG_RECONCILE_OPERATION_NAME = "aggregates-reconcile"
3029

30+
PrimitiveType = bool | int | float | str
31+
OptionalPrimitiveType = PrimitiveType | None
32+
3133

3234
class TableThresholdBoundsException(ValueError):
3335
"""Raise the error when the bounds for table threshold are invalid"""

tests/integration/reconcile/connectors/test_read_schema.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections.abc import Mapping
12
from unittest.mock import create_autospec
23

34
import pytest
@@ -8,6 +9,7 @@
89
from databricks.labs.lakebridge.reconcile.connectors.oracle import OracleDataSource
910
from databricks.labs.lakebridge.reconcile.connectors.snowflake import SnowflakeDataSource
1011
from databricks.labs.lakebridge.reconcile.connectors.tsql import TSQLServerDataSource
12+
from databricks.labs.lakebridge.reconcile.recon_config import OptionalPrimitiveType
1113
from databricks.labs.lakebridge.transpiler.sqlglot.dialect_utils import get_dialect
1214

1315
from databricks.sdk import WorkspaceClient
@@ -22,11 +24,12 @@ def __init__(self, spark, ws):
2224

2325
@property
2426
def get_jdbc_url(self) -> str:
25-
return (
26-
self._test_env.get("TEST_TSQL_JDBC")
27-
+ f"user={self._test_env.get('TEST_TSQL_USER')};"
28-
+ f"password={self._test_env.get('TEST_TSQL_PASS')};"
29-
)
27+
return self._test_env.get("TEST_TSQL_JDBC")
28+
29+
def _get_user_password(self) -> dict:
30+
user = self._test_env.get("TEST_TSQL_USER")
31+
password = self._test_env.get("TEST_TSQL_PASS")
32+
return {"user": user, "password": password}
3033

3134

3235
class OracleDataSourceUnderTest(OracleDataSource):
@@ -38,11 +41,13 @@ def __init__(self, spark, ws):
3841
def get_jdbc_url(self) -> str:
3942
return self._test_env.get("TEST_ORACLE_JDBC")
4043

41-
def reader(self, query: str) -> DataFrameReader:
44+
def reader(self, query: str, options: Mapping[str, OptionalPrimitiveType] | None = None) -> DataFrameReader:
45+
if options is None:
46+
options = {}
4247
user = self._test_env.get("TEST_ORACLE_USER")
4348
password = self._test_env.get("TEST_ORACLE_PASSWORD")
4449
return self._get_jdbc_reader(
45-
query, self.get_jdbc_url, OracleDataSource._DRIVER, {"user": user, "password": password}
50+
query, self.get_jdbc_url, OracleDataSource._DRIVER, {**options, "user": user, "password": password}
4651
)
4752

4853

@@ -75,12 +80,12 @@ def _get_snowflake_options(self):
7580
return opts
7681

7782

78-
@pytest.mark.skip(reason="Add the creds to Github secrets and populate the actions' env to enable this test")
83+
@pytest.mark.skip(reason="Run in acceptance environment only")
7984
def test_sql_server_read_schema_happy(mock_spark):
8085
mock_ws = create_autospec(WorkspaceClient)
8186
connector = TSQLServerDataSourceUnderTest(mock_spark, mock_ws)
8287

83-
columns = connector.get_schema("labs_azure_sandbox_remorph", "dbo", "Employees")
88+
columns = connector.get_schema("labs_azure_sandbox_remorph", "dbo", "reconcile_in")
8489
assert columns
8590

8691

tests/unit/reconcile/connectors/test_oracle.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,7 @@ def test_read_data_with_options():
7676
)
7777
spark.read.format().option().option.assert_called_with("driver", "oracle.jdbc.OracleDriver")
7878
spark.read.format().option().option().option.assert_called_with("dbtable", "(select 1 from data.employee) tmp")
79-
spark.read.format().option().option().option().option.assert_called_with("user", "my_user")
80-
spark.read.format().option().option().option().option().option.assert_called_with("password", "my_password")
81-
jdbc_actual_args = spark.read.format().option().option().option().option().option().options.call_args.kwargs
79+
jdbc_actual_args = spark.read.format().option().option().option().options.call_args.kwargs
8280
jdbc_expected_args = {
8381
"numPartitions": 50,
8482
"partitionColumn": "s_nationkey",
@@ -89,9 +87,11 @@ def test_read_data_with_options():
8987
"sessionInitStatement": r"BEGIN dbms_session.set_nls('nls_date_format', "
9088
r"'''YYYY-MM-DD''');dbms_session.set_nls('nls_timestamp_format', '''YYYY-MM-DD "
9189
r"HH24:MI:SS''');END;",
90+
"user": "my_user",
91+
"password": "my_password",
9292
}
9393
assert jdbc_actual_args == jdbc_expected_args
94-
spark.read.format().option().option().option().option().option().options().load.assert_called_once()
94+
spark.read.format().option().option().option().options().load.assert_called_once()
9595

9696

9797
def test_get_schema():
@@ -143,9 +143,7 @@ def test_read_data_exception_handling():
143143
filters=None,
144144
)
145145

146-
spark.read.format().option().option().option().option().option().options().load.side_effect = RuntimeError(
147-
"Test Exception"
148-
)
146+
spark.read.format().option().option().option().options().load.side_effect = RuntimeError("Test Exception")
149147

150148
# Call the read_data method with the Tables configuration and assert that a PySparkException is raised
151149
with pytest.raises(
@@ -160,7 +158,7 @@ def test_get_schema_exception_handling():
160158
engine, spark, ws, scope = initial_setup()
161159
ords = OracleDataSource(engine, spark, ws, scope)
162160

163-
spark.read.format().option().option().option().option().option().load.side_effect = RuntimeError("Test Exception")
161+
spark.read.format().option().option().option().options().load.side_effect = RuntimeError("Test Exception")
164162

165163
# Call the get_schema method with predefined table, schema, and catalog names and assert that a PySparkException
166164
# is raised

tests/unit/reconcile/connectors/test_sql_server.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -55,20 +55,7 @@ def test_get_jdbc_url_happy():
5555
url = data_source.get_jdbc_url
5656
# Assert that the URL is generated correctly
5757
assert url == (
58-
"""jdbc:sqlserver://my_host:777;databaseName=my_database;user=my_user;password=my_password;encrypt=true;trustServerCertificate=true;"""
59-
)
60-
61-
62-
def test_get_jdbc_url_fail():
63-
# initial setup
64-
engine, spark, ws, scope = initial_setup()
65-
ws.secrets.get_secret.side_effect = mock_secret
66-
# create object for TSQLServerDataSource
67-
data_source = TSQLServerDataSource(engine, spark, ws, scope)
68-
url = data_source.get_jdbc_url
69-
# Assert that the URL is generated correctly
70-
assert url == (
71-
"""jdbc:sqlserver://my_host:777;databaseName=my_database;user=my_user;password=my_password;encrypt=true;trustServerCertificate=true;"""
58+
"""jdbc:sqlserver://my_host:777;databaseName=my_database;encrypt=true;trustServerCertificate=true;"""
7259
)
7360

7461

@@ -96,7 +83,7 @@ def test_read_data_with_options():
9683
spark.read.format.assert_called_with("jdbc")
9784
spark.read.format().option.assert_called_with(
9885
"url",
99-
"jdbc:sqlserver://my_host:777;databaseName=my_database;user=my_user;password=my_password;encrypt=true;trustServerCertificate=true;",
86+
"jdbc:sqlserver://my_host:777;databaseName=my_database;encrypt=true;trustServerCertificate=true;",
10087
)
10188
spark.read.format().option().option.assert_called_with("driver", "com.microsoft.sqlserver.jdbc.SQLServerDriver")
10289
spark.read.format().option().option().option.assert_called_with(
@@ -109,6 +96,8 @@ def test_read_data_with_options():
10996
"lowerBound": '0',
11097
"upperBound": "100",
11198
"fetchsize": 100,
99+
"user": "my_user",
100+
"password": "my_password",
112101
}
113102
assert actual_args == expected_args
114103
spark.read.format().option().option().option().options().load.assert_called_once()
@@ -166,7 +155,7 @@ def test_get_schema_exception_handling():
166155
engine, spark, ws, scope = initial_setup()
167156
data_source = TSQLServerDataSource(engine, spark, ws, scope)
168157

169-
spark.read.format().option().option().option().option().load.side_effect = RuntimeError("Test Exception")
158+
spark.read.format().option().option().option().options().load.side_effect = RuntimeError("Test Exception")
170159

171160
# Call the get_schema method with predefined table, schema, and catalog names and assert that a PySparkException
172161
# is raised

0 commit comments

Comments
 (0)