Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ def _get_jdbc_reader(self, query, jdbc_url, driver, additional_options: dict | N
.option("dbtable", f"({query}) tmp")
)
if isinstance(additional_options, dict):
for key, value in additional_options.items():
reader = reader.option(key, value)
reader = reader.options(**additional_options)
return reader

@staticmethod
Expand Down
10 changes: 6 additions & 4 deletions src/databricks/labs/lakebridge/reconcile/connectors/oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,9 @@ def read_data(
table_query = query.replace(":tbl", f"{schema}.{table}")
try:
if options is None:
return self.reader(table_query).options(**self._get_timestamp_options()).load()
return self.reader(table_query, self._get_timestamp_options()).load()
reader_options = self._get_jdbc_reader_options(options) | self._get_timestamp_options()
df = self.reader(table_query).options(**reader_options).load()
df = self.reader(table_query, reader_options).load()
logger.warning(f"Fetching data using query: \n`{table_query}`")

# Convert all column names to lower case
Expand Down Expand Up @@ -107,12 +107,14 @@ def _get_timestamp_options() -> dict[str, str]:
"HH24:MI:SS''');END;",
}

def reader(self, query: str) -> DataFrameReader:
def reader(self, query: str, options: dict | None = None) -> DataFrameReader:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
def reader(self, query: str, options: dict | None = None) -> DataFrameReader:
def reader(self, query: str, options: Mapping[str, OptionalPrimitiveType] | None = None) -> DataFrameReader:

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I cant import OptionalPrimitiveType so using object

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Try:

from pyspark.sql.readwriter import OptionalPrimitiveType

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this cannot be imported as well. I copied it so we can use it.

if options is None:
options = {}
user = self._get_secret('user')
password = self._get_secret('password')
logger.debug(f"Using user: {user} to connect to Oracle")
return self._get_jdbc_reader(
query, self.get_jdbc_url, OracleDataSource._DRIVER, {"user": user, "password": password}
query, self.get_jdbc_url, OracleDataSource._DRIVER, {**options, "user": user, "password": password}
)

def normalize_identifier(self, identifier: str) -> NormalizedIdentifier:
Expand Down
21 changes: 13 additions & 8 deletions src/databricks/labs/lakebridge/reconcile/connectors/tsql.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,8 +71,6 @@ def get_jdbc_url(self) -> str:
return (
f"jdbc:{self._DRIVER}://{self._get_secret('host')}:{self._get_secret('port')};"
f"databaseName={self._get_secret('database')};"
f"user={self._get_secret('user')};"
f"password={self._get_secret('password')};"
f"encrypt={self._get_secret('encrypt')};"
f"trustServerCertificate={self._get_secret('trustServerCertificate')};"
)
Expand All @@ -96,10 +94,10 @@ def read_data(
prepare_query_string = ""
try:
if options is None:
df = self.reader(query, prepare_query_string).load()
df = self.reader(query, {"prepareQuery": prepare_query_string}).load()
else:
options = self._get_jdbc_reader_options(options)
df = self._get_jdbc_reader(table_query, self.get_jdbc_url, self._DRIVER).options(**options).load()
spark_options = self._get_jdbc_reader_options(options)
df = self.reader(table_query, spark_options).load()
return df.select([col(column).alias(column.lower()) for column in df.columns])
except (RuntimeError, PySparkException) as e:
return self.log_and_throw_exception(e, "data", table_query)
Expand All @@ -126,15 +124,22 @@ def get_schema(
try:
logger.debug(f"Fetching schema using query: \n`{schema_query}`")
logger.info(f"Fetching Schema: Started at: {datetime.now()}")
df = self.reader(schema_query).load()
df = self.reader(schema_query, {}).load()
schema_metadata = df.select([col(c).alias(c.lower()) for c in df.columns]).collect()
logger.info(f"Schema fetched successfully. Completed at: {datetime.now()}")
return [self._map_meta_column(field, normalize) for field in schema_metadata]
except (RuntimeError, PySparkException) as e:
return self.log_and_throw_exception(e, "schema", schema_query)

def reader(self, query: str, prepare_query_str="") -> DataFrameReader:
return self._get_jdbc_reader(query, self.get_jdbc_url, self._DRIVER, {"prepareQuery": prepare_query_str})
def reader(self, query: str, options: dict) -> DataFrameReader:
creds = self._get_user_password()
return self._get_jdbc_reader(query, self.get_jdbc_url, self._DRIVER, {**options, **creds})

def _get_user_password(self) -> dict:
return {
"user": self._get_secret("user"),
"password": self._get_secret("password"),
}

def normalize_identifier(self, identifier: str) -> NormalizedIdentifier:
return DialectUtils.normalize_identifier(
Expand Down
21 changes: 12 additions & 9 deletions tests/integration/reconcile/connectors/test_read_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,12 @@ def __init__(self, spark, ws):

@property
def get_jdbc_url(self) -> str:
return (
self._test_env.get("TEST_TSQL_JDBC")
+ f"user={self._test_env.get('TEST_TSQL_USER')};"
+ f"password={self._test_env.get('TEST_TSQL_PASS')};"
)
return self._test_env.get("TEST_TSQL_JDBC")

def _get_user_password(self) -> dict:
user = self._test_env.get("TEST_TSQL_USER")
password = self._test_env.get("TEST_TSQL_PASS")
return {"user": user, "password": password}


class OracleDataSourceUnderTest(OracleDataSource):
Expand All @@ -38,11 +39,13 @@ def __init__(self, spark, ws):
def get_jdbc_url(self) -> str:
return self._test_env.get("TEST_ORACLE_JDBC")

def reader(self, query: str) -> DataFrameReader:
def reader(self, query: str, options: dict | None = None) -> DataFrameReader:
if options is None:
options = {}
user = self._test_env.get("TEST_ORACLE_USER")
password = self._test_env.get("TEST_ORACLE_PASSWORD")
return self._get_jdbc_reader(
query, self.get_jdbc_url, OracleDataSource._DRIVER, {"user": user, "password": password}
query, self.get_jdbc_url, OracleDataSource._DRIVER, {**options, "user": user, "password": password}
)


Expand Down Expand Up @@ -75,12 +78,12 @@ def _get_snowflake_options(self):
return opts


@pytest.mark.skip(reason="Add the creds to Github secrets and populate the actions' env to enable this test")
@pytest.mark.skip(reason="Run in acceptance environment only")
def test_sql_server_read_schema_happy(mock_spark):
mock_ws = create_autospec(WorkspaceClient)
connector = TSQLServerDataSourceUnderTest(mock_spark, mock_ws)

columns = connector.get_schema("labs_azure_sandbox_remorph", "dbo", "Employees")
columns = connector.get_schema("labs_azure_sandbox_remorph", "dbo", "reconcile_in")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Where does this come from?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

manually created by me.

assert columns


Expand Down
14 changes: 6 additions & 8 deletions tests/unit/reconcile/connectors/test_oracle.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,7 @@ def test_read_data_with_options():
)
spark.read.format().option().option.assert_called_with("driver", "oracle.jdbc.OracleDriver")
spark.read.format().option().option().option.assert_called_with("dbtable", "(select 1 from data.employee) tmp")
spark.read.format().option().option().option().option.assert_called_with("user", "my_user")
spark.read.format().option().option().option().option().option.assert_called_with("password", "my_password")
jdbc_actual_args = spark.read.format().option().option().option().option().option().options.call_args.kwargs
jdbc_actual_args = spark.read.format().option().option().option().options.call_args.kwargs
jdbc_expected_args = {
"numPartitions": 50,
"partitionColumn": "s_nationkey",
Expand All @@ -89,9 +87,11 @@ def test_read_data_with_options():
"sessionInitStatement": r"BEGIN dbms_session.set_nls('nls_date_format', "
r"'''YYYY-MM-DD''');dbms_session.set_nls('nls_timestamp_format', '''YYYY-MM-DD "
r"HH24:MI:SS''');END;",
"user": "my_user",
"password": "my_password",
}
assert jdbc_actual_args == jdbc_expected_args
spark.read.format().option().option().option().option().option().options().load.assert_called_once()
spark.read.format().option().option().option().options().load.assert_called_once()


def test_get_schema():
Expand Down Expand Up @@ -143,9 +143,7 @@ def test_read_data_exception_handling():
filters=None,
)

spark.read.format().option().option().option().option().option().options().load.side_effect = RuntimeError(
"Test Exception"
)
spark.read.format().option().option().option().options().load.side_effect = RuntimeError("Test Exception")

# Call the read_data method with the Tables configuration and assert that a PySparkException is raised
with pytest.raises(
Expand All @@ -160,7 +158,7 @@ def test_get_schema_exception_handling():
engine, spark, ws, scope = initial_setup()
ords = OracleDataSource(engine, spark, ws, scope)

spark.read.format().option().option().option().option().option().load.side_effect = RuntimeError("Test Exception")
spark.read.format().option().option().option().options().load.side_effect = RuntimeError("Test Exception")

# Call the get_schema method with predefined table, schema, and catalog names and assert that a PySparkException
# is raised
Expand Down
21 changes: 5 additions & 16 deletions tests/unit/reconcile/connectors/test_sql_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,20 +55,7 @@ def test_get_jdbc_url_happy():
url = data_source.get_jdbc_url
# Assert that the URL is generated correctly
assert url == (
"""jdbc:sqlserver://my_host:777;databaseName=my_database;user=my_user;password=my_password;encrypt=true;trustServerCertificate=true;"""
)


def test_get_jdbc_url_fail():
# initial setup
engine, spark, ws, scope = initial_setup()
ws.secrets.get_secret.side_effect = mock_secret
# create object for TSQLServerDataSource
data_source = TSQLServerDataSource(engine, spark, ws, scope)
url = data_source.get_jdbc_url
# Assert that the URL is generated correctly
assert url == (
"""jdbc:sqlserver://my_host:777;databaseName=my_database;user=my_user;password=my_password;encrypt=true;trustServerCertificate=true;"""
"""jdbc:sqlserver://my_host:777;databaseName=my_database;encrypt=true;trustServerCertificate=true;"""
)


Expand Down Expand Up @@ -96,7 +83,7 @@ def test_read_data_with_options():
spark.read.format.assert_called_with("jdbc")
spark.read.format().option.assert_called_with(
"url",
"jdbc:sqlserver://my_host:777;databaseName=my_database;user=my_user;password=my_password;encrypt=true;trustServerCertificate=true;",
"jdbc:sqlserver://my_host:777;databaseName=my_database;encrypt=true;trustServerCertificate=true;",
)
spark.read.format().option().option.assert_called_with("driver", "com.microsoft.sqlserver.jdbc.SQLServerDriver")
spark.read.format().option().option().option.assert_called_with(
Expand All @@ -109,6 +96,8 @@ def test_read_data_with_options():
"lowerBound": '0',
"upperBound": "100",
"fetchsize": 100,
"user": "my_user",
"password": "my_password",
}
assert actual_args == expected_args
spark.read.format().option().option().option().options().load.assert_called_once()
Expand Down Expand Up @@ -166,7 +155,7 @@ def test_get_schema_exception_handling():
engine, spark, ws, scope = initial_setup()
data_source = TSQLServerDataSource(engine, spark, ws, scope)

spark.read.format().option().option().option().option().load.side_effect = RuntimeError("Test Exception")
spark.read.format().option().option().option().options().load.side_effect = RuntimeError("Test Exception")

# Call the get_schema method with predefined table, schema, and catalog names and assert that a PySparkException
# is raised
Expand Down
Loading