Skip to content

Commit 6e50bb1

Browse files
committed
add support for bucket script
1 parent 313b82a commit 6e50bb1

File tree

13 files changed

+391
-116
lines changed

13 files changed

+391
-116
lines changed

bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala

Lines changed: 90 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ import app.softnetwork.elastic.sql.query.{
2828
MetricSelectorScript,
2929
NestedElement,
3030
NestedElements,
31+
SQLAggregation,
3132
SortOrder
3233
}
3334
import app.softnetwork.elastic.sql.function._
@@ -36,6 +37,7 @@ import app.softnetwork.elastic.sql.function.time.DateTrunc
3637
import app.softnetwork.elastic.sql.time.TimeUnit
3738
import com.sksamuel.elastic4s.ElasticApi.{
3839
avgAgg,
40+
bucketScriptAggregation,
3941
bucketSelectorAggregation,
4042
cardinalityAgg,
4143
maxAgg,
@@ -49,6 +51,7 @@ import com.sksamuel.elastic4s.ElasticApi.{
4951
import com.sksamuel.elastic4s.requests.script.Script
5052
import com.sksamuel.elastic4s.requests.searches.DateHistogramInterval
5153
import com.sksamuel.elastic4s.requests.searches.aggs.{
54+
AbstractAggregation,
5255
Aggregation,
5356
CardinalityAggregation,
5457
DateHistogramAggregation,
@@ -74,7 +77,7 @@ case class ElasticAggregation(
7477
nestedAgg: Option[NestedAggregation] = None,
7578
filteredAgg: Option[FilterAggregation] = None,
7679
aggType: AggregateFunction,
77-
agg: Aggregation,
80+
agg: AbstractAggregation,
7881
direction: Option[SortOrder] = None,
7982
nestedElement: Option[NestedElement] = None
8083
) {
@@ -94,7 +97,8 @@ object ElasticAggregation {
9497
def apply(
9598
sqlAgg: Field,
9699
having: Option[Criteria],
97-
bucketsDirection: Map[String, SortOrder]
100+
bucketsDirection: Map[String, SortOrder],
101+
allAggregations: Map[String, SQLAggregation]
98102
): ElasticAggregation = {
99103
import sqlAgg._
100104
val sourceField = identifier.path
@@ -111,9 +115,14 @@ object ElasticAggregation {
111115

112116
val distinct = identifier.distinct
113117

114-
val aggType = aggregateFunction.getOrElse(
115-
throw new IllegalArgumentException("Aggregation function is required")
116-
)
118+
val aggType = {
119+
if (isBucketScript) {
120+
BucketScriptAggregation(identifier)
121+
} else
122+
aggregateFunction.getOrElse(
123+
throw new IllegalArgumentException("Aggregation function is required")
124+
)
125+
}
117126

118127
val aggName = {
119128
if (fieldAlias.isDefined)
@@ -135,7 +144,8 @@ object ElasticAggregation {
135144

136145
val (aggFuncs, transformFuncs) = FunctionUtils.aggregateAndTransformFunctions(identifier)
137146

138-
require(aggFuncs.size == 1, s"Multiple aggregate functions not supported: $aggFuncs")
147+
if (!isBucketScript)
148+
require(aggFuncs.size == 1, s"Multiple aggregate functions not supported: $aggFuncs")
139149

140150
def aggWithFieldOrScript(
141151
buildField: (String, String) => Aggregation,
@@ -171,9 +181,8 @@ object ElasticAggregation {
171181
case th: WindowFunction =>
172182
val limit = {
173183
th match {
174-
case _: LastValue => 1
175-
// case _: FirstValue => 1
176-
case _ => th.limit.map(_.limit).getOrElse(1)
184+
case _: LastValue | _: FirstValue => Some(1)
185+
case _ => th.limit.map(_.limit)
177186
}
178187
}
179188
val topHits =
@@ -193,9 +202,9 @@ object ElasticAggregation {
193202
.groupBy(_.sourceField)
194203
.map(_._2.head)
195204
.map(f => f.sourceField -> Script(f.painless(None)).lang("painless"))
196-
.toMap
197-
)
198-
.size(limit) sortBy th.orderBy.sorts.map(sort =>
205+
.toMap,
206+
size = limit
207+
) sortBy th.orderBy.sorts.map(sort =>
199208
sort.order match {
200209
case Some(Desc) =>
201210
th.window match {
@@ -209,10 +218,25 @@ object ElasticAggregation {
209218
}
210219
}
211220
)
212-
/*th.fields.filter(_.isScriptField).foldLeft(topHits) { (agg, f) =>
213-
agg.script(f.sourceField, Script(f.painless, lang = Some("painless")))
214-
}*/
215221
topHits
222+
case script: BucketScriptAggregation =>
223+
val params = allAggregations.get(aggName) match {
224+
case Some(sqlAgg) =>
225+
sqlAgg.aggType match {
226+
case bsa: BucketScriptAggregation =>
227+
extractMetricsPathForBucketScript(bsa, allAggregations.values.toSeq)
228+
case _ => Map.empty
229+
}
230+
case None => Map.empty
231+
}
232+
val painless = script.identifier.painless(None)
233+
bucketScriptAggregation(
234+
aggName,
235+
Script(s"$painless").lang("painless"),
236+
params.toMap
237+
)
238+
case _ =>
239+
throw new IllegalArgumentException(s"Unsupported aggregation type: $aggType")
216240
}
217241

218242
val nestedElement = identifier.nestedElement
@@ -276,7 +300,7 @@ object ElasticAggregation {
276300
def buildBuckets(
277301
buckets: Seq[Bucket],
278302
bucketsDirection: Map[String, SortOrder],
279-
aggregations: Seq[Aggregation],
303+
aggregations: Seq[AbstractAggregation],
280304
aggregationsDirection: Map[String, SortOrder],
281305
having: Option[Criteria],
282306
nested: Option[NestedElement],
@@ -287,7 +311,7 @@ object ElasticAggregation {
287311
val currentBucketPath = bucket.identifier.path
288312

289313
val aggScript =
290-
if (bucket.shouldBeScripted) {
314+
if (!bucket.isBucketScript && bucket.shouldBeScripted) {
291315
val context = PainlessContext()
292316
val painless = bucket.painless(Some(context))
293317
Some(Script(s"$context$painless").lang("painless"))
@@ -579,6 +603,54 @@ object ElasticAggregation {
579603
}
580604
}
581605

606+
def extractMetricsPathForBucketScript(
607+
bucketScriptAggregation: BucketScriptAggregation,
608+
allAggregations: Seq[SQLAggregation]
609+
): Map[String, String] = {
610+
val currentBucketPath =
611+
bucketScriptAggregation.identifier.nestedElement.map(_.bucketPath).getOrElse("")
612+
// Extract ALL metrics paths
613+
val allMetricsPaths = bucketScriptAggregation.params.keys
614+
val result =
615+
allMetricsPaths.flatMap { metricName =>
616+
allAggregations.find(agg => agg.aggName == metricName || agg.field == metricName) match {
617+
case Some(sqlAgg) =>
618+
val metricBucketPath = sqlAgg.nestedElement
619+
.map(_.bucketPath)
620+
.getOrElse("")
621+
if (metricBucketPath == currentBucketPath) {
622+
// Metric of the same level
623+
Some(metricName -> metricName)
624+
} else if (isDirectChild(metricBucketPath, currentBucketPath)) {
625+
// Metric of a direct child
626+
// CHECK if it is a "global" metric (cardinality, etc.) or a bucket metric (avg, sum, etc.)
627+
val isGlobalMetric = sqlAgg.isGlobalMetric
628+
629+
if (isGlobalMetric) {
630+
// Global metric: can be referenced from the parent
631+
val childNestedName = sqlAgg.nestedElement
632+
.map(_.innerHitsName)
633+
.getOrElse("")
634+
// println(
635+
// s"[DEBUG extractMetricsPath] Direct child (global metric): $metricName -> $childNestedName>$metricName"
636+
// )
637+
Some(metricName -> s"$childNestedName>$metricName")
638+
} else {
639+
// Bucket metric: cannot be referenced from the parent
640+
// println(
641+
// s"[DEBUG extractMetricsPath] Direct child (bucket metric): $metricName -> SKIP (bucket-level metric)"
642+
// )
643+
None
644+
}
645+
} else {
646+
None
647+
}
648+
case _ => None
649+
}
650+
}
651+
result.toMap
652+
}
653+
582654
/** Extracts the buckets_path for a given bucket
583655
*/
584656
def extractMetricsPathForBucket(
@@ -596,7 +668,7 @@ object ElasticAggregation {
596668
// println(s"[DEBUG extractMetricsPath] allMetricsPaths = $allMetricsPaths")
597669

598670
// Filter and adapt the paths for this bucket
599-
val result = allMetricsPaths.flatMap { case (metricName, metricPath) =>
671+
val result = allMetricsPaths.flatMap { case (metricName, _) =>
600672
allElasticAggregations.find(agg =>
601673
agg.aggName == metricName || agg.field == metricName
602674
) match {

bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticSearchRequest.scala

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@
1616

1717
package app.softnetwork.elastic.sql.bridge
1818

19-
import app.softnetwork.elastic.sql.query.{Bucket, Criteria, Except, Field}
19+
import app.softnetwork.elastic.sql.query.{Bucket, Criteria, Except, Field, FieldSort}
2020
import com.sksamuel.elastic4s.requests.searches.{SearchBodyBuilderFn, SearchRequest}
2121

2222
case class ElasticSearchRequest(
23+
sql: String,
2324
fields: Seq[Field],
2425
except: Option[Except],
2526
sources: Seq[String],
@@ -28,7 +29,8 @@ case class ElasticSearchRequest(
2829
offset: Option[Int],
2930
search: SearchRequest,
3031
buckets: Seq[Bucket] = Seq.empty,
31-
aggregations: Seq[ElasticAggregation] = Seq.empty
32+
having: Option[Criteria] = None,
33+
sorts: Seq[FieldSort] = Seq.empty
3234
) {
3335
def minScore(score: Option[Double]): ElasticSearchRequest = {
3436
score match {

bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/package.scala

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ import com.sksamuel.elastic4s.requests.common.FetchSourceContext
3333
import com.sksamuel.elastic4s.requests.script.Script
3434
import com.sksamuel.elastic4s.requests.script.ScriptType.Source
3535
import com.sksamuel.elastic4s.requests.searches.aggs.{
36-
Aggregation,
36+
AbstractAggregation,
3737
FilterAggregation,
3838
NestedAggregation,
3939
TermsAggregation
@@ -148,7 +148,7 @@ package object bridge {
148148
implicit def requestToRootAggregations(
149149
request: SQLSearchRequest,
150150
aggregations: Seq[ElasticAggregation]
151-
): Seq[Aggregation] = {
151+
): Seq[AbstractAggregation] = {
152152
val notNestedAggregations = aggregations.filterNot(_.nested)
153153

154154
val notNestedBuckets = request.buckets.filterNot(_.nested)
@@ -263,7 +263,7 @@ package object bridge {
263263
requestToNestedFilterAggregation(request, n.innerHitsName)
264264

265265
// Build buckets for this nested aggregation
266-
val buckets: Seq[Aggregation] =
266+
val buckets: Seq[AbstractAggregation] =
267267
ElasticAggregation.buildBuckets(
268268
nestedBuckets,
269269
request.sorts -- directions.keys,
@@ -379,7 +379,7 @@ package object bridge {
379379
}
380380

381381
private def addNestedAggregationsToTermsAggregation(
382-
agg: Aggregation,
382+
agg: AbstractAggregation,
383383
nested: Seq[NestedAggregation]
384384
): Option[TermsAggregation] = {
385385
agg match {
@@ -403,24 +403,29 @@ package object bridge {
403403

404404
implicit def requestToElasticSearchRequest(request: SQLSearchRequest): ElasticSearchRequest =
405405
ElasticSearchRequest(
406+
request.sql,
406407
request.select.fields,
407408
request.select.except,
408409
request.sources,
409410
request.where.flatMap(_.criteria),
410411
request.limit.map(_.limit),
411-
request.limit.flatMap(_.offset.map(_.offset)),
412+
request.limit.flatMap(_.offset.map(_.offset)).orElse(Some(0)),
412413
request,
413414
request.buckets,
414-
request.aggregates.map(
415-
ElasticAggregation(_, request.having.flatMap(_.criteria), request.sorts)
416-
)
415+
request.having.flatMap(_.criteria),
416+
request.orderBy.map(_.sorts).getOrElse(Seq.empty)
417417
).minScore(request.score)
418418

419419
implicit def requestToSearchRequest(request: SQLSearchRequest): SearchRequest = {
420420
import request._
421421

422422
val aggregations = request.aggregates.map(
423-
ElasticAggregation(_, request.having.flatMap(_.criteria), request.sorts)
423+
ElasticAggregation(
424+
_,
425+
request.having.flatMap(_.criteria),
426+
request.sorts,
427+
request.sqlAggregations
428+
)
424429
)
425430

426431
val rootAggregations = requestToRootAggregations(request, aggregations)
@@ -990,7 +995,7 @@ package object bridge {
990995
case Left(l) =>
991996
val filteredAgg: Option[FilterAggregation] = requestToFilterAggregation(l)
992997
l.aggregates
993-
.map(ElasticAggregation(_, l.having.flatMap(_.criteria), l.sorts))
998+
.map(ElasticAggregation(_, l.having.flatMap(_.criteria), l.sorts, l.sqlAggregations))
994999
.map(aggregation => {
9951000
val queryFiltered =
9961001
l.where

0 commit comments

Comments
 (0)