Skip to content

Commit b23b721

Browse files
committed
fix tests
1 parent c0c2054 commit b23b721

File tree

4 files changed

+18
-30
lines changed

4 files changed

+18
-30
lines changed

src/databricks/labs/lakebridge/reconcile/connectors/jdbc_reader.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,7 @@ def _get_jdbc_reader(self, query, jdbc_url, driver, additional_options: dict | N
2222
.option("dbtable", f"({query}) tmp")
2323
)
2424
if isinstance(additional_options, dict):
25-
for key, value in additional_options.items():
26-
reader = reader.option(key, value)
25+
reader = reader.options(**additional_options)
2726
return reader
2827

2928
@staticmethod

src/databricks/labs/lakebridge/reconcile/connectors/oracle.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -64,9 +64,9 @@ def read_data(
6464
table_query = query.replace(":tbl", f"{schema}.{table}")
6565
try:
6666
if options is None:
67-
return self.reader(table_query).options(**self._get_timestamp_options()).load()
67+
return self.reader(table_query, self._get_timestamp_options()).load()
6868
reader_options = self._get_jdbc_reader_options(options) | self._get_timestamp_options()
69-
df = self.reader(table_query).options(**reader_options).load()
69+
df = self.reader(table_query, reader_options).load()
7070
logger.warning(f"Fetching data using query: \n`{table_query}`")
7171

7272
# Convert all column names to lower case
@@ -107,12 +107,14 @@ def _get_timestamp_options() -> dict[str, str]:
107107
"HH24:MI:SS''');END;",
108108
}
109109

110-
def reader(self, query: str) -> DataFrameReader:
110+
def reader(self, query: str, options: dict | None = None) -> DataFrameReader:
111+
if options is None:
112+
options = {}
111113
user = self._get_secret('user')
112114
password = self._get_secret('password')
113115
logger.debug(f"Using user: {user} to connect to Oracle")
114116
return self._get_jdbc_reader(
115-
query, self.get_jdbc_url, OracleDataSource._DRIVER, {"user": user, "password": password}
117+
query, self.get_jdbc_url, OracleDataSource._DRIVER, {**options, "user": user, "password": password}
116118
)
117119

118120
def normalize_identifier(self, identifier: str) -> NormalizedIdentifier:

tests/unit/reconcile/connectors/test_oracle.py

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -76,9 +76,7 @@ def test_read_data_with_options():
7676
)
7777
spark.read.format().option().option.assert_called_with("driver", "oracle.jdbc.OracleDriver")
7878
spark.read.format().option().option().option.assert_called_with("dbtable", "(select 1 from data.employee) tmp")
79-
spark.read.format().option().option().option().option.assert_called_with("user", "my_user")
80-
spark.read.format().option().option().option().option().option.assert_called_with("password", "my_password")
81-
jdbc_actual_args = spark.read.format().option().option().option().option().option().options.call_args.kwargs
79+
jdbc_actual_args = spark.read.format().option().option().option().options.call_args.kwargs
8280
jdbc_expected_args = {
8381
"numPartitions": 50,
8482
"partitionColumn": "s_nationkey",
@@ -89,9 +87,11 @@ def test_read_data_with_options():
8987
"sessionInitStatement": r"BEGIN dbms_session.set_nls('nls_date_format', "
9088
r"'''YYYY-MM-DD''');dbms_session.set_nls('nls_timestamp_format', '''YYYY-MM-DD "
9189
r"HH24:MI:SS''');END;",
90+
"user": "my_user",
91+
"password": "my_password",
9292
}
9393
assert jdbc_actual_args == jdbc_expected_args
94-
spark.read.format().option().option().option().option().option().options().load.assert_called_once()
94+
spark.read.format().option().option().option().options().load.assert_called_once()
9595

9696

9797
def test_get_schema():
@@ -143,9 +143,7 @@ def test_read_data_exception_handling():
143143
filters=None,
144144
)
145145

146-
spark.read.format().option().option().option().option().option().options().load.side_effect = RuntimeError(
147-
"Test Exception"
148-
)
146+
spark.read.format().option().option().option().options().load.side_effect = RuntimeError("Test Exception")
149147

150148
# Call the read_data method with the Tables configuration and assert that a PySparkException is raised
151149
with pytest.raises(
@@ -160,7 +158,7 @@ def test_get_schema_exception_handling():
160158
engine, spark, ws, scope = initial_setup()
161159
ords = OracleDataSource(engine, spark, ws, scope)
162160

163-
spark.read.format().option().option().option().option().option().load.side_effect = RuntimeError("Test Exception")
161+
spark.read.format().option().option().option().options().load.side_effect = RuntimeError("Test Exception")
164162

165163
# Call the get_schema method with predefined table, schema, and catalog names and assert that a PySparkException
166164
# is raised

tests/unit/reconcile/connectors/test_sql_server.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -55,20 +55,7 @@ def test_get_jdbc_url_happy():
5555
url = data_source.get_jdbc_url
5656
# Assert that the URL is generated correctly
5757
assert url == (
58-
"""jdbc:sqlserver://my_host:777;databaseName=my_database;user=my_user;password=my_password;encrypt=true;trustServerCertificate=true;"""
59-
)
60-
61-
62-
def test_get_jdbc_url_fail():
63-
# initial setup
64-
engine, spark, ws, scope = initial_setup()
65-
ws.secrets.get_secret.side_effect = mock_secret
66-
# create object for TSQLServerDataSource
67-
data_source = TSQLServerDataSource(engine, spark, ws, scope)
68-
url = data_source.get_jdbc_url
69-
# Assert that the URL is generated correctly
70-
assert url == (
71-
"""jdbc:sqlserver://my_host:777;databaseName=my_database;user=my_user;password=my_password;encrypt=true;trustServerCertificate=true;"""
58+
"""jdbc:sqlserver://my_host:777;databaseName=my_database;encrypt=true;trustServerCertificate=true;"""
7259
)
7360

7461

@@ -96,7 +83,7 @@ def test_read_data_with_options():
9683
spark.read.format.assert_called_with("jdbc")
9784
spark.read.format().option.assert_called_with(
9885
"url",
99-
"jdbc:sqlserver://my_host:777;databaseName=my_database;user=my_user;password=my_password;encrypt=true;trustServerCertificate=true;",
86+
"jdbc:sqlserver://my_host:777;databaseName=my_database;encrypt=true;trustServerCertificate=true;",
10087
)
10188
spark.read.format().option().option.assert_called_with("driver", "com.microsoft.sqlserver.jdbc.SQLServerDriver")
10289
spark.read.format().option().option().option.assert_called_with(
@@ -109,6 +96,8 @@ def test_read_data_with_options():
10996
"lowerBound": '0',
11097
"upperBound": "100",
11198
"fetchsize": 100,
99+
"user": "my_user",
100+
"password": "my_password",
112101
}
113102
assert actual_args == expected_args
114103
spark.read.format().option().option().option().options().load.assert_called_once()
@@ -166,7 +155,7 @@ def test_get_schema_exception_handling():
166155
engine, spark, ws, scope = initial_setup()
167156
data_source = TSQLServerDataSource(engine, spark, ws, scope)
168157

169-
spark.read.format().option().option().option().option().load.side_effect = RuntimeError("Test Exception")
158+
spark.read.format().option().option().option().options().load.side_effect = RuntimeError("Test Exception")
170159

171160
# Call the get_schema method with predefined table, schema, and catalog names and assert that a PySparkException
172161
# is raised

0 commit comments

Comments
 (0)