@@ -28,6 +28,7 @@ import app.softnetwork.elastic.sql.query.{
2828 MetricSelectorScript ,
2929 NestedElement ,
3030 NestedElements ,
31+ SQLAggregation ,
3132 SortOrder
3233}
3334import app .softnetwork .elastic .sql .function ._
@@ -36,6 +37,7 @@ import app.softnetwork.elastic.sql.function.time.DateTrunc
3637import app .softnetwork .elastic .sql .time .TimeUnit
3738import com .sksamuel .elastic4s .ElasticApi .{
3839 avgAgg ,
40+ bucketScriptAggregation ,
3941 bucketSelectorAggregation ,
4042 cardinalityAgg ,
4143 maxAgg ,
@@ -49,6 +51,7 @@ import com.sksamuel.elastic4s.ElasticApi.{
4951import com .sksamuel .elastic4s .requests .script .Script
5052import com .sksamuel .elastic4s .requests .searches .DateHistogramInterval
5153import com .sksamuel .elastic4s .requests .searches .aggs .{
54+ AbstractAggregation ,
5255 Aggregation ,
5356 CardinalityAggregation ,
5457 DateHistogramAggregation ,
@@ -74,7 +77,7 @@ case class ElasticAggregation(
7477 nestedAgg : Option [NestedAggregation ] = None ,
7578 filteredAgg : Option [FilterAggregation ] = None ,
7679 aggType : AggregateFunction ,
77- agg : Aggregation ,
80+ agg : AbstractAggregation ,
7881 direction : Option [SortOrder ] = None ,
7982 nestedElement : Option [NestedElement ] = None
8083) {
@@ -94,7 +97,8 @@ object ElasticAggregation {
9497 def apply (
9598 sqlAgg : Field ,
9699 having : Option [Criteria ],
97- bucketsDirection : Map [String , SortOrder ]
100+ bucketsDirection : Map [String , SortOrder ],
101+ allAggregations : Map [String , SQLAggregation ]
98102 ): ElasticAggregation = {
99103 import sqlAgg ._
100104 val sourceField = identifier.path
@@ -111,9 +115,14 @@ object ElasticAggregation {
111115
112116 val distinct = identifier.distinct
113117
114- val aggType = aggregateFunction.getOrElse(
115- throw new IllegalArgumentException (" Aggregation function is required" )
116- )
118+ val aggType = {
119+ if (isBucketScript) {
120+ BucketScriptAggregation (identifier)
121+ } else
122+ aggregateFunction.getOrElse(
123+ throw new IllegalArgumentException (" Aggregation function is required" )
124+ )
125+ }
117126
118127 val aggName = {
119128 if (fieldAlias.isDefined)
@@ -135,7 +144,8 @@ object ElasticAggregation {
135144
136145 val (aggFuncs, transformFuncs) = FunctionUtils .aggregateAndTransformFunctions(identifier)
137146
138- require(aggFuncs.size == 1 , s " Multiple aggregate functions not supported: $aggFuncs" )
147+ if (! isBucketScript)
148+ require(aggFuncs.size == 1 , s " Multiple aggregate functions not supported: $aggFuncs" )
139149
140150 def aggWithFieldOrScript (
141151 buildField : (String , String ) => Aggregation ,
@@ -171,9 +181,8 @@ object ElasticAggregation {
171181 case th : WindowFunction =>
172182 val limit = {
173183 th match {
174- case _ : LastValue => 1
175- // case _: FirstValue => 1
176- case _ => th.limit.map(_.limit).getOrElse(1 )
184+ case _ : LastValue | _ : FirstValue => Some (1 )
185+ case _ => th.limit.map(_.limit)
177186 }
178187 }
179188 val topHits =
@@ -193,9 +202,9 @@ object ElasticAggregation {
193202 .groupBy(_.sourceField)
194203 .map(_._2.head)
195204 .map(f => f.sourceField -> Script (f.painless(None )).lang(" painless" ))
196- .toMap
197- )
198- .size(limit ) sortBy th.orderBy.sorts.map(sort =>
205+ .toMap,
206+ size = limit
207+ ) sortBy th.orderBy.sorts.map(sort =>
199208 sort.order match {
200209 case Some (Desc ) =>
201210 th.window match {
@@ -209,10 +218,25 @@ object ElasticAggregation {
209218 }
210219 }
211220 )
212- /* th.fields.filter(_.isScriptField).foldLeft(topHits) { (agg, f) =>
213- agg.script(f.sourceField, Script(f.painless, lang = Some("painless")))
214- }*/
215221 topHits
222+ case script : BucketScriptAggregation =>
223+ val params = allAggregations.get(aggName) match {
224+ case Some (sqlAgg) =>
225+ sqlAgg.aggType match {
226+ case bsa : BucketScriptAggregation =>
227+ extractMetricsPathForBucketScript(bsa, allAggregations.values.toSeq)
228+ case _ => Map .empty
229+ }
230+ case None => Map .empty
231+ }
232+ val painless = script.identifier.painless(None )
233+ bucketScriptAggregation(
234+ aggName,
235+ Script (s " $painless" ).lang(" painless" ),
236+ params.toMap
237+ )
238+ case _ =>
239+ throw new IllegalArgumentException (s " Unsupported aggregation type: $aggType" )
216240 }
217241
218242 val nestedElement = identifier.nestedElement
@@ -276,7 +300,7 @@ object ElasticAggregation {
276300 def buildBuckets (
277301 buckets : Seq [Bucket ],
278302 bucketsDirection : Map [String , SortOrder ],
279- aggregations : Seq [Aggregation ],
303+ aggregations : Seq [AbstractAggregation ],
280304 aggregationsDirection : Map [String , SortOrder ],
281305 having : Option [Criteria ],
282306 nested : Option [NestedElement ],
@@ -287,7 +311,7 @@ object ElasticAggregation {
287311 val currentBucketPath = bucket.identifier.path
288312
289313 val aggScript =
290- if (bucket.shouldBeScripted) {
314+ if (! bucket.isBucketScript && bucket.shouldBeScripted) {
291315 val context = PainlessContext ()
292316 val painless = bucket.painless(Some (context))
293317 Some (Script (s " $context$painless" ).lang(" painless" ))
@@ -579,6 +603,54 @@ object ElasticAggregation {
579603 }
580604 }
581605
606+ def extractMetricsPathForBucketScript (
607+ bucketScriptAggregation : BucketScriptAggregation ,
608+ allAggregations : Seq [SQLAggregation ]
609+ ): Map [String , String ] = {
610+ val currentBucketPath =
611+ bucketScriptAggregation.identifier.nestedElement.map(_.bucketPath).getOrElse(" " )
612+ // Extract ALL metrics paths
613+ val allMetricsPaths = bucketScriptAggregation.params.keys
614+ val result =
615+ allMetricsPaths.flatMap { metricName =>
616+ allAggregations.find(agg => agg.aggName == metricName || agg.field == metricName) match {
617+ case Some (sqlAgg) =>
618+ val metricBucketPath = sqlAgg.nestedElement
619+ .map(_.bucketPath)
620+ .getOrElse(" " )
621+ if (metricBucketPath == currentBucketPath) {
622+ // Metric of the same level
623+ Some (metricName -> metricName)
624+ } else if (isDirectChild(metricBucketPath, currentBucketPath)) {
625+ // Metric of a direct child
626+ // CHECK if it is a "global" metric (cardinality, etc.) or a bucket metric (avg, sum, etc.)
627+ val isGlobalMetric = sqlAgg.isGlobalMetric
628+
629+ if (isGlobalMetric) {
630+ // Global metric: can be referenced from the parent
631+ val childNestedName = sqlAgg.nestedElement
632+ .map(_.innerHitsName)
633+ .getOrElse(" " )
634+ // println(
635+ // s"[DEBUG extractMetricsPath] Direct child (global metric): $metricName -> $childNestedName>$metricName"
636+ // )
637+ Some (metricName -> s " $childNestedName> $metricName" )
638+ } else {
639+ // Bucket metric: cannot be referenced from the parent
640+ // println(
641+ // s"[DEBUG extractMetricsPath] Direct child (bucket metric): $metricName -> SKIP (bucket-level metric)"
642+ // )
643+ None
644+ }
645+ } else {
646+ None
647+ }
648+ case _ => None
649+ }
650+ }
651+ result.toMap
652+ }
653+
582654 /** Extracts the buckets_path for a given bucket
583655 */
584656 def extractMetricsPathForBucket (
@@ -596,7 +668,7 @@ object ElasticAggregation {
596668 // println(s"[DEBUG extractMetricsPath] allMetricsPaths = $allMetricsPaths")
597669
598670 // Filter and adapt the paths for this bucket
599- val result = allMetricsPaths.flatMap { case (metricName, metricPath ) =>
671+ val result = allMetricsPaths.flatMap { case (metricName, _ ) =>
600672 allElasticAggregations.find(agg =>
601673 agg.aggName == metricName || agg.field == metricName
602674 ) match {
0 commit comments