Skip to content

Commit 71979ef

Browse files
committed
update macros so as to reject SELECT * queries + accept queries with missing fields that have defaults or are options
1 parent a9fb3a9 commit 71979ef

File tree

2 files changed

+250
-15
lines changed

2 files changed

+250
-15
lines changed

macros-tests/src/test/scala/app/softnetwork/elastic/sql/macros/SQLQueryValidatorSpec.scala

Lines changed: 85 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,11 +69,50 @@ class SQLQueryValidatorSpec extends AnyFlatSpec with Matchers {
6969
)""")
7070
}
7171

72+
it should "accept query with missing Option fields" in {
73+
assertCompiles("""
74+
import app.softnetwork.elastic.client.macros.TestElasticClientApi
75+
import app.softnetwork.elastic.client.macros.TestElasticClientApi.defaultFormats
76+
import app.softnetwork.elastic.sql.macros.SQLQueryValidatorSpec.ProductWithOptional
77+
import app.softnetwork.elastic.sql.query.SQLQuery
78+
79+
TestElasticClientApi.searchAs[ProductWithOptional](
80+
"SELECT id, name FROM products"
81+
)
82+
""")
83+
}
84+
85+
it should "accept query with missing fields that have defaults" in {
86+
assertCompiles("""
87+
import app.softnetwork.elastic.client.macros.TestElasticClientApi
88+
import app.softnetwork.elastic.client.macros.TestElasticClientApi.defaultFormats
89+
import app.softnetwork.elastic.sql.macros.SQLQueryValidatorSpec.ProductWithDefaults
90+
import app.softnetwork.elastic.sql.query.SQLQuery
91+
92+
TestElasticClientApi.searchAs[ProductWithDefaults](
93+
"SELECT id, name FROM products"
94+
)
95+
""")
96+
}
97+
98+
it should "accept SELECT * with Unchecked variant" in {
99+
assertCompiles("""
100+
import app.softnetwork.elastic.client.macros.TestElasticClientApi
101+
import app.softnetwork.elastic.client.macros.TestElasticClientApi.defaultFormats
102+
import app.softnetwork.elastic.sql.macros.SQLQueryValidatorSpec.Product
103+
import app.softnetwork.elastic.sql.query.SQLQuery
104+
105+
TestElasticClientApi.searchAsUnchecked[Product](
106+
SQLQuery("SELECT * FROM products")
107+
)
108+
""")
109+
}
110+
72111
// ============================================================
73112
// Negative Tests (Should NOT Compile)
74113
// ============================================================
75114

76-
it should "reject missing fields" in {
115+
it should "REJECT query with missing required fields" in {
77116
assertDoesNotCompile("""
78117
import app.softnetwork.elastic.client.macros.TestElasticClientApi
79118
import app.softnetwork.elastic.client.macros.TestElasticClientApi.defaultFormats
@@ -85,7 +124,7 @@ class SQLQueryValidatorSpec extends AnyFlatSpec with Matchers {
85124
)""")
86125
}
87126

88-
it should "reject invalid field names" in {
127+
it should "REJECT query with invalid field names" in {
89128
assertDoesNotCompile("""
90129
import app.softnetwork.elastic.client.macros.TestElasticClientApi
91130
import app.softnetwork.elastic.client.macros.TestElasticClientApi.defaultFormats
@@ -97,7 +136,7 @@ class SQLQueryValidatorSpec extends AnyFlatSpec with Matchers {
97136
)""")
98137
}
99138

100-
it should "reject type mismatches" in {
139+
it should "REJECT query with type mismatches" in {
101140
assertDoesNotCompile("""
102141
import app.softnetwork.elastic.client.macros.TestElasticClientApi
103142
import app.softnetwork.elastic.client.macros.TestElasticClientApi.defaultFormats
@@ -122,7 +161,7 @@ class SQLQueryValidatorSpec extends AnyFlatSpec with Matchers {
122161
)""")
123162
}
124163

125-
it should "reject dynamic queries (non-literals)" in {
164+
it should "REJECT dynamic queries (non-literals)" in {
126165
assertDoesNotCompile("""
127166
import app.softnetwork.elastic.client.macros.TestElasticClientApi
128167
import app.softnetwork.elastic.client.macros.TestElasticClientApi.defaultFormats
@@ -134,6 +173,33 @@ class SQLQueryValidatorSpec extends AnyFlatSpec with Matchers {
134173
s"SELECT id, $dynamicField FROM products"
135174
)""")
136175
}
176+
177+
it should "REJECT SELECT * queries" in {
178+
assertDoesNotCompile("""
179+
import app.softnetwork.elastic.client.macros.TestElasticClientApi
180+
import app.softnetwork.elastic.client.macros.TestElasticClientApi.defaultFormats
181+
import app.softnetwork.elastic.sql.macros.SQLQueryValidatorSpec.Product
182+
import app.softnetwork.elastic.sql.query.SQLQuery
183+
184+
TestElasticClientApi.searchAs[Product](
185+
"SELECT * FROM products"
186+
)
187+
""")
188+
}
189+
190+
it should "REJECT SELECT * even with WHERE clause" in {
191+
assertDoesNotCompile("""
192+
import app.softnetwork.elastic.client.macros.TestElasticClientApi
193+
import app.softnetwork.elastic.client.macros.TestElasticClientApi.defaultFormats
194+
import app.softnetwork.elastic.sql.macros.SQLQueryValidatorSpec.Product
195+
import app.softnetwork.elastic.sql.query.SQLQuery
196+
197+
TestElasticClientApi.searchAs[Product](
198+
"SELECT * FROM products WHERE active = true"
199+
)
200+
""")
201+
}
202+
137203
}
138204

139205
object SQLQueryValidatorSpec {
@@ -146,6 +212,21 @@ object SQLQueryValidatorSpec {
146212
createdAt: java.time.LocalDateTime
147213
)
148214

215+
case class ProductWithOptional(
216+
id: String,
217+
name: String,
218+
price: Option[Double], // ✅ OK if missing
219+
stock: Option[Int] // ✅ OK if missing
220+
)
221+
222+
// Case class with default values
223+
case class ProductWithDefaults(
224+
id: String,
225+
name: String,
226+
price: Double = 0.0, // ✅ OK if missing
227+
stock: Int = 0 // ✅ OK if missing
228+
)
229+
149230
case class Numbers(
150231
tiny: Byte,
151232
small: Short,

macros/src/main/scala/app/softnetwork/elastic/sql/macros/SQLQueryValidator.scala

Lines changed: 165 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@ trait SQLQueryValidator {
2929

3030
/** Validates an SQL query against a type T. Returns the SQL query if valid, otherwise aborts
3131
* compilation.
32+
* @note
33+
* query fields must not exist in case class because we are using Jackson to deserialize the
34+
* results with the following option DeserializationFeature.FAIL_ON_UNKNOWN_PROPERTIES = false
3235
*/
3336
protected def validateSQLQuery[T: c.WeakTypeTag](c: blackbox.Context)(
3437
query: c.Expr[String]
@@ -46,6 +49,11 @@ trait SQLQueryValidator {
4649
// 2. Parse the SQL query
4750
val parsedQuery = parseSQLQuery(c)(sqlQuery)
4851

52+
// ============================================================
53+
// ✅ NEW: Reject SELECT *
54+
// ============================================================
55+
rejectSelectStar(c)(parsedQuery, sqlQuery)
56+
4957
// 3. Extract the selected fields
5058
val queryFields = extractQueryFields(parsedQuery)
5159

@@ -56,13 +64,13 @@ trait SQLQueryValidator {
5664
val caseClassFields = extractCaseClassFields(c)(tpe)
5765
c.echo(c.enclosingPosition, s"📦 Case class fields: ${caseClassFields.mkString(", ")}")
5866

59-
// 5. Validate the fields
60-
validateFields(c)(queryFields, caseClassFields, tpe)
67+
// 5. Validate: missing case class fields must have defaults or be Option
68+
validateMissingFieldsHaveDefaults(c)(queryFields, caseClassFields, tpe)
6169

62-
// 6. Validate the types
70+
// 7. Validate the types
6371
validateTypes(c)(parsedQuery, caseClassFields)
6472

65-
// 7. Return the validated request
73+
// 8. Return the validated request
6674
sqlQuery
6775
}
6876

@@ -108,6 +116,47 @@ trait SQLQueryValidator {
108116
}
109117
}
110118

119+
// ============================================================
120+
// ✅ Reject SELECT * (incompatible with compile-time validation)
121+
// ============================================================
122+
private def rejectSelectStar(c: blackbox.Context)(
123+
parsedQuery: SQLSearchRequest,
124+
sqlQuery: String
125+
): Unit = {
126+
127+
// Check if any field is a wildcard (*)
128+
val hasWildcard = parsedQuery.select.fields.exists { field =>
129+
field.identifier.name == "*"
130+
}
131+
132+
if (hasWildcard) {
133+
c.abort(
134+
c.enclosingPosition,
135+
s"""❌ SELECT * is not allowed with compile-time validation.
136+
|
137+
|Query: $sqlQuery
138+
|
139+
|Reason:
140+
| • Cannot validate field existence at compile-time
141+
| • Cannot validate type compatibility at compile-time
142+
| • Schema changes will break silently at runtime
143+
|
144+
|Solution:
145+
| 1. Explicitly list all required fields:
146+
| SELECT id, name, price FROM products
147+
|
148+
| 2. Use the *Unchecked() variant for dynamic queries:
149+
| searchAsUnchecked[Product](SQLQuery("SELECT * FROM products"))
150+
|
151+
|Best Practice:
152+
| Always explicitly select only the fields you need.
153+
|""".stripMargin
154+
)
155+
}
156+
157+
c.echo(c.enclosingPosition, "✅ No SELECT * detected")
158+
}
159+
111160
private def extractQueryFields(parsedQuery: SQLSearchRequest): Set[String] = {
112161
parsedQuery.select.fields.map { field =>
113162
field.fieldAlias.map(_.alias).getOrElse(field.identifier.name)
@@ -125,17 +174,21 @@ trait SQLQueryValidator {
125174
}.toMap
126175
}
127176

128-
private def validateFields(c: blackbox.Context)(
177+
// ============================================================
178+
// ✅ VALIDATION 1: Query fields must exist in case class
179+
// ============================================================
180+
@deprecated
181+
private def validateQueryFieldsExist(c: blackbox.Context)(
129182
queryFields: Set[String],
130183
caseClassFields: Map[String, c.universe.Type],
131184
tpe: c.universe.Type
132185
): Unit = {
133-
val missingFields = caseClassFields.keySet -- queryFields
186+
val unknownFields = queryFields.filterNot(f => caseClassFields.contains(f))
134187

135-
if (missingFields.nonEmpty) {
188+
if (unknownFields.nonEmpty) {
136189
val availableFields = caseClassFields.keys.toSeq.sorted.mkString(", ")
137-
val suggestions = missingFields.flatMap { missing =>
138-
findClosestMatch(missing, caseClassFields.keys.toSeq)
190+
val suggestions = unknownFields.flatMap { unknown =>
191+
findClosestMatch(unknown, caseClassFields.keys.toSeq)
139192
}
140193

141194
val suggestionMsg = if (suggestions.nonEmpty) {
@@ -144,13 +197,112 @@ trait SQLQueryValidator {
144197

145198
c.abort(
146199
c.enclosingPosition,
147-
s"❌ SQL case class fields in ${tpe.typeSymbol.name} not present in ${queryFields.mkString(",")}: " +
148-
s"${missingFields.mkString(", ")}\n" +
200+
s"❌ SQL query selects fields not present in ${tpe.typeSymbol.name}: " +
201+
s"${unknownFields.mkString(", ")}\n" +
149202
s"Available fields: $availableFields$suggestionMsg"
150203
)
151204
}
205+
206+
c.echo(c.enclosingPosition, "✅ All query fields exist in case class")
152207
}
153208

209+
// ============================================================
210+
// ✅ VALIDATION 2: Missing fields must have defaults or be Option
211+
// ============================================================
212+
private def validateMissingFieldsHaveDefaults(c: blackbox.Context)(
213+
queryFields: Set[String],
214+
caseClassFields: Map[String, c.universe.Type],
215+
tpe: c.universe.Type
216+
): Unit = {
217+
import c.universe._
218+
219+
val missingFields = caseClassFields.keySet -- queryFields
220+
221+
if (missingFields.isEmpty) {
222+
c.echo(c.enclosingPosition, "✅ No missing fields to validate")
223+
return
224+
}
225+
226+
c.echo(c.enclosingPosition, s"⚠️ Missing fields: ${missingFields.mkString(", ")}")
227+
228+
// Get constructor parameters with their positions
229+
val constructor = tpe.decl(termNames.CONSTRUCTOR).asMethod
230+
val params = constructor.paramLists.flatten
231+
232+
// Build map: fieldName -> (index, hasDefault, isOption)
233+
val fieldInfo = params.zipWithIndex.map { case (param, idx) =>
234+
val fieldName = param.name.toString
235+
val fieldType = param.typeSignature
236+
237+
// Check if Option
238+
val isOption = fieldType.typeConstructor =:= typeOf[Option[_]].typeConstructor
239+
240+
// Check if has default value
241+
val companionSymbol = tpe.typeSymbol.companion
242+
val hasDefault = if (companionSymbol != NoSymbol) {
243+
val companionType = companionSymbol.typeSignature
244+
val defaultMethodName = s"apply$$default$$${idx + 1}"
245+
companionType.member(TermName(defaultMethodName)) != NoSymbol
246+
} else {
247+
false
248+
}
249+
250+
(fieldName, (idx, hasDefault, isOption))
251+
}.toMap
252+
253+
// Check each missing field
254+
val fieldsWithoutDefaults = missingFields.filterNot { fieldName =>
255+
fieldInfo.get(fieldName) match {
256+
case Some((_, hasDefault, isOption)) =>
257+
if (isOption) {
258+
c.echo(c.enclosingPosition, s"✅ Field '$fieldName' is Option - OK")
259+
true
260+
} else if (hasDefault) {
261+
c.echo(c.enclosingPosition, s"✅ Field '$fieldName' has default value - OK")
262+
true
263+
} else {
264+
c.echo(c.enclosingPosition, s"❌ Field '$fieldName' has NO default and is NOT Option")
265+
false
266+
}
267+
case None =>
268+
c.echo(c.enclosingPosition, s"⚠️ Field '$fieldName' not found in constructor")
269+
false
270+
}
271+
}
272+
273+
if (fieldsWithoutDefaults.nonEmpty) {
274+
c.abort(
275+
c.enclosingPosition,
276+
s"❌ SQL query does not select the following required fields from ${tpe.typeSymbol.name}:\n" +
277+
s" ${fieldsWithoutDefaults.mkString(", ")}\n\n" +
278+
s"These fields are missing from the query:\n" +
279+
s" SELECT ${queryFields.mkString(", ")} FROM ...\n\n" +
280+
s"To fix this, either:\n" +
281+
s" 1. Add them to the SELECT clause\n" +
282+
s" 2. Make them Option[T] in the case class\n" +
283+
s" 3. Provide default values in the case class definition"
284+
)
285+
}
286+
287+
c.echo(c.enclosingPosition, "✅ All missing fields have defaults or are Option")
288+
}
289+
290+
// Helper: Get the index of a field in the case class constructor
291+
private def getFieldIndex(c: blackbox.Context)(
292+
tpe: c.universe.Type,
293+
fieldName: String
294+
): Int = {
295+
import c.universe._
296+
297+
val constructor = tpe.decl(termNames.CONSTRUCTOR).asMethod
298+
val params = constructor.paramLists.flatten
299+
300+
params.indexWhere(_.name.toString == fieldName)
301+
}
302+
303+
// ============================================================
304+
// VALIDATION 3: Type compatibility
305+
// ============================================================
154306
private def validateTypes(c: blackbox.Context)(
155307
parsedQuery: SQLSearchRequest,
156308
caseClassFields: Map[String, c.universe.Type]
@@ -172,6 +324,8 @@ trait SQLQueryValidator {
172324
case _ => // Cannot validate without type info
173325
}
174326
}
327+
328+
c.echo(c.enclosingPosition, "✅ Type validation passed")
175329
}
176330

177331
private def areTypesCompatible(c: blackbox.Context)(

0 commit comments

Comments
 (0)