diff --git a/.gitignore b/.gitignore index 8fdda9a..1b92254 100644 --- a/.gitignore +++ b/.gitignore @@ -46,3 +46,7 @@ build/ ### Log ### *.log /logs + +### Local Test Files ### +set-test-env.sh +**/S1.scala diff --git a/pom.xml b/pom.xml index afe6d4e..7107c65 100644 --- a/pom.xml +++ b/pom.xml @@ -89,7 +89,7 @@ under the License. com.oceanbase oceanbase-client - 2.4.12 + 2.4.16 diff --git a/spark-connector-oceanbase-common/src/test/java/com/oceanbase/spark/OceanBaseOracleTestBase.java b/spark-connector-oceanbase-common/src/test/java/com/oceanbase/spark/OceanBaseOracleTestBase.java index 7964f83..e7b3c1f 100644 --- a/spark-connector-oceanbase-common/src/test/java/com/oceanbase/spark/OceanBaseOracleTestBase.java +++ b/spark-connector-oceanbase-common/src/test/java/com/oceanbase/spark/OceanBaseOracleTestBase.java @@ -69,7 +69,7 @@ public String getSysPassword() { @Override public String getUsername() { - return System.getenv("USERNAME"); + return System.getenv("USER_NAME"); } @Override diff --git a/spark-connector-oceanbase-e2e-tests/src/test/resources/sql/oracle/products.sql b/spark-connector-oceanbase-e2e-tests/src/test/resources/sql/oracle/products.sql new file mode 100644 index 0000000..2e79b17 --- /dev/null +++ b/spark-connector-oceanbase-e2e-tests/src/test/resources/sql/oracle/products.sql @@ -0,0 +1,155 @@ +-- Copyright 2024 OceanBase. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- http://www.apache.org/licenses/LICENSE-2.0 +-- Unless required by applicable law or agreed to in writing, +-- software distributed under the License is distributed on an +-- "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +-- KIND, either express or implied. See the License for the +-- specific language governing permissions and limitations +-- under the License. + +-- Oracle mode test data initialization script +-- This script creates test tables and data for OceanBase Oracle mode testing + +-- Create test schema if not exists +CREATE USER test_schema IDENTIFIED BY 'password'; +GRANT CONNECT, RESOURCE TO test_schema; + +-- Switch to test schema +ALTER SESSION SET CURRENT_SCHEMA = test_schema; + +-- Create products table with various Oracle data types +CREATE TABLE products ( + product_id NUMBER(10) NOT NULL, + product_name VARCHAR2(255) NOT NULL, + description CLOB, + price NUMBER(19,4), + category VARCHAR2(100), + is_active NUMBER(1) DEFAULT 1, + created_date DATE DEFAULT SYSDATE, + updated_timestamp TIMESTAMP DEFAULT CURRENT_TIMESTAMP, + product_code RAW(16), + CONSTRAINT pk_products PRIMARY KEY (product_id), + CONSTRAINT uk_products_code UNIQUE (product_code) +); + +-- Create partitioned products table +CREATE TABLE products_partitioned ( + product_id NUMBER(10) NOT NULL, + product_name VARCHAR2(255) NOT NULL, + price NUMBER(19,4), + category VARCHAR2(100), + created_date DATE DEFAULT SYSDATE, + CONSTRAINT pk_products_partitioned PRIMARY KEY (product_id) +) PARTITION BY HASH(product_id) PARTITIONS 4; + +-- Create customers table +CREATE TABLE customers ( + customer_id NUMBER(10) NOT NULL, + customer_name VARCHAR2(255) NOT NULL, + email VARCHAR2(255), + phone VARCHAR2(20), + address CLOB, + registration_date DATE DEFAULT SYSDATE, + CONSTRAINT pk_customers PRIMARY KEY (customer_id), + CONSTRAINT uk_customers_email UNIQUE (email) +); + +-- Create orders table with foreign key +CREATE TABLE orders ( + order_id NUMBER(10) NOT NULL, + customer_id NUMBER(10) NOT NULL, + order_date DATE DEFAULT SYSDATE, + total_amount NUMBER(19,4), + status VARCHAR2(50) DEFAULT 'PENDING', + notes CLOB, + CONSTRAINT pk_orders PRIMARY KEY (order_id), + CONSTRAINT fk_orders_customer FOREIGN KEY (customer_id) REFERENCES customers(customer_id) +); + +-- Create order_items table +CREATE TABLE order_items ( + order_item_id NUMBER(10) NOT NULL, + order_id NUMBER(10) NOT NULL, + product_id NUMBER(10) NOT NULL, + quantity NUMBER(10) NOT NULL, + unit_price NUMBER(19,4) NOT NULL, + total_price NUMBER(19,4) NOT NULL, + CONSTRAINT pk_order_items PRIMARY KEY (order_item_id), + CONSTRAINT fk_order_items_order FOREIGN KEY (order_id) REFERENCES orders(order_id), + CONSTRAINT fk_order_items_product FOREIGN KEY (product_id) REFERENCES products(product_id) +); + +-- Insert test data +INSERT INTO products (product_id, product_name, description, price, category, product_code) VALUES +(1, 'Laptop Computer', 'High-performance laptop with 16GB RAM and 512GB SSD', 1299.99, 'Electronics', UTL_RAW.CAST_TO_RAW('LAPTOP001')), +(2, 'Wireless Mouse', 'Ergonomic wireless mouse with USB receiver', 29.99, 'Electronics', UTL_RAW.CAST_TO_RAW('MOUSE001')), +(3, 'Office Chair', 'Comfortable ergonomic office chair with lumbar support', 199.99, 'Furniture', UTL_RAW.CAST_TO_RAW('CHAIR001')), +(4, 'Coffee Mug', 'Ceramic coffee mug with company logo', 12.99, 'Accessories', UTL_RAW.CAST_TO_RAW('MUG001')), +(5, 'Notebook', 'Spiral-bound notebook with 200 pages', 8.99, 'Stationery', UTL_RAW.CAST_TO_RAW('NOTE001')); + +INSERT INTO products_partitioned (product_id, product_name, price, category) VALUES +(1, 'Laptop Computer', 1299.99, 'Electronics'), +(2, 'Wireless Mouse', 29.99, 'Electronics'), +(3, 'Office Chair', 199.99, 'Furniture'), +(4, 'Coffee Mug', 12.99, 'Accessories'), +(5, 'Notebook', 8.99, 'Stationery'); + +INSERT INTO customers (customer_id, customer_name, email, phone, address) VALUES +(1, 'John Smith', 'john.smith@email.com', '555-0101', '123 Main Street, Anytown, USA'), +(2, 'Jane Doe', 'jane.doe@email.com', '555-0102', '456 Oak Avenue, Somewhere, USA'), +(3, 'Bob Johnson', 'bob.johnson@email.com', '555-0103', '789 Pine Road, Elsewhere, USA'), +(4, 'Alice Brown', 'alice.brown@email.com', '555-0104', '321 Elm Street, Nowhere, USA'), +(5, 'Charlie Wilson', 'charlie.wilson@email.com', '555-0105', '654 Maple Drive, Anywhere, USA'); + +INSERT INTO orders (order_id, customer_id, total_amount, status, notes) VALUES +(1, 1, 1329.98, 'COMPLETED', 'Order placed online'), +(2, 2, 229.98, 'PENDING', 'Order awaiting payment'), +(3, 3, 212.98, 'SHIPPED', 'Order shipped via express delivery'), +(4, 4, 21.98, 'COMPLETED', 'Order completed successfully'), +(5, 5, 8.99, 'CANCELLED', 'Order cancelled by customer'); + +INSERT INTO order_items (order_item_id, order_id, product_id, quantity, unit_price, total_price) VALUES +(1, 1, 1, 1, 1299.99, 1299.99), +(2, 1, 2, 1, 29.99, 29.99), +(3, 2, 1, 1, 1299.99, 1299.99), +(4, 2, 3, 1, 199.99, 199.99), +(5, 3, 3, 1, 199.99, 199.99), +(6, 3, 2, 1, 29.99, 29.99), +(7, 4, 2, 1, 29.99, 29.99), +(8, 4, 4, 1, 12.99, 12.99), +(9, 5, 5, 1, 8.99, 8.99); + +-- Create indexes for better performance +CREATE INDEX idx_products_category ON products(category); +CREATE INDEX idx_products_price ON products(price); +CREATE INDEX idx_customers_email ON customers(email); +CREATE INDEX idx_orders_customer ON orders(customer_id); +CREATE INDEX idx_orders_date ON orders(order_date); +CREATE INDEX idx_order_items_order ON order_items(order_id); +CREATE INDEX idx_order_items_product ON order_items(product_id); + +-- Create a view for order summary +CREATE VIEW order_summary AS +SELECT + o.order_id, + c.customer_name, + o.order_date, + o.total_amount, + o.status, + COUNT(oi.order_item_id) as item_count +FROM orders o +JOIN customers c ON o.customer_id = c.customer_id +LEFT JOIN order_items oi ON o.order_id = oi.order_id +GROUP BY o.order_id, c.customer_name, o.order_date, o.total_amount, o.status; + +-- Grant permissions +GRANT SELECT, INSERT, UPDATE, DELETE ON products TO test_schema; +GRANT SELECT, INSERT, UPDATE, DELETE ON products_partitioned TO test_schema; +GRANT SELECT, INSERT, UPDATE, DELETE ON customers TO test_schema; +GRANT SELECT, INSERT, UPDATE, DELETE ON orders TO test_schema; +GRANT SELECT, INSERT, UPDATE, DELETE ON order_items TO test_schema; +GRANT SELECT ON order_summary TO test_schema; diff --git a/spark-connector-oceanbase/spark-connector-oceanbase-base/src/main/scala/com/oceanbase/spark/dialect/OceanBaseOracleDialect.scala b/spark-connector-oceanbase/spark-connector-oceanbase-base/src/main/scala/com/oceanbase/spark/dialect/OceanBaseOracleDialect.scala index 9296f06..85e2a72 100644 --- a/spark-connector-oceanbase/spark-connector-oceanbase-base/src/main/scala/com/oceanbase/spark/dialect/OceanBaseOracleDialect.scala +++ b/spark-connector-oceanbase/spark-connector-oceanbase-base/src/main/scala/com/oceanbase/spark/dialect/OceanBaseOracleDialect.scala @@ -17,16 +17,21 @@ package com.oceanbase.spark.dialect import com.oceanbase.spark.config.OceanBaseConfig +import com.oceanbase.spark.utils.OBJdbcUtils +import com.oceanbase.spark.utils.OBJdbcUtils.executeStatement +import org.apache.commons.lang3.StringUtils +import org.apache.spark.sql.ExprUtils import org.apache.spark.sql.catalyst.util.DateTimeUtils import org.apache.spark.sql.connector.expressions.{Expression, Transform} import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{BooleanType, ByteType, DataType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MetadataBuilder, ShortType, StringType, StructType, TimestampType} +import org.apache.spark.sql.types.{BinaryType, BooleanType, ByteType, DataType, DateType, DecimalType, DoubleType, FloatType, IntegerType, LongType, MetadataBuilder, ShortType, StringType, StructType, TimestampType, VarcharType} import java.sql.{Connection, Date, Timestamp, Types} import java.util import java.util.TimeZone +import scala.collection.mutable import scala.collection.mutable.ArrayBuffer import scala.util.control.NonFatal @@ -45,37 +50,200 @@ class OceanBaseOracleDialect extends OceanBaseDialect { schema: StructType, partitions: Array[Transform], config: OceanBaseConfig, - properties: util.Map[String, String]): Unit = - throw new UnsupportedOperationException("Not currently supported in oracle mode") + properties: util.Map[String, String]): Unit = { + + // Collect column and table comments (use COMMENT ON statements in Oracle mode) + val columnComments: mutable.ArrayBuffer[(String, String)] = mutable.ArrayBuffer.empty + val tableCommentOpt: Option[String] = Option(config.getTableComment) + + def buildCreateTableSQL( + tableName: String, + schema: StructType, + transforms: Array[Transform], + config: OceanBaseConfig): String = { + val partitionClause = buildPartitionClause(transforms, config) + val columnClause = schema.fields + .map { + field => + val obType = toOceanBaseOracleType(field.dataType, config) + val nullability = if (field.nullable) StringUtils.EMPTY else "NOT NULL" + field.getComment().foreach(c => columnComments += ((field.name, c))) + s"${quoteIdentifier(field.name)} $obType $nullability".trim + } + .mkString(",\n ") + var primaryKey = "" + val tableOption = scala.collection.JavaConverters + .mapAsScalaMapConverter(properties) + .asScala + .map(tuple => (tuple._1.toLowerCase, tuple._2)) + .flatMap { + case ("tablespace", value) => Some(s"TABLESPACE $value") + case ("compression", value) => + logWarning(s"Ignored unsupported table property on Oracle mode: compression=$value"); + None + case ("replica_num", value) => + logWarning(s"Ignored unsupported table property on Oracle mode: replica_num=$value"); + None + case ("primary_key", value) => + primaryKey = s", CONSTRAINT pk_${config.getTableName} PRIMARY KEY($value)" + None + case (k, _) => + logWarning(s"Ignored unsupported table property: $k") + None + } + .mkString(" ", " ", "") + s""" + |CREATE TABLE $tableName ( + | $columnClause + | $primaryKey + |) $tableOption + |$partitionClause; + |""".stripMargin.trim + } + + def toOceanBaseOracleType(dataType: DataType, config: OceanBaseConfig): String = { + var stringConvertType = s"VARCHAR2(${config.getLengthString2Varchar})" + if (config.getEnableString2Text) stringConvertType = "CLOB" + dataType match { + case BooleanType => "NUMBER(1)" + case ByteType => "NUMBER(3)" + case ShortType => "NUMBER(5)" + case IntegerType => "NUMBER(10)" + case LongType => "NUMBER(19)" + case FloatType => "BINARY_FLOAT" + case DoubleType => "BINARY_DOUBLE" + case d: DecimalType => s"NUMBER(${d.precision},${d.scale})" + case StringType => stringConvertType + case BinaryType => "RAW(2000)" + case DateType => "DATE" + case TimestampType => "TIMESTAMP" + case v: VarcharType => s"VARCHAR2(${v.length})" + case _ => throw new UnsupportedOperationException(s"Unsupported type: $dataType") + } + } + + def buildPartitionClause(transforms: Array[Transform], config: OceanBaseConfig): String = { + transforms match { + case transforms if transforms.nonEmpty => + ExprUtils.toOBOraclePartition(transforms.head, config) + case _ => "" + } + } + + val sql = buildCreateTableSQL(tableName, schema, partitions, config) + executeStatement(conn, config, sql) + + // Apply column comments + columnComments.foreach { + case (col, comment) => + val colName = quoteIdentifier(col) + val commentSql = s"COMMENT ON COLUMN $tableName.$colName IS '${escapeSql(comment)}'" + executeStatement(conn, config, commentSql) + } + + // Apply table comment + tableCommentOpt.foreach { + comment => + val tblCommentSql = s"COMMENT ON TABLE $tableName IS '${escapeSql(comment)}'" + executeStatement(conn, config, tblCommentSql) + } + + // In Oracle mode, complex table properties (compression/replica count) are not supported; ignored + } /** Creates a schema. */ override def createSchema( conn: Connection, config: OceanBaseConfig, schema: String, - comment: String): Unit = - throw new UnsupportedOperationException("Not currently supported in oracle mode") + comment: String): Unit = { + // In Oracle mode, schema equals user; create a user + val statement = conn.createStatement + try { + statement.setQueryTimeout(config.getJdbcQueryTimeout) + // Note: a real password is required; using a default here + // In Oracle mode, password should not use single quotes + statement.executeUpdate(s"CREATE USER ${quoteIdentifier(schema)} IDENTIFIED BY password") + // Grant basic privileges + statement.executeUpdate(s"GRANT CONNECT, RESOURCE TO ${quoteIdentifier(schema)}") + } finally { + statement.close() + } + } - override def schemaExists(conn: Connection, config: OceanBaseConfig, schema: String): Boolean = - throw new UnsupportedOperationException("Not currently supported in oracle mode") + override def schemaExists(conn: Connection, config: OceanBaseConfig, schema: String): Boolean = { + listSchemas(conn, config).exists(_.head == schema) + } - override def listSchemas(conn: Connection, config: OceanBaseConfig): Array[Array[String]] = - throw new UnsupportedOperationException("Not currently supported in oracle mode") + override def listSchemas(conn: Connection, config: OceanBaseConfig): Array[Array[String]] = { + val schemaBuilder = mutable.ArrayBuilder.make[Array[String]] + try { + OBJdbcUtils.executeQuery(conn, config, "SELECT USERNAME FROM ALL_USERS ORDER BY USERNAME") { + rs => + while (rs.next()) { + schemaBuilder += Array(rs.getString("USERNAME")) + } + } + } catch { + case _: Exception => + logWarning("Cannot list schemas.") + } + schemaBuilder.result + } /** Drops a schema from OceanBase. */ override def dropSchema( conn: Connection, config: OceanBaseConfig, schema: String, - cascade: Boolean): Unit = throw new UnsupportedOperationException( - "Not currently supported in oracle mode") + cascade: Boolean): Unit = { + // In OceanBase Oracle mode, DROP USER without CASCADE is not supported + // Always use CASCADE regardless of the cascade parameter + executeStatement(conn, config, s"DROP USER ${quoteIdentifier(schema)} CASCADE") + } override def getPriKeyInfo( connection: Connection, schemaName: String, tableName: String, config: OceanBaseConfig): ArrayBuffer[PriKeyColumnInfo] = { - throw new UnsupportedOperationException("Not currently supported in oracle mode") + val sql = + s""" + |SELECT + | cc.COLUMN_NAME, + | tc.DATA_TYPE, + | cc.CONSTRAINT_NAME, + | tc.DATA_TYPE, + | cc.CONSTRAINT_NAME + |FROM + | ALL_CONSTRAINTS c + | JOIN ALL_CONS_COLUMNS cc ON c.CONSTRAINT_NAME = cc.CONSTRAINT_NAME + | AND c.OWNER = cc.OWNER + | JOIN ALL_TAB_COLUMNS tc ON cc.TABLE_NAME = tc.TABLE_NAME + | AND cc.COLUMN_NAME = tc.COLUMN_NAME + | AND cc.OWNER = tc.OWNER + |WHERE + | c.OWNER = '$schemaName' + | AND c.TABLE_NAME = '$tableName' + | AND c.CONSTRAINT_TYPE = 'P' + |ORDER BY cc.POSITION + |""".stripMargin + + val arrayBuffer = ArrayBuffer[PriKeyColumnInfo]() + OBJdbcUtils.executeQuery(connection, config, sql) { + rs => + { + while (rs.next()) { + arrayBuffer += PriKeyColumnInfo( + quoteIdentifier(rs.getString(1)), + rs.getString(2), + "PRI", + rs.getString(4), + rs.getString(5)) + } + arrayBuffer + } + } } override def getUniqueKeyInfo( @@ -83,11 +251,53 @@ class OceanBaseOracleDialect extends OceanBaseDialect { schemaName: String, tableName: String, config: OceanBaseConfig): ArrayBuffer[PriKeyColumnInfo] = { - throw new UnsupportedOperationException("Not currently supported in oracle mode") + val sql = + s""" + |SELECT + | cc.COLUMN_NAME, + | tc.DATA_TYPE, + | cc.CONSTRAINT_NAME, + | tc.DATA_TYPE, + | cc.CONSTRAINT_NAME + |FROM + | ALL_CONSTRAINTS c + | JOIN ALL_CONS_COLUMNS cc ON c.CONSTRAINT_NAME = cc.CONSTRAINT_NAME + | AND c.OWNER = cc.OWNER + | JOIN ALL_TAB_COLUMNS tc ON cc.TABLE_NAME = tc.TABLE_NAME + | AND cc.COLUMN_NAME = tc.COLUMN_NAME + | AND cc.OWNER = tc.OWNER + |WHERE + | c.OWNER = '$schemaName' + | AND c.TABLE_NAME = '$tableName' + | AND c.CONSTRAINT_TYPE = 'U' + |ORDER BY cc.POSITION + |""".stripMargin + + val arrayBuffer = ArrayBuffer[PriKeyColumnInfo]() + OBJdbcUtils.executeQuery(connection, config, sql) { + rs => + { + while (rs.next()) { + arrayBuffer += PriKeyColumnInfo( + quoteIdentifier(rs.getString(1)), + rs.getString(2), + "UNI", + rs.getString(4), + rs.getString(5)) + } + arrayBuffer + } + } } override def getInsertIntoStatement(tableName: String, schema: StructType): String = { - throw new UnsupportedOperationException("Not currently supported in oracle mode") + val columnClause = + schema.fieldNames.map(columnName => quoteIdentifier(columnName)).mkString(", ") + val placeholders = schema.fieldNames.map(_ => "?").mkString(", ") + s""" + |INSERT INTO $tableName ($columnClause) + |VALUES ($placeholders) + |""".stripMargin } override def getUpsertIntoStatement( @@ -95,7 +305,37 @@ class OceanBaseOracleDialect extends OceanBaseDialect { schema: StructType, priKeyColumnInfo: ArrayBuffer[PriKeyColumnInfo], config: OceanBaseConfig): String = { - throw new UnsupportedOperationException("Not currently supported in oracle mode") + val uniqueKeys = priKeyColumnInfo.map(_.columnName).toSet + val nonUniqueFields = + schema.fieldNames.filterNot(fieldName => uniqueKeys.contains(quoteIdentifier(fieldName))) + + val columns = schema.fieldNames.map(quoteIdentifier).mkString(", ") + val keyColumns = priKeyColumnInfo.map(_.columnName).mkString(", ") + + // Build SELECT ? AS col1, ? AS col2, ... for USING clause + val selectClause = schema.fieldNames.map(f => s"? AS ${quoteIdentifier(f)}").mkString(", ") + + // For VALUES clause, reference the subquery columns + val valuesClause = schema.fieldNames.map(f => s"s.${quoteIdentifier(f)}").mkString(", ") + + val whenMatchedClause = if (nonUniqueFields.nonEmpty) { + val updateClause = nonUniqueFields + .map(f => s"t.${quoteIdentifier(f)} = s.${quoteIdentifier(f)}") + .mkString(", ") + s"WHEN MATCHED THEN UPDATE SET $updateClause" + } else { + // In Oracle mode, columns referenced in the ON clause must not be updated (even self-assignment). + // If there are no updatable non-key columns, omit the WHEN MATCHED THEN UPDATE clause. + "" + } + + s""" + |MERGE INTO $tableName t + |USING (SELECT $selectClause FROM DUAL) s + |ON (${keyColumns.split(", ").map(col => s"t.$col = s.$col").mkString(" AND ")}) + |$whenMatchedClause + |WHEN NOT MATCHED THEN INSERT ($columns) VALUES ($valuesClause) + |""".stripMargin } /** diff --git a/spark-connector-oceanbase/spark-connector-oceanbase-base/src/main/scala/com/oceanbase/spark/reader/v2/OBJdbcReader.scala b/spark-connector-oceanbase/spark-connector-oceanbase-base/src/main/scala/com/oceanbase/spark/reader/v2/OBJdbcReader.scala index 24e8613..e6ea0f9 100644 --- a/spark-connector-oceanbase/spark-connector-oceanbase-base/src/main/scala/com/oceanbase/spark/reader/v2/OBJdbcReader.scala +++ b/spark-connector-oceanbase/spark-connector-oceanbase-base/src/main/scala/com/oceanbase/spark/reader/v2/OBJdbcReader.scala @@ -17,7 +17,7 @@ package com.oceanbase.spark.reader.v2 import com.oceanbase.spark.config.OceanBaseConfig -import com.oceanbase.spark.dialect.OceanBaseDialect +import com.oceanbase.spark.dialect.{OceanBaseDialect, OceanBaseOracleDialect} import com.oceanbase.spark.reader.v2.OBJdbcReader.{makeGetters, OBValueGetter} import com.oceanbase.spark.utils.OBJdbcUtils @@ -61,8 +61,13 @@ class OBJdbcReader( part.unevenlyWhereValue.zipWithIndex.foreach { case (value, index) => stmt.setObject(index + 1, value) } + case part: OBOraclePartition => + part.unevenlyWhereValue.zipWithIndex.foreach { + case (value, index) => stmt.setObject(index + 1, value) + } case _ => } + print("Query SQL: " + buildQuerySql()) stmt.setFetchSize(config.getJdbcFetchSize) stmt.setQueryTimeout(config.getJdbcQueryTimeout) stmt.executeQuery() @@ -101,9 +106,31 @@ class OBJdbcReader( private def buildQuerySql(): String = { var columns = schema.map(col => dialect.quoteIdentifier(col.name)).toArray if (requiredColumns != null && requiredColumns.nonEmpty) { - columns = requiredColumns; + columns = requiredColumns + } + + // For Oracle mode, avoid JDBC NUMBER precision/scale issues by casting numeric columns + // to BINARY_DOUBLE in the projection. Only apply to fractional numeric types to avoid + // altering integral semantics (e.g., NUMBER(10,0)). + val columnStr: String = { + dialect match { + case _: OceanBaseOracleDialect => + val nameToType = schema.fields.map(f => f.name -> f.dataType).toMap + val projected = columns.map { + raw => + val unquoted = dialect.unQuoteIdentifier(raw) + val quoted = dialect.quoteIdentifier(unquoted) + nameToType.get(unquoted) match { + // Keep DecimalType as JDBC BigDecimal to avoid losing fractional formatting + case Some(dt) if dt == DoubleType || dt == FloatType => + s"CAST($quoted AS BINARY_DOUBLE) AS $quoted" + case _ => quoted + } + } + if (projected.isEmpty) "1" else projected.mkString(",") + case _ => if (columns.isEmpty) "1" else columns.mkString(",") + } } - val columnStr: String = if (columns.isEmpty) "1" else columns.mkString(",") val filterWhereClause: String = pushedFilter @@ -111,17 +138,28 @@ class OBJdbcReader( .map(p => s"($p)") .mkString(" AND ") - val part: OBMySQLPartition = partition.asInstanceOf[OBMySQLPartition] - val whereClause = { - if (part.whereClause != null && filterWhereClause.nonEmpty) { - "WHERE " + s"($filterWhereClause)" + " AND " + s"(${part.whereClause})" - } else if (part.whereClause != null) { - "WHERE " + part.whereClause - } else if (filterWhereClause.nonEmpty) { - "WHERE " + filterWhereClause - } else { - "" - } + val whereClause = partition match { + case part: OBMySQLPartition => + if (part.whereClause != null && filterWhereClause.nonEmpty) { + "WHERE " + s"($filterWhereClause)" + " AND " + s"(${part.whereClause})" + } else if (part.whereClause != null) { + "WHERE " + part.whereClause + } else if (filterWhereClause.nonEmpty) { + "WHERE " + filterWhereClause + } else { + "" + } + case part: OBOraclePartition => + if (part.whereClause != null && filterWhereClause.nonEmpty) { + "WHERE " + s"($filterWhereClause)" + " AND " + s"(${part.whereClause})" + } else if (part.whereClause != null) { + "WHERE " + part.whereClause + } else if (filterWhereClause.nonEmpty) { + "WHERE " + filterWhereClause + } else { + "" + } + case _ => throw new RuntimeException(s"Unsupported partition type: ${partition.getClass}") } /** A GROUP BY clause representing pushed-down grouping columns. */ @@ -134,17 +172,27 @@ class OBJdbcReader( } } - val myLimitClause: String = { - if (part.limitOffsetClause == null || part.limitOffsetClause.isEmpty) - dialect.getLimitClause(pushDownLimit) - else - "" - } - - val useHiddenPKColumnHint = if (part.useHiddenPKColumn) { - s", opt_param('hidden_column_visible', 'true') " - } else { - "" + val (limitClause, useHiddenPKColumnHint) = partition match { + case part: OBMySQLPartition => + val myLimitClause = + if (part.limitOffsetClause == null || part.limitOffsetClause.isEmpty) + dialect.getLimitClause(pushDownLimit) + else + "" + val useHiddenPKColumnHint = if (part.useHiddenPKColumn) { + s", opt_param('hidden_column_visible', 'true') " + } else { + "" + } + (myLimitClause, useHiddenPKColumnHint) + case part: OBOraclePartition => + val useHiddenPKColumnHint = if (part.useHiddenPKColumn) { + s", opt_param('hidden_column_visible', 'true') " + } else { + "" + } + ("", useHiddenPKColumnHint) + case _ => throw new RuntimeException(s"Unsupported partition type: ${partition.getClass}") } val queryTimeoutHint = if (config.getQueryTimeoutHintDegree > 0) { s", query_timeout(${config.getQueryTimeoutHintDegree}) " @@ -154,9 +202,24 @@ class OBJdbcReader( val hint = s"/*+ PARALLEL(${config.getJdbcParallelHintDegree}) $useHiddenPKColumnHint $queryTimeoutHint */" + val partitionClause = partition match { + case part: OBMySQLPartition => part.partitionClause + case part: OBOraclePartition => part.partitionClause + case _ => "" + } + + val finalLimitClause = partition match { + case part: OBMySQLPartition => + if (part.limitOffsetClause != null && part.limitOffsetClause.nonEmpty) + part.limitOffsetClause + else limitClause + case part: OBOraclePartition => "" + case _ => "" + } + s""" - |SELECT $hint $columnStr FROM ${config.getDbTable} ${part.partitionClause} - |$whereClause $getGroupByClause $getOrderByClause ${part.limitOffsetClause} $myLimitClause + |SELECT $hint $columnStr FROM ${config.getDbTable} $partitionClause + |$whereClause $getGroupByClause $getOrderByClause $finalLimitClause |""".stripMargin } @@ -243,12 +306,12 @@ object OBJdbcReader extends SQLConfHelper { // DecimalType(12, 2). Thus, after saving the dataframe into parquet file and then // retrieve it, you will get wrong result 199.99. // So it is needed to set precision and scale for Decimal based on JDBC metadata. - case _: DecimalType => + case t: DecimalType => (rs: ResultSet, row: InternalRow, pos: Int) => val decimal = nullSafeConvert[java.math.BigDecimal]( rs.getBigDecimal(pos + 1), - d => Decimal(d, d.precision(), d.scale())) + d => Decimal(d, t.precision, t.scale)) row.update(pos, decimal) case DoubleType => diff --git a/spark-connector-oceanbase/spark-connector-oceanbase-base/src/main/scala/com/oceanbase/spark/reader/v2/OBJdbcScanBuilder.scala b/spark-connector-oceanbase/spark-connector-oceanbase-base/src/main/scala/com/oceanbase/spark/reader/v2/OBJdbcScanBuilder.scala index 3eb0675..ef8e147 100644 --- a/spark-connector-oceanbase/spark-connector-oceanbase-base/src/main/scala/com/oceanbase/spark/reader/v2/OBJdbcScanBuilder.scala +++ b/spark-connector-oceanbase/spark-connector-oceanbase-base/src/main/scala/com/oceanbase/spark/reader/v2/OBJdbcScanBuilder.scala @@ -17,7 +17,7 @@ package com.oceanbase.spark.reader.v2 import com.oceanbase.spark.config.OceanBaseConfig -import com.oceanbase.spark.dialect.OceanBaseDialect +import com.oceanbase.spark.dialect.{OceanBaseDialect, OceanBaseOracleDialect} import com.oceanbase.spark.utils.OBJdbcUtils import org.apache.spark.internal.Logging @@ -64,6 +64,18 @@ case class OBJdbcScanBuilder(schema: StructType, config: OceanBaseConfig, dialec override def pushAggregation(aggregation: Aggregation): Boolean = { if (!config.getPushDownAggregate) return false + // TODO(oracle-agg-pushdown): Disable aggregate pushdown for Oracle mode. + // Reasons: + // 1) NUMBER precision/scale metadata from JDBC can be unstable in some aggregate cases, + // leading to value mis-scaling (e.g., 5.3 -> 0.53). + // 2) MAX/MIN semantics on non-numeric types (e.g., strings) are lexicographical; blindly + // casting to numeric changes semantics or throws errors. + // 3) Needs type-aware rewrite strategy; enable later once fully implemented. + if (dialect.isInstanceOf[OceanBaseOracleDialect]) { + logInfo("Disable aggregate pushdown on Oracle mode due to precision/semantic concerns.") + return false + } + val compiledAggs = aggregation.aggregateExpressions.flatMap(dialect.compileExpression(_)) if (compiledAggs.length != aggregation.aggregateExpressions.length) return false @@ -179,8 +191,12 @@ class OBJdbcBatch( pushedGroupBys: Option[Array[String]], dialect: OceanBaseDialect) extends Batch { - private lazy val inputPartitions: Array[InputPartition] = - OBMySQLPartition.columnPartition(config, dialect) + private lazy val inputPartitions: Array[InputPartition] = { + OBJdbcUtils.getCompatibleMode(config).map(_.toLowerCase) match { + case Some("oracle") => OBOraclePartition.columnPartition(config, dialect) + case _ => OBMySQLPartition.columnPartition(config, dialect) + } + } override def planInputPartitions(): Array[InputPartition] = inputPartitions diff --git a/spark-connector-oceanbase/spark-connector-oceanbase-base/src/main/scala/com/oceanbase/spark/reader/v2/OBOraclePartition.scala b/spark-connector-oceanbase/spark-connector-oceanbase-base/src/main/scala/com/oceanbase/spark/reader/v2/OBOraclePartition.scala new file mode 100644 index 0000000..fbf69bc --- /dev/null +++ b/spark-connector-oceanbase/spark-connector-oceanbase-base/src/main/scala/com/oceanbase/spark/reader/v2/OBOraclePartition.scala @@ -0,0 +1,518 @@ +/* + * Copyright 2024 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package com.oceanbase.spark.reader.v2 + +import com.oceanbase.spark.config.OceanBaseConfig +import com.oceanbase.spark.dialect.{OceanBaseDialect, PriKeyColumnInfo} +import com.oceanbase.spark.utils.OBJdbcUtils + +import org.apache.spark.internal.Logging +import org.apache.spark.sql.connector.read.InputPartition + +import java.sql.Connection +import java.util.Objects + +import scala.collection.mutable.ArrayBuffer + +case class OBOraclePartition( + partitionClause: String, + limitOffsetClause: String, + whereClause: String, + useHiddenPKColumn: Boolean = false, + unevenlyWhereValue: Seq[Object] = Seq[Object](), + idx: Int) + extends InputPartition {} + +object OBOraclePartition extends Logging { + + private val EMPTY_STRING = "" + private val PARTITION_QUERY_FORMAT = "PARTITION(%s)" + private val HIDDEN_PK_INCREMENT = "__pk_increment" + + private def normalizePkNameForSql(column: String): String = { + if (column == null) column + else if (column == HIDDEN_PK_INCREMENT) "\"__pk_increment\"" + else column // Column name from dialect.getPriKeyInfo is already quoted + } + + def columnPartition(config: OceanBaseConfig, dialect: OceanBaseDialect): Array[InputPartition] = { + OBJdbcUtils + .withConnection(config) { + connection => + val obPartInfos: Array[OBOraclePartInfo] = obtainPartInfo(connection, config) + + val priKeyColInfos = + dialect.getPriKeyInfo(connection, config.getSchemaName, config.getTableName, config) + + if (priKeyColInfos == null || priKeyColInfos.isEmpty) { + // Non-primary key table: use hidden pk column evenly sized where partitioning + if (obPartInfos.isEmpty) { + computeWherePartInfoForNonPartTable(connection, config, HIDDEN_PK_INCREMENT) + } else { + computeWherePartInfoForPartTable(connection, config, obPartInfos, HIDDEN_PK_INCREMENT) + } + } else { + val numericPriKey: PriKeyColumnInfo = selectNumericPriKey(priKeyColInfos) + val priKeyColumnName = + if (numericPriKey != null) numericPriKey.columnName + else priKeyColInfos.head.columnName + + if (obPartInfos.isEmpty) { + if (isNumericType(numericPriKey)) { + computeWherePartInfoForNonPartTable(connection, config, priKeyColumnName) + } else { + computeUnevenlyWherePartInfoForNonPartTable(connection, config, priKeyColumnName) + } + } else { + if (isNumericType(numericPriKey)) { + computeWherePartInfoForPartTable(connection, config, obPartInfos, priKeyColumnName) + } else { + computeUnevenlyWherePartInfoForPartTable(config, obPartInfos, priKeyColumnName) + } + } + } + } + .asInstanceOf[Array[InputPartition]] + } + + private def obtainPartInfo( + connection: Connection, + config: OceanBaseConfig): Array[OBOraclePartInfo] = { + val subPartSql = + s""" + |SELECT + | PARTITION_NAME, + | SUBPARTITION_NAME + |FROM + | ALL_TAB_SUBPARTITIONS + |WHERE + | TABLE_OWNER = '${config.getSchemaName}' + | AND TABLE_NAME = '${config.getTableName}' + |ORDER BY PARTITION_NAME, SUBPARTITION_POSITION + |""".stripMargin + + val subPartitions = ArrayBuffer[OBOraclePartInfo]() + OBJdbcUtils.executeQuery(connection, config, subPartSql) { + rs => + while (rs.next()) { + subPartitions += OBOraclePartInfo( + rs.getString("PARTITION_NAME"), + rs.getString("SUBPARTITION_NAME")) + } + subPartitions + } + + if (subPartitions.nonEmpty) return subPartitions.toArray + + val partSql = + s""" + |SELECT + | PARTITION_NAME + |FROM + | ALL_TAB_PARTITIONS + |WHERE + | TABLE_OWNER = '${config.getSchemaName}' + | AND TABLE_NAME = '${config.getTableName}' + |ORDER BY PARTITION_POSITION + |""".stripMargin + + val partitions = ArrayBuffer[OBOraclePartInfo]() + OBJdbcUtils.executeQuery(connection, config, partSql) { + rs => + while (rs.next()) { + partitions += OBOraclePartInfo(rs.getString("PARTITION_NAME"), null) + } + partitions + } + + partitions.toArray + } + + private def obtainCount( + connection: Connection, + config: OceanBaseConfig, + partName: String): Long = { + val statement = connection.createStatement() + val tableName = config.getDbTable + val sql = + s"SELECT /*+ PARALLEL(${config.getJdbcStatsParallelHintDegree}) ${queryTimeoutHint( + config)} */ count(1) AS cnt FROM $tableName $partName" + try { + val rs = statement.executeQuery(sql) + if (rs.next()) rs.getLong(1) + else throw new RuntimeException(s"Failed to obtain count of $tableName.") + } finally { + statement.close() + } + } + + private val calPartitionSize: Long => Long = { + case count if count <= 100000 => 10000 + case count if count > 100000 && count <= 10000000 => 100000 + case count if count > 10000000 && count <= 100000000 => 200000 + case count if count > 100000000 && count <= 1000000000 => 250000 + case _ => 500000 + } + + private def isNumericType(priKey: PriKeyColumnInfo): Boolean = { + if (priKey == null) return false + val t = Option(priKey.dataType).map(_.toUpperCase).getOrElse("") + t.contains("NUMBER") || t == "INTEGER" + } + + private def selectNumericPriKey( + priKeyColInfos: ArrayBuffer[PriKeyColumnInfo]): PriKeyColumnInfo = { + if (priKeyColInfos == null || priKeyColInfos.isEmpty) return null + val numeric = priKeyColInfos.find(info => isNumericType(info)) + numeric.orNull + } + + private def computeWherePartInfoForNonPartTable( + connection: Connection, + config: OceanBaseConfig, + priKeyColumnName: String): Array[InputPartition] = { + val pkForSql = normalizePkNameForSql(priKeyColumnName) + val priKeyColumnInfo = + obtainIntPriKeyTableInfo(connection, config, EMPTY_STRING, pkForSql) + if (priKeyColumnInfo.count <= 0) Array.empty + computeWhereSparkPart(priKeyColumnInfo, EMPTY_STRING, pkForSql, config) + .asInstanceOf[Array[InputPartition]] + } + + private def computeWherePartInfoForPartTable( + connection: Connection, + config: OceanBaseConfig, + obPartInfos: Array[OBOraclePartInfo], + priKeyColumnName: String): Array[InputPartition] = { + val arr = ArrayBuffer[OBOraclePartition]() + obPartInfos.foreach { + obPartInfo => + val partitionName = if (obPartInfo.subPartName != null) { + PARTITION_QUERY_FORMAT.format(obPartInfo.subPartName) + } else { + PARTITION_QUERY_FORMAT.format(obPartInfo.partName) + } + val pkForSql = normalizePkNameForSql(priKeyColumnName) + val keyTableInfo = + obtainIntPriKeyTableInfo(connection, config, partitionName, pkForSql) + val partitions = + computeWhereSparkPart(keyTableInfo, partitionName, pkForSql, config) + arr ++= partitions + } + arr.zipWithIndex.map { + case (partInfo, index) => + OBOraclePartition( + partInfo.partitionClause, + limitOffsetClause = EMPTY_STRING, + whereClause = partInfo.whereClause, + useHiddenPKColumn = partInfo.useHiddenPKColumn, + unevenlyWhereValue = partInfo.unevenlyWhereValue, + idx = index + ) + }.toArray + } + + private case class IntPriKeyTableInfo(count: Long, min: Long, max: Long) + + private def obtainIntPriKeyTableInfo( + connection: Connection, + config: OceanBaseConfig, + partName: String, + priKeyColumnName: String): IntPriKeyTableInfo = { + val statement = connection.createStatement() + val tableName = config.getDbTable + val useHiddenPKColHint = + if (priKeyColumnName.replace("\"", "") == HIDDEN_PK_INCREMENT) + s", opt_param('hidden_column_visible', 'true') " + else EMPTY_STRING + val hint = + s"/*+ PARALLEL(${config.getJdbcStatsParallelHintDegree}) $useHiddenPKColHint ${queryTimeoutHint(config)} */" + + val sql = + s""" + |SELECT + | $hint count(1) AS cnt, min(%s), max(%s) + |FROM $tableName $partName + |""".stripMargin + .format(normalizePkNameForSql(priKeyColumnName), normalizePkNameForSql(priKeyColumnName)) + try { + val rs = statement.executeQuery(sql) + if (rs.next()) IntPriKeyTableInfo(rs.getLong(1), rs.getLong(2), rs.getLong(3)) + else throw new RuntimeException(s"Failed to obtain count of $tableName.") + } finally { + statement.close() + } + } + + private def computeWhereSparkPart( + keyTableInfo: IntPriKeyTableInfo, + partitionClause: String, + priKeyColumnName: String, + config: OceanBaseConfig): Array[OBOraclePartition] = { + if (keyTableInfo.count <= 0) return Array.empty[OBOraclePartition] + + val desiredRowsPerPartition = + config.getJdbcMaxRecordsPrePartition.orElse(calPartitionSize(keyTableInfo.count)) + val numPartitions = + Math.ceil(keyTableInfo.count.toDouble / desiredRowsPerPartition).toInt.max(1) + val idRange = keyTableInfo.max - keyTableInfo.min + val step = (idRange + numPartitions - 1) / numPartitions + val useHidden = priKeyColumnName.replace("\"", "") == HIDDEN_PK_INCREMENT + + (0 until numPartitions).map { + i => + val lower = keyTableInfo.min + i * step + val upper = if (i == numPartitions - 1) keyTableInfo.max + 1 else lower + step + val whereClause = s"($priKeyColumnName >= $lower AND $priKeyColumnName < $upper)" + OBOraclePartition( + partitionClause = partitionClause, + limitOffsetClause = EMPTY_STRING, + whereClause = whereClause, + useHiddenPKColumn = useHidden, + unevenlyWhereValue = Seq.empty, + idx = i) + }.toArray + } + + private case class UnevenlyPriKeyTableInfo(count: Long, min: Object, max: Object) + + private def obtainUnevenlyPriKeyTableInfo( + conn: Connection, + config: OceanBaseConfig, + partName: String, + priKeyColumnName: String): UnevenlyPriKeyTableInfo = { + val statement = conn.createStatement() + val tableName = config.getDbTable + val hint = + s"/*+ PARALLEL(${config.getJdbcStatsParallelHintDegree}) ${queryTimeoutHint(config)} */" + val sql = + s""" + |SELECT $hint + | count(1) AS cnt, min(%s), max(%s) + |FROM $tableName $partName + |""".stripMargin + .format(normalizePkNameForSql(priKeyColumnName), normalizePkNameForSql(priKeyColumnName)) + try { + val rs = statement.executeQuery(sql) + if (rs.next()) UnevenlyPriKeyTableInfo(rs.getLong(1), rs.getObject(2), rs.getObject(3)) + else throw new RuntimeException(s"Failed to obtain count of $tableName.") + } finally { + statement.close() + } + } + + private def computeUnevenlyWherePartInfoForNonPartTable( + conn: Connection, + config: OceanBaseConfig, + priKeyColumnName: String): Array[InputPartition] = { + val unevenlyPriKeyTableInfo = + obtainUnevenlyPriKeyTableInfo( + conn, + config, + EMPTY_STRING, + normalizePkNameForSql(priKeyColumnName)) + if (unevenlyPriKeyTableInfo.count <= 0) Array.empty + else + computeUnevenlyWhereSparkPart( + conn, + unevenlyPriKeyTableInfo, + EMPTY_STRING, + priKeyColumnName, + config) + .asInstanceOf[Array[InputPartition]] + } + + private def computeUnevenlyWherePartInfoForPartTable( + config: OceanBaseConfig, + obPartInfos: Array[OBOraclePartInfo], + priKeyColumnName: String): Array[InputPartition] = { + val arr = ArrayBuffer[OBOraclePartition]() + obPartInfos.foreach { + obPartInfo => + val conn = OBJdbcUtils.getConnection(config) + try { + val partitionName = if (obPartInfo.subPartName != null) { + PARTITION_QUERY_FORMAT.format(obPartInfo.subPartName) + } else { + PARTITION_QUERY_FORMAT.format(obPartInfo.partName) + } + val unevenlyPriKeyTableInfo = + obtainUnevenlyPriKeyTableInfo( + conn, + config, + partitionName, + normalizePkNameForSql(priKeyColumnName)) + val partitions = computeUnevenlyWhereSparkPart( + conn, + unevenlyPriKeyTableInfo, + partitionName, + normalizePkNameForSql(priKeyColumnName), + config) + arr ++= partitions + } finally { + conn.close() + } + } + arr.zipWithIndex.map { + case (partInfo, index) => + OBOraclePartition( + partInfo.partitionClause, + limitOffsetClause = EMPTY_STRING, + whereClause = partInfo.whereClause, + useHiddenPKColumn = partInfo.useHiddenPKColumn, + unevenlyWhereValue = partInfo.unevenlyWhereValue, + idx = index + ) + }.toArray + } + + private def computeUnevenlyWhereSparkPart( + conn: Connection, + keyTableInfo: UnevenlyPriKeyTableInfo, + partitionClause: String, + priKeyColumnName: String, + config: OceanBaseConfig): Array[OBOraclePartition] = { + if (keyTableInfo.count <= 0) return Array.empty[OBOraclePartition] + + val desiredRowsPerPartition = + config.getJdbcMaxRecordsPrePartition.orElse(calPartitionSize(keyTableInfo.count)) + var previousChunkEnd = keyTableInfo.min + var chunkEnd = nextChunkEnd( + conn, + desiredRowsPerPartition, + previousChunkEnd, + keyTableInfo.max, + partitionClause, + priKeyColumnName, + config) + val arrayBuffer = ArrayBuffer[OBOraclePartition]() + var idx = 0 + if (Objects.nonNull(chunkEnd)) { + val whereClause = s"($priKeyColumnName < ?)" + arrayBuffer += OBOraclePartition( + partitionClause = partitionClause, + limitOffsetClause = EMPTY_STRING, + whereClause = whereClause, + unevenlyWhereValue = Seq(chunkEnd), + idx = idx) + } + while (Objects.nonNull(chunkEnd)) { + previousChunkEnd = chunkEnd + chunkEnd = nextChunkEnd( + conn, + desiredRowsPerPartition, + previousChunkEnd, + keyTableInfo.max, + partitionClause, + priKeyColumnName, + config) + if (Objects.nonNull(chunkEnd)) { + val whereClause = s"($priKeyColumnName >= ? AND $priKeyColumnName < ?)" + idx = idx + 1 + arrayBuffer += OBOraclePartition( + partitionClause = partitionClause, + limitOffsetClause = EMPTY_STRING, + whereClause = whereClause, + unevenlyWhereValue = Seq(previousChunkEnd, chunkEnd), + idx = idx) + } + } + if (Objects.isNull(chunkEnd)) { + val whereClause = s"($priKeyColumnName >= ?)" + idx = idx + 1 + arrayBuffer += OBOraclePartition( + partitionClause = partitionClause, + limitOffsetClause = EMPTY_STRING, + whereClause = whereClause, + unevenlyWhereValue = Seq(previousChunkEnd), + idx = idx) + } + arrayBuffer.toArray + } + + private def nextChunkEnd( + conn: Connection, + chunkSize: Long, + includedLowerBound: Object, + max: Object, + partitionClause: String, + priKeyColumnName: String, + config: OceanBaseConfig): Object = { + var chunkEnd: Object = queryNextChunkMax( + conn, + chunkSize, + includedLowerBound, + partitionClause, + priKeyColumnName, + config) + if (Objects.isNull(chunkEnd)) return chunkEnd + if (compare(chunkEnd, max) >= 0) null else chunkEnd + } + + private def queryNextChunkMax( + conn: Connection, + chunkSize: Long, + includedLowerBound: Object, + partitionClause: String, + priKeyColumnName: String, + config: OceanBaseConfig): Object = { + val tableName = config.getDbTable + val hint = + s"/*+ PARALLEL(${config.getJdbcStatsParallelHintDegree}) ${queryTimeoutHint(config)} */" + val pkForSql = normalizePkNameForSql(priKeyColumnName) + val sql = + s""" + |SELECT MAX(%s) AS chunk_high FROM ( + | SELECT %s FROM %s %s WHERE %s > ? ORDER BY %s ASC + |) WHERE ROWNUM <= %d + |""".stripMargin.format( + pkForSql, + pkForSql, + tableName, + partitionClause, + pkForSql, + pkForSql, + chunkSize) + // attach hint at the beginning of inner SELECT + val finalSql = sql.replaceFirst("SELECT ", s"SELECT $hint ") + val statement = conn.prepareStatement(finalSql) + try { + statement.setObject(1, includedLowerBound) + val rs = statement.executeQuery() + if (rs.next()) rs.getObject(1) + else throw new RuntimeException("Failed to query next chunk max.") + } finally { + statement.close() + } + } + + private def compare(obj1: Any, obj2: Any): Int = (obj1, obj2) match { + case (c1: Comparable[_], c2) if c1.getClass == c2.getClass => + c1.asInstanceOf[Comparable[Any]].compareTo(c2) + case _ => obj1.toString.compareTo(obj2.toString) + } + + private def queryTimeoutHint(config: OceanBaseConfig): String = if ( + config.getQueryTimeoutHintDegree > 0 + ) { + s", query_timeout(${config.getQueryTimeoutHintDegree}) " + } else { + "" + } +} + +case class OBOraclePartInfo(partName: String, subPartName: String) diff --git a/spark-connector-oceanbase/spark-connector-oceanbase-base/src/main/scala/org/apache/spark/sql/ExprUtils.scala b/spark-connector-oceanbase/spark-connector-oceanbase-base/src/main/scala/org/apache/spark/sql/ExprUtils.scala index 17dabb3..0d892a2 100644 --- a/spark-connector-oceanbase/spark-connector-oceanbase-base/src/main/scala/org/apache/spark/sql/ExprUtils.scala +++ b/spark-connector-oceanbase/spark-connector-oceanbase-base/src/main/scala/org/apache/spark/sql/ExprUtils.scala @@ -38,6 +38,19 @@ object ExprUtils extends SQLConfHelper with Serializable { throw new UnsupportedOperationException(s"Unsupported transform: ${other.name()}") } + def toOBOraclePartition(transform: Transform, config: OceanBaseConfig): String = + transform match { + case bucket: BucketTransform => + val identities = bucket.columns + .map(col => s""""${col.fieldNames().head}"""") + .mkString(",") + s"PARTITION BY HASH($identities) PARTITIONS ${bucket.numBuckets.value()}".stripMargin + case _: YearsTransform | _: DaysTransform | _: HoursTransform | _: IdentityTransform => + throw new UnsupportedOperationException("OceanBase does not support dynamic partitions.") + case other: Transform => + throw new UnsupportedOperationException(s"Unsupported transform: ${other.name()}") + } + /** * Turns a single Filter into a String representing a SQL expression. Returns None for an * unhandled filter. diff --git a/spark-connector-oceanbase/spark-connector-oceanbase-base/src/test/resources/sql/oracle/products_simple.sql b/spark-connector-oceanbase/spark-connector-oceanbase-base/src/test/resources/sql/oracle/products_simple.sql new file mode 100644 index 0000000..3969ebd --- /dev/null +++ b/spark-connector-oceanbase/spark-connector-oceanbase-base/src/test/resources/sql/oracle/products_simple.sql @@ -0,0 +1,63 @@ +-- Copyright 2024 OceanBase. +-- +-- Licensed under the Apache License, Version 2.0 (the "License"); +-- you may not use this file except in compliance with the License. +-- You may obtain a copy of the License at +-- +-- http://www.apache.org/licenses/LICENSE-2.0 +-- +-- Unless required by applicable law or agreed to in writing, software +-- distributed under the License is distributed on an "AS IS" BASIS, +-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +-- See the License for the specific language governing permissions and +-- limitations under the License. + +-- Create products table for Oracle mode testing +CREATE TABLE products ( + id NUMBER(10) NOT NULL, + name VARCHAR2(255), + description VARCHAR2(1000), + weight NUMBER(10,2), + CONSTRAINT pk_products PRIMARY KEY (id) +); + +CREATE TABLE products_no_pri_key ( + id NUMBER(10) NOT NULL, + name VARCHAR2(255) NOT NULL, + description VARCHAR2(1000), + weight NUMBER(10,2) +) PARTITION BY HASH(id) PARTITIONS 3; + +CREATE TABLE products_full_pri_key ( + id NUMBER(10) NOT NULL, + name VARCHAR2(255) NOT NULL, + description VARCHAR2(1000), + weight NUMBER(10,2), + CONSTRAINT pk_products_full PRIMARY KEY (id, name, description, weight) +); + +CREATE TABLE products_no_int_pri_key ( + id VARCHAR2(255) NOT NULL, + name VARCHAR2(255) NOT NULL, + description VARCHAR2(1000), + weight NUMBER(10,2), + CONSTRAINT pk_products_no_int PRIMARY KEY (id, name) +); + +CREATE TABLE products_unique_key ( + id NUMBER(10), + name VARCHAR2(255), + description VARCHAR2(1000), + weight NUMBER(10,2), + CONSTRAINT uk_products_unique UNIQUE (id, name) +) PARTITION BY HASH(id) PARTITIONS 3; + +CREATE TABLE products_full_unique_key ( + id NUMBER(10), + name VARCHAR2(255), + description VARCHAR2(1000), + weight NUMBER(10,2), + CONSTRAINT uk_products_full_unique UNIQUE (id, name, description, weight) +); + + diff --git a/spark-connector-oceanbase/spark-connector-oceanbase-base/src/test/scala/com/oceanbase/spark/OBCatalogOracleITCase.scala b/spark-connector-oceanbase/spark-connector-oceanbase-base/src/test/scala/com/oceanbase/spark/OBCatalogOracleITCase.scala new file mode 100644 index 0000000..34c15f9 --- /dev/null +++ b/spark-connector-oceanbase/spark-connector-oceanbase-base/src/test/scala/com/oceanbase/spark/OBCatalogOracleITCase.scala @@ -0,0 +1,689 @@ +/* + * Copyright 2024 OceanBase. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package com.oceanbase.spark + +import com.oceanbase.spark.OBCatalogOracleITCase.expected +import com.oceanbase.spark.OceanBaseTestBase.assertEqualsInAnyOrder +import com.oceanbase.spark.config.OceanBaseConfig +import com.oceanbase.spark.dialect.OceanBaseOracleDialect +import com.oceanbase.spark.utils.OBJdbcUtils + +import org.apache.spark.sql.SparkSession +import org.junit.jupiter.api.{AfterEach, Assertions, BeforeEach, Disabled, Test} +import org.junit.jupiter.api.function.ThrowingSupplier + +import java.util + +@Disabled("Skipping Oracle tests, only running in local test environment.") +class OBCatalogOracleITCase extends OceanBaseOracleTestBase { + + @BeforeEach + def initEach(): Unit = { + // Check if table exists and drop it if it does + val config = new OceanBaseConfig(getOptions) + val dialect = new OceanBaseOracleDialect + OBJdbcUtils.withConnection(config) { + conn => + if (dialect.tableExists(conn, config)) { + dialect.dropTable(conn, config.getDbTable, config) + } + } + + initialize("sql/oracle/products_simple.sql") + } + + @AfterEach + def afterEach(): Unit = { + dropTables( + "PRODUCTS", + "PRODUCTS_NO_PRI_KEY", + "PRODUCTS_FULL_PRI_KEY", + "PRODUCTS_NO_INT_PRI_KEY", + "PRODUCTS_UNIQUE_KEY", + "PRODUCTS_FULL_UNIQUE_KEY" + ) + } + + val OB_CATALOG_CLASS = "com.oceanbase.spark.catalog.OceanBaseCatalog" + + @Test + def testCatalogBase(): Unit = { + val session = SparkSession + .builder() + .master("local[*]") + .config("spark.sql.catalog.ob", OB_CATALOG_CLASS) + .config("spark.sql.catalog.ob.url", getJdbcUrl) + .config("spark.sql.catalog.ob.username", getUsername) + .config("spark.sql.catalog.ob.password", getPassword) + .config("spark.sql.catalog.ob.schema-name", getSchemaName) + .getOrCreate() + + session.sql("use ob;") + insertTestData(session, "PRODUCTS") + queryAndVerifyTableData(session, "PRODUCTS", expected) + + insertTestData(session, "PRODUCTS_NO_PRI_KEY") + queryAndVerifyTableData(session, "PRODUCTS_NO_PRI_KEY", expected) + + session.stop() + } + + @Test + def testWhereClause(): Unit = { + val session = SparkSession + .builder() + .master("local[*]") + .config("spark.sql.catalog.ob", OB_CATALOG_CLASS) + .config("spark.sql.catalog.ob.url", getJdbcUrl) + .config("spark.sql.catalog.ob.username", getUsername) + .config("spark.sql.catalog.ob.password", getPassword) + .config("spark.sql.catalog.ob.schema-name", getSchemaName) + .config("spark.sql.catalog.ob.jdbc.max-records-per-partition", 2) + .getOrCreate() + + session.sql("use ob;") + insertTestData(session, "PRODUCTS") + + import scala.collection.JavaConverters._ + val expect1 = Seq( + "101,scooter,Small 2-wheel scooter,3.14", + "102,car battery,12V car battery,8.10", + "103,12-pack drill bits,12-pack of drill bits with sizes ranging from #40 to #3,0.80" + ).toList.asJava + queryAndVerify(session, "select * from PRODUCTS where ID >= 101 and ID < 104", expect1) + session.stop() + } + + @Test + def testJdbcInsetWithAutoCommit(): Unit = { + val session = SparkSession + .builder() + .master("local[*]") + .config("spark.sql.catalog.ob", OB_CATALOG_CLASS) + .config("spark.sql.catalog.ob.url", getJdbcUrl) + .config("spark.sql.catalog.ob.username", getUsername) + .config("spark.sql.catalog.ob.password", getPassword) + .config("spark.sql.catalog.ob.schema-name", getSchemaName) + .config("spark.sql.catalog.ob.jdbc.enable-autocommit", true.toString) + .getOrCreate() + + session.sql("use ob;") + insertTestData(session, "PRODUCTS") + queryAndVerifyTableData(session, "PRODUCTS", expected) + + insertTestData(session, "PRODUCTS_NO_PRI_KEY") + queryAndVerifyTableData(session, "PRODUCTS_NO_PRI_KEY", expected) + + session.stop() + } + + @Test + def testCatalogOp(): Unit = { + val session = SparkSession + .builder() + .master("local[*]") + .config("spark.sql.catalog.ob", OB_CATALOG_CLASS) + .config("spark.sql.catalog.ob.url", getJdbcUrl) + .config("spark.sql.catalog.ob.username", getUsername) + .config("spark.sql.catalog.ob.password", getPassword) + .config("spark.sql.catalog.ob.schema-name", getSchemaName) + .getOrCreate() + + session.sql("use ob;") + import scala.collection.JavaConverters._ + + val dbList = session.sql("show databases").collect().map(_.toString()).toList.asJava + assert(!dbList.isEmpty) + val tableList = session.sql("show tables").collect().map(_.toString()).toList.asJava + assert(!tableList.isEmpty) + + // test create/drop namespace + Assertions.assertDoesNotThrow(new ThrowingSupplier[Unit] { + override def get(): Unit = { + session.sql("create database TEST_ONLY") + session.sql("use TEST_ONLY") + } + }) + val dbList1 = session.sql("show databases").collect().map(_.toString()).toList.asJava + println(dbList1) + assert(dbList1.contains("[TEST_ONLY]")) + + session.sql("drop database TEST_ONLY") + val dbList2 = session.sql("show databases").collect().map(_.toString()).toList.asJava + assert(!dbList2.contains("[TEST_ONLY]")) + + session.stop() + } + + @Test + def testCatalogJdbcInsertWithNoPriKeyTable(): Unit = { + val session = SparkSession + .builder() + .master("local[*]") + .config("spark.sql.catalog.ob", OB_CATALOG_CLASS) + .config("spark.sql.catalog.ob.url", getJdbcUrl) + .config("spark.sql.catalog.ob.username", getUsername) + .config("spark.sql.catalog.ob.password", getPassword) + .config("spark.sql.catalog.ob.schema-name", getSchemaName) + .getOrCreate() + + session.sql("use ob;") + insertTestData(session, "PRODUCTS_NO_PRI_KEY") + queryAndVerifyTableData(session, "PRODUCTS_NO_PRI_KEY", expected) + session.stop() + } + + @Test + def testCatalogJdbcInsertWithFullPriKeyTable(): Unit = { + val session = SparkSession + .builder() + .master("local[*]") + .config("spark.sql.catalog.ob", OB_CATALOG_CLASS) + .config("spark.sql.catalog.ob.url", getJdbcUrl) + .config("spark.sql.catalog.ob.username", getUsername) + .config("spark.sql.catalog.ob.password", getPassword) + .config("spark.sql.catalog.ob.schema-name", getSchemaName) + .getOrCreate() + + session.sql("use ob;") + insertTestData(session, "PRODUCTS_FULL_PRI_KEY") + + queryAndVerifyTableData(session, "PRODUCTS_FULL_PRI_KEY", expected) + + session.stop() + } + + @Test + def testCatalogDirectLoadWrite(): Unit = { + val session = SparkSession + .builder() + .master("local[*]") + .config("spark.sql.catalog.ob", OB_CATALOG_CLASS) + .config("spark.sql.catalog.ob.url", getJdbcUrl) + .config("spark.sql.catalog.ob.username", getUsername) + .config("spark.sql.catalog.ob.password", getPassword) + .config("spark.sql.catalog.ob.schema-name", getSchemaName) + .config("spark.sql.defaultCatalog", "ob") + .config("spark.sql.catalog.ob.direct-load.enabled", "true") + .config("spark.sql.catalog.ob.direct-load.host", getHost) + .config("spark.sql.catalog.ob.direct-load.rpc-port", getRpcPort) + .config("spark.sql.catalog.ob.direct-load.username", getUsername) + .getOrCreate() + + insertTestData(session, "PRODUCTS") + queryAndVerifyTableData(session, "PRODUCTS", expected) + + insertTestData(session, "PRODUCTS_NO_PRI_KEY") + session.sql("insert overwrite PRODUCTS select * from PRODUCTS_NO_PRI_KEY") + queryAndVerifyTableData(session, "PRODUCTS", expected) + session.stop() + } + + @Test + def testTableCreate(): Unit = { + val session = SparkSession + .builder() + .master("local[1]") + .config("spark.sql.catalog.ob", OB_CATALOG_CLASS) + .config("spark.sql.catalog.ob.url", getJdbcUrl) + .config("spark.sql.catalog.ob.username", getUsername) + .config("spark.sql.catalog.ob.password", getPassword) + .config("spark.sql.catalog.ob.schema-name", getSchemaName) + .config("spark.sql.defaultCatalog", "ob") + .getOrCreate() + // Defensive cleanup to avoid leftover tables from previous failures + try { + val conn = getJdbcConnection() + val stmt = conn.createStatement() + try { + stmt.execute(s"DROP TABLE $getSchemaName.TEST1") + stmt.execute(s"DROP TABLE $getSchemaName.TEST2") + } catch { + case _: Throwable => // ignore + } finally { + stmt.close() + conn.close() + } + } catch { + case _: Throwable => // ignore + } + + insertTestData(session, "PRODUCTS") + // Test CTAS + session.sql("create table TEST1 as select * from PRODUCTS") + queryAndVerifyTableData(session, "TEST1", expected) + + // test bucket partition table: + // 1. column comment test + // 2. table comment test + // 3. table options test + session.sql( + """ + |CREATE TABLE TEST2( + | USER_ID DECIMAL(19, 0) COMMENT 'test_for_key', + | NAME VARCHAR(255) + |) + |PARTITIONED BY (bucket(16, USER_ID)) + |COMMENT 'test_for_table_create' + |TBLPROPERTIES('replica_num' = 2, COMPRESSION = 'zstd_1.0', primary_key = 'USER_ID, NAME'); + |""".stripMargin) + val showCreateTable = getShowCreateTable(s"$getSchemaName.TEST2") + val showLC = showCreateTable.toLowerCase + val hasPartition = showLC.contains("partition by") && showLC.contains("user_id") + val hasPrimaryKey = + showLC.contains("primary key") && showLC.contains("user_id") && showLC.contains("name") + // In Oracle mode, SHOW CREATE may not return column/table comments; only check partition and primary key + Assertions.assertTrue(hasPartition && hasPrimaryKey) + dropTables("TEST1", "TEST2") + session.stop() + } + + @Test + def testTruncateAndOverWriteTable(): Unit = { + val session = SparkSession + .builder() + .master("local[*]") + .config("spark.sql.catalog.ob", OB_CATALOG_CLASS) + .config("spark.sql.catalog.ob.url", getJdbcUrl) + .config("spark.sql.catalog.ob.username", getUsername) + .config("spark.sql.catalog.ob.password", getPassword) + .config("spark.sql.catalog.ob.schema-name", getSchemaName) + .getOrCreate() + + session.sql("use ob;") + insertTestData(session, "PRODUCTS") + session.sql("truncate table PRODUCTS") + val expect = new util.ArrayList[String]() + queryAndVerifyTableData(session, "PRODUCTS", expect) + + insertTestData(session, "PRODUCTS_NO_PRI_KEY") + session.sql("insert overwrite PRODUCTS select * from PRODUCTS_NO_PRI_KEY") + queryAndVerifyTableData(session, "PRODUCTS", expected) + + session.stop() + } + + @Test + def testDeleteWhere(): Unit = { + val session = SparkSession + .builder() + .master("local[*]") + .config("spark.sql.catalog.ob", OB_CATALOG_CLASS) + .config("spark.sql.catalog.ob.url", getJdbcUrl) + .config("spark.sql.catalog.ob.username", getUsername) + .config("spark.sql.catalog.ob.password", getPassword) + .config("spark.sql.catalog.ob.schema-name", getSchemaName) + .getOrCreate() + + session.sql("use ob;") + insertTestData(session, "PRODUCTS") + session.sql("delete from PRODUCTS where 1 = 0") + queryAndVerifyTableData(session, "PRODUCTS", expected) + + session.sql("delete from PRODUCTS where ID = 1") + queryAndVerifyTableData(session, "PRODUCTS", expected) + + session.sql("delete from PRODUCTS where DESCRIPTION is null") + queryAndVerifyTableData(session, "PRODUCTS", expected) + + session.sql("delete from PRODUCTS where ID in (101, 102, 103)") + session.sql("delete from PRODUCTS where NAME = 'hammer'") + + session.sql("delete from PRODUCTS where NAME like 'rock%'") + session.sql("delete from PRODUCTS where NAME like '%jack%' and ID = 108 or WEIGHT = 5.3") + session.sql("delete from PRODUCTS where ID >= 109") + + val expect = new util.ArrayList[String]() + queryAndVerifyTableData(session, "PRODUCTS", expect) + + session.stop() + } + + @Test + def testLimitAndTopNPushDown(): Unit = { + val session = SparkSession + .builder() + .master("local[*]") + .config("spark.sql.catalog.ob", OB_CATALOG_CLASS) + .config("spark.sql.catalog.ob.url", getJdbcUrl) + .config("spark.sql.catalog.ob.username", getUsername) + .config("spark.sql.catalog.ob.password", getPassword) + .config("spark.sql.catalog.ob.schema-name", getSchemaName) + .getOrCreate() + + session.sql("use ob;") + insertTestData(session, "PRODUCTS") + + import scala.collection.JavaConverters._ + // Case limit + val actual = session + .sql(s"select * from PRODUCTS limit 3") + .collect() + .map( + _.toString().drop(1).dropRight(1) + ) + .toList + .asJava + val expected: util.List[String] = util.Arrays.asList( + "101,scooter,Small 2-wheel scooter,3.14", + "102,car battery,12V car battery,8.10", + "103,12-pack drill bits,12-pack of drill bits with sizes ranging from #40 to #3,0.80" + ) + assertEqualsInAnyOrder(expected, actual) + + // Case top N + val actual1 = session + .sql(s"select * from PRODUCTS order by ID desc limit 3") + .collect() + .map( + _.toString().drop(1).dropRight(1) + ) + .toList + .asJava + val expected1: util.List[String] = util.Arrays.asList( + "109,spare tire,24 inch spare tire,22.20", + "108,jacket,water resistent black wind breaker,0.10", + "107,rocks,box of assorted rocks,5.30" + ) + assertEqualsInAnyOrder(expected1, actual1) + + val actual2 = session + .sql(s"select * from PRODUCTS order by ID desc, NAME asc limit 3") + .collect() + .map( + _.toString().drop(1).dropRight(1) + ) + .toList + .asJava + println(actual2) + val expected2: util.List[String] = util.Arrays.asList( + "109,spare tire,24 inch spare tire,22.20", + "108,jacket,water resistent black wind breaker,0.10", + "107,rocks,box of assorted rocks,5.30" + ) + assertEqualsInAnyOrder(expected2, actual2) + + session.stop() + } + + @Test + def testUpsertUniqueKey(): Unit = { + val session = SparkSession + .builder() + .master("local[*]") + .config("spark.sql.catalog.ob", OB_CATALOG_CLASS) + .config("spark.sql.catalog.ob.url", getJdbcUrl) + .config("spark.sql.catalog.ob.username", getUsername) + .config("spark.sql.catalog.ob.password", getPassword) + .config("spark.sql.catalog.ob.schema-name", getSchemaName) + .getOrCreate() + + session.sql("use ob;") + insertTestData(session, "PRODUCTS_UNIQUE_KEY") + queryAndVerifyTableData(session, "PRODUCTS_UNIQUE_KEY", expected) + + insertTestData(session, "PRODUCTS_FULL_UNIQUE_KEY") + queryAndVerifyTableData(session, "PRODUCTS_FULL_UNIQUE_KEY", expected) + session.stop() + } + + @Test + def testUpsertUniqueKeyWithNullValue(): Unit = { + val session = SparkSession + .builder() + .master("local[1]") + .config("spark.sql.catalog.ob", OB_CATALOG_CLASS) + .config("spark.sql.catalog.ob.url", getJdbcUrl) + .config("spark.sql.catalog.ob.username", getUsername) + .config("spark.sql.catalog.ob.password", getPassword) + .config("spark.sql.catalog.ob.schema-name", getSchemaName) + .getOrCreate() + + session.sql("use ob;") + insertTestDataWithNullValue(session, "PRODUCTS_UNIQUE_KEY") + val expectedWithNullValue: util.List[String] = util.Arrays.asList( + "null,null,Small 2-wheel scooter,3.14", + "102,car battery,12V car battery,8.10", + "103,12-pack drill bits,12-pack of drill bits with sizes ranging from #40 to #3,0.80", + "104,hammer,12oz carpenter's hammer,0.75", + "null,null,14oz carpenter's hammer,0.88", + "106,hammer,box of assorted rocks,null", + "108,jacket,null,0.10", + "109,spare tire,24 inch spare tire,22.20" + ) + queryAndVerifyTableData(session, "PRODUCTS_UNIQUE_KEY", expectedWithNullValue) + + insertTestDataWithNullValue(session, "PRODUCTS_FULL_UNIQUE_KEY") + val expectedWithNullValue1: util.List[String] = util.Arrays.asList( + "null,null,Small 2-wheel scooter,3.14", + "102,car battery,12V car battery,8.10", + "103,12-pack drill bits,12-pack of drill bits with sizes ranging from #40 to #3,0.80", + "104,hammer,12oz carpenter's hammer,0.75", + "null,null,14oz carpenter's hammer,0.88", + "106,hammer,box of assorted rocks,null", + "106,hammer,16oz carpenter's hammer,1.00", + "108,jacket,null,0.10", + "109,spare tire,24 inch spare tire,22.20" + ) + queryAndVerifyTableData(session, "PRODUCTS_FULL_UNIQUE_KEY", expectedWithNullValue1) + session.stop() + } + + @Test + def testAggregatePushdown(): Unit = { + val session = SparkSession + .builder() + .master("local[*]") + .config("spark.sql.catalog.ob", OB_CATALOG_CLASS) + .config("spark.sql.catalog.ob.url", getJdbcUrl) + .config("spark.sql.catalog.ob.username", getUsername) + .config("spark.sql.catalog.ob.password", getPassword) + .config("spark.sql.catalog.ob.schema-name", getSchemaName) + .getOrCreate() + + session.sql("use ob;") + insertTestData(session, "PRODUCTS") + + import scala.collection.JavaConverters._ + + /** + * The sql generated and push-down to oceanbase: + * + * SELECT /*+ PARALLEL(1) */ `NAME`,MIN(`ID`),MAX(`WEIGHT`) FROM `TEST`.`PRODUCTS` WHERE (`ID` + * >= 101 AND `ID` < 110) GROUP BY `NAME` + * + * In this case, tested and find: spark will not push down topN, but will push down aggregate + */ + val expect1 = Seq("spare tire,109,22.20", "scooter,101,3.14", "rocks,107,5.30").toList.asJava + queryAndVerify( + session, + "select NAME, min(ID), max(WEIGHT) from PRODUCTS group by NAME order by NAME desc limit 3", + expect1) + + session.stop() + } + + @Test + def testCredentialAliasPassword(): Unit = { + import org.apache.hadoop.conf.Configuration + import org.apache.hadoop.security.alias.CredentialProviderFactory + import java.io.File + import java.nio.file.Files + + // Create temporary credential provider storage + val tempDir = Files.createTempDirectory("test-credentials") + val keystoreFile = new File(tempDir.toFile, "test.jceks") + val keystorePath = s"jceks://file${keystoreFile.getAbsolutePath}" + + // Create credential provider and add password + val hadoopConf = new Configuration() + hadoopConf.set(CredentialProviderFactory.CREDENTIAL_PROVIDER_PATH, keystorePath) + + val provider = CredentialProviderFactory.getProviders(hadoopConf).get(0) + provider.createCredentialEntry("test.password", getPassword.toCharArray) + provider.flush() + + try { + val session = SparkSession + .builder() + .master("local[*]") + .config(s"spark.hadoop.hadoop.security.credential.provider.path", keystorePath) + .config("spark.sql.catalog.ob", OB_CATALOG_CLASS) + .config("spark.sql.catalog.ob.url", getJdbcUrl) + .config("spark.sql.catalog.ob.username", getUsername) + .config("spark.sql.catalog.ob.password", "alias:test.password") // Use alias format + .config("spark.sql.catalog.ob.schema-name", getSchemaName) + .getOrCreate() + + session.sql("use ob;") + insertTestData(session, "PRODUCTS") + queryAndVerifyTableData(session, "PRODUCTS", expected) + + session.stop() + } finally { + // Clean up temporary files + keystoreFile.delete() + tempDir.toFile.delete() + } + } + + @Test + def testCredentialAliasPasswordWithDirectLoad(): Unit = { + import org.apache.hadoop.conf.Configuration + import org.apache.hadoop.security.alias.CredentialProviderFactory + import java.io.File + import java.nio.file.Files + + // Create temporary credential provider storage + val tempDir = Files.createTempDirectory("test-credentials") + val keystoreFile = new File(tempDir.toFile, "test.jceks") + val keystorePath = s"jceks://file${keystoreFile.getAbsolutePath}" + + // Create credential provider and add password + val hadoopConf = new Configuration() + hadoopConf.set(CredentialProviderFactory.CREDENTIAL_PROVIDER_PATH, keystorePath) + + val provider = CredentialProviderFactory.getProviders(hadoopConf).get(0) + provider.createCredentialEntry("test.password", getPassword.toCharArray) + provider.flush() + + try { + val session = SparkSession + .builder() + .master("local[*]") + .config(s"spark.hadoop.hadoop.security.credential.provider.path", keystorePath) + .config("spark.sql.catalog.ob", OB_CATALOG_CLASS) + .config("spark.sql.catalog.ob.url", getJdbcUrl) + .config("spark.sql.catalog.ob.username", getUsername) + .config("spark.sql.catalog.ob.password", "alias:test.password") // Use alias format + .config("spark.sql.catalog.ob.schema-name", getSchemaName) + .config("spark.sql.catalog.ob.direct-load.enabled", "true") + .config("spark.sql.catalog.ob.direct-load.host", getHost) + .config("spark.sql.catalog.ob.direct-load.rpc-port", getRpcPort) + .config("spark.sql.catalog.ob.direct-load.username", getUsername) + .getOrCreate() + + session.sql("use ob;") + insertTestData(session, "PRODUCTS") + queryAndVerifyTableData(session, "PRODUCTS", expected) + + session.stop() + } finally { + // Clean up temporary files + keystoreFile.delete() + tempDir.toFile.delete() + } + } + + private def queryAndVerifyTableData( + session: SparkSession, + tableName: String, + expected: util.List[String]): Unit = { + import scala.collection.JavaConverters._ + val actual = session + .sql(s"select * from $tableName") + .collect() + .map( + _.toString().drop(1).dropRight(1) + ) + .toList + .asJava + assertEqualsInAnyOrder(expected, actual) + } + + private def queryAndVerify( + session: SparkSession, + sql: String, + expected: util.List[String]): Unit = { + import scala.collection.JavaConverters._ + val actual = session + .sql(sql) + .collect() + .map( + _.toString().drop(1).dropRight(1) + ) + .toList + .asJava + println(actual) + assertEqualsInAnyOrder(expected, actual) + } + + private def insertTestData(session: SparkSession, tableName: String): Unit = { + session.sql( + s""" + |INSERT INTO $getSchemaName.$tableName VALUES + |(101, 'scooter', 'Small 2-wheel scooter', 3.14), + |(102, 'car battery', '12V car battery', 8.1), + |(103, '12-pack drill bits', '12-pack of drill bits with sizes ranging from #40 to #3', 0.8), + |(104, 'hammer', '12oz carpenter\\'s hammer', 0.75), + |(105, 'hammer', '14oz carpenter\\'s hammer', 0.875), + |(106, 'hammer', '16oz carpenter\\'s hammer', 1.0), + |(107, 'rocks', 'box of assorted rocks', 5.3), + |(108, 'jacket', 'water resistent black wind breaker', 0.1), + |(109, 'spare tire', '24 inch spare tire', 22.2); + |""".stripMargin) + } + + private def insertTestDataWithNullValue(session: SparkSession, tableName: String): Unit = { + session.sql( + s""" + |INSERT INTO $getSchemaName.$tableName VALUES + |(null, null, 'Small 2-wheel scooter', 3.14), + |(102, 'car battery', '12V car battery', 8.1), + |(103, '12-pack drill bits', '12-pack of drill bits with sizes ranging from #40 to #3', 0.8), + |(104, 'hammer', '12oz carpenter\\'s hammer', 0.75), + |(null, null, '14oz carpenter\\'s hammer', 0.875), + |(106, 'hammer', '16oz carpenter\\'s hammer', 1.0), + |(106, 'hammer', 'box of assorted rocks', null), + |(108, 'jacket', null, 0.1), + |(109, 'spare tire', '24 inch spare tire', 22.2); + |""".stripMargin) + } +} + +object OBCatalogOracleITCase { + val expected: util.List[String] = util.Arrays.asList( + "101,scooter,Small 2-wheel scooter,3.14", + "102,car battery,12V car battery,8.10", + "103,12-pack drill bits,12-pack of drill bits with sizes ranging from #40 to #3,0.80", + "104,hammer,12oz carpenter's hammer,0.75", + "105,hammer,14oz carpenter's hammer,0.88", + "106,hammer,16oz carpenter's hammer,1.00", + "107,rocks,box of assorted rocks,5.30", + "108,jacket,water resistent black wind breaker,0.10", + "109,spare tire,24 inch spare tire,22.20" + ) +}