Skip to content

Commit 9533843

Browse files
committed
add support for count, min, max, avg and sum with partitions
1 parent 4d49bad commit 9533843

File tree

11 files changed

+1060
-179
lines changed

11 files changed

+1060
-179
lines changed

bridge/src/main/scala/app/softnetwork/elastic/sql/bridge/ElasticAggregation.scala

Lines changed: 66 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -181,46 +181,73 @@ object ElasticAggregation {
181181
case AVG => aggWithFieldOrScript(avgAgg, (name, s) => avgAgg(name, sourceField).script(s))
182182
case SUM => aggWithFieldOrScript(sumAgg, (name, s) => sumAgg(name, sourceField).script(s))
183183
case th: WindowFunction =>
184-
val limit = {
185-
th match {
186-
case _: LastValue | _: FirstValue => Some(1)
187-
case _ => th.limit.map(_.limit)
188-
}
189-
}
190-
val topHits =
191-
topHitsAgg(aggName)
192-
.fetchSource(
193-
th.identifier.name +: th.fields
194-
.filterNot(_.isScriptField)
195-
.filterNot(_.sourceField == th.identifier.name)
196-
.map(_.sourceField)
197-
.distinct
198-
.toArray,
199-
Array.empty
200-
)
201-
.copy(
202-
scripts = th.fields
203-
.filter(_.isScriptField)
204-
.groupBy(_.sourceField)
205-
.map(_._2.head)
206-
.map(f => f.sourceField -> Script(f.painless(None)).lang("painless"))
207-
.toMap,
208-
size = limit
209-
) sortBy th.orderBy.sorts.map(sort =>
210-
sort.order match {
211-
case Some(Desc) =>
212-
th.window match {
213-
case LAST_VALUE => FieldSort(sort.field.name).asc()
214-
case _ => FieldSort(sort.field.name).desc()
215-
}
216-
case _ =>
217-
th.window match {
218-
case LAST_VALUE => FieldSort(sort.field.name).desc()
219-
case _ => FieldSort(sort.field.name).asc()
220-
}
184+
th.window match {
185+
case COUNT =>
186+
val field =
187+
sourceField match {
188+
case "*" | "_id" | "_index" | "_type" => "_index"
189+
case _ => sourceField
190+
}
191+
if (distinct)
192+
cardinalityAgg(aggName, field)
193+
else {
194+
valueCountAgg(aggName, field)
221195
}
222-
)
223-
topHits
196+
case MIN =>
197+
aggWithFieldOrScript(minAgg, (name, s) => minAgg(name, sourceField).script(s))
198+
case MAX =>
199+
aggWithFieldOrScript(maxAgg, (name, s) => maxAgg(name, sourceField).script(s))
200+
case AVG =>
201+
aggWithFieldOrScript(avgAgg, (name, s) => avgAgg(name, sourceField).script(s))
202+
case SUM =>
203+
aggWithFieldOrScript(sumAgg, (name, s) => sumAgg(name, sourceField).script(s))
204+
case _ =>
205+
val limit = {
206+
th match {
207+
case _: LastValue | _: FirstValue => Some(1)
208+
case _ => th.limit.map(_.limit)
209+
}
210+
}
211+
val topHits =
212+
topHitsAgg(aggName)
213+
.fetchSource(
214+
th.identifier.name +: th.fields
215+
.filterNot(_.isScriptField)
216+
.filterNot(_.sourceField == th.identifier.name)
217+
.map(_.sourceField)
218+
.distinct
219+
.toArray,
220+
Array.empty
221+
)
222+
.copy(
223+
scripts = th.fields
224+
.filter(_.isScriptField)
225+
.groupBy(_.sourceField)
226+
.map(_._2.head)
227+
.map(f => f.sourceField -> Script(f.painless(None)).lang("painless"))
228+
.toMap,
229+
size = limit,
230+
sorts = th.orderBy
231+
.map(
232+
_.sorts.map(sort =>
233+
sort.order match {
234+
case Some(Desc) =>
235+
th.window match {
236+
case LAST_VALUE => FieldSort(sort.field.name).asc()
237+
case _ => FieldSort(sort.field.name).desc()
238+
}
239+
case _ =>
240+
th.window match {
241+
case LAST_VALUE => FieldSort(sort.field.name).desc()
242+
case _ => FieldSort(sort.field.name).asc()
243+
}
244+
}
245+
)
246+
)
247+
.getOrElse(Seq.empty)
248+
)
249+
topHits
250+
}
224251
case script: BucketScriptAggregation =>
225252
val params = allAggregations.get(aggName) match {
226253
case Some(sqlAgg) =>

core/src/main/scala/app/softnetwork/elastic/client/ElasticConversion.scala

Lines changed: 101 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -148,11 +148,27 @@ trait ElasticConversion {
148148

149149
case (None, Some(aggs)) =>
150150
// Case 2 : only aggregations
151-
parseAggregations(aggs, Map.empty, fieldAliases, aggregations)
151+
val ret = parseAggregations(aggs, Map.empty, fieldAliases, aggregations)
152+
val groupedRows: Map[String, Seq[Map[String, Any]]] =
153+
ret.groupBy(_.getOrElse("bucket_root", "").toString)
154+
groupedRows.values.foldLeft(Seq(Map.empty[String, Any])) { (acc, group) =>
155+
for {
156+
accMap <- acc
157+
groupMap <- group
158+
} yield accMap ++ groupMap
159+
}
152160

153161
case (Some(hits), Some(aggs)) if hits.isEmpty =>
154162
// Case 3 : aggregations with no hits
155-
parseAggregations(aggs, Map.empty, fieldAliases, aggregations)
163+
val ret = parseAggregations(aggs, Map.empty, fieldAliases, aggregations)
164+
val groupedRows: Map[String, Seq[Map[String, Any]]] =
165+
ret.groupBy(_.getOrElse("bucket_root", "").toString)
166+
groupedRows.values.foldLeft(Seq(Map.empty[String, Any])) { (acc, group) =>
167+
for {
168+
accMap <- acc
169+
groupMap <- group
170+
} yield accMap ++ groupMap
171+
}
156172

157173
case (Some(hits), Some(aggs)) if hits.nonEmpty =>
158174
// Case 4 : Hits + global aggregations + top_hits aggregations
@@ -355,7 +371,7 @@ trait ElasticConversion {
355371
}
356372
} else if (bucketAggs.isEmpty) {
357373
// No buckets : it is a leaf aggregation (metrics or top_hits)
358-
val metrics = extractMetrics(aggsNode)
374+
val metrics = extractMetrics(aggsNode, aggregations)
359375
val allTopHits = extractAllTopHits(aggsNode, fieldAliases, aggregations)
360376

361377
if (allTopHits.nonEmpty) {
@@ -369,6 +385,7 @@ trait ElasticConversion {
369385
// Handle each aggregation with buckets
370386
bucketAggs.flatMap { case (aggName, buckets, _) =>
371387
buckets.flatMap { bucket =>
388+
val metrics = extractMetrics(bucket, aggregations)
372389
val allTopHits = extractAllTopHits(bucket, fieldAliases, aggregations)
373390

374391
val bucketKey = extractBucketKey(bucket)
@@ -379,7 +396,7 @@ trait ElasticConversion {
379396
val currentContext = parentContext ++ Map(
380397
aggName -> bucketKey,
381398
s"${aggName}_doc_count" -> docCount
382-
) ++ allTopHits
399+
) ++ metrics ++ allTopHits
383400

384401
// Check for sub-aggregations
385402
val subAggFields = bucket
@@ -468,62 +485,76 @@ trait ElasticConversion {
468485

469486
/** Extract metrics from an aggregation node
470487
*/
471-
def extractMetrics(aggsNode: JsonNode): Map[String, Any] = {
488+
def extractMetrics(
489+
aggsNode: JsonNode,
490+
aggregations: Map[String, ClientAggregation]
491+
): Map[String, Any] = {
472492
if (!aggsNode.isObject) return Map.empty
473-
aggsNode
474-
.properties()
475-
.asScala
476-
.flatMap { entry =>
477-
val name = normalizeAggregationKey(entry.getKey)
478-
val value = entry.getValue
479-
480-
// Detect simple metric values
481-
Option(value.get("value"))
482-
.filter(!_.isNull)
483-
.map { metricValue =>
484-
val numericValue = if (metricValue.isIntegralNumber) {
485-
metricValue.asLong()
486-
} else if (metricValue.isFloatingPointNumber) {
487-
metricValue.asDouble()
488-
} else {
489-
metricValue.asText()
490-
}
491-
name -> numericValue
493+
var bucketRoot: Option[String] = None
494+
val metrics =
495+
aggsNode
496+
.properties()
497+
.asScala
498+
.flatMap { entry =>
499+
val name = normalizeAggregationKey(entry.getKey)
500+
aggregations.get(name) match {
501+
case Some(agg) =>
502+
bucketRoot = Some(agg.bucketRoot)
503+
case _ =>
492504
}
493-
.orElse {
494-
// Stats aggregations
495-
if (value.has("count") && value.has("sum") && value.has("avg")) {
496-
Some(
497-
name -> Map(
498-
"count" -> value.get("count").asLong(),
499-
"sum" -> Option(value.get("sum")).filterNot(_.isNull).map(_.asDouble()),
500-
"avg" -> Option(value.get("avg")).filterNot(_.isNull).map(_.asDouble()),
501-
"min" -> Option(value.get("min")).filterNot(_.isNull).map(_.asDouble()),
502-
"max" -> Option(value.get("max")).filterNot(_.isNull).map(_.asDouble())
503-
).collect { case (k, Some(v)) => k -> v; case (k, v: Long) => k -> v }
504-
)
505-
} else {
506-
None
505+
val value = entry.getValue
506+
507+
// Detect simple metric values
508+
Option(value.get("value"))
509+
.filter(!_.isNull)
510+
.map { metricValue =>
511+
val numericValue = if (metricValue.isIntegralNumber) {
512+
metricValue.asLong()
513+
} else if (metricValue.isFloatingPointNumber) {
514+
metricValue.asDouble()
515+
} else {
516+
metricValue.asText()
517+
}
518+
name -> numericValue
507519
}
508-
}
509-
.orElse {
510-
// Percentiles
511-
if (value.has("values") && value.get("values").isObject) {
512-
val percentiles = value
513-
.get("values")
514-
.properties()
515-
.asScala
516-
.map { pEntry =>
517-
pEntry.getKey -> pEntry.getValue.asDouble()
518-
}
519-
.toMap
520-
Some(name -> percentiles)
521-
} else {
522-
None
520+
.orElse {
521+
// Stats aggregations
522+
if (value.has("count") && value.has("sum") && value.has("avg")) {
523+
Some(
524+
name -> Map(
525+
"count" -> value.get("count").asLong(),
526+
"sum" -> Option(value.get("sum")).filterNot(_.isNull).map(_.asDouble()),
527+
"avg" -> Option(value.get("avg")).filterNot(_.isNull).map(_.asDouble()),
528+
"min" -> Option(value.get("min")).filterNot(_.isNull).map(_.asDouble()),
529+
"max" -> Option(value.get("max")).filterNot(_.isNull).map(_.asDouble())
530+
).collect { case (k, Some(v)) => k -> v; case (k, v: Long) => k -> v }
531+
)
532+
} else {
533+
None
534+
}
523535
}
524-
}
525-
}
526-
.toMap
536+
.orElse {
537+
// Percentiles
538+
if (value.has("values") && value.get("values").isObject) {
539+
val percentiles = value
540+
.get("values")
541+
.properties()
542+
.asScala
543+
.map { pEntry =>
544+
pEntry.getKey -> pEntry.getValue.asDouble()
545+
}
546+
.toMap
547+
Some(name -> percentiles)
548+
} else {
549+
None
550+
}
551+
}
552+
}
553+
.toMap
554+
bucketRoot match {
555+
case Some(root) => metrics + ("bucket_root" -> root)
556+
case None => metrics
557+
}
527558
}
528559

529560
/** Extract all top_hits aggregations with their names and hits */
@@ -533,6 +564,7 @@ trait ElasticConversion {
533564
aggregations: Map[String, ClientAggregation]
534565
): Map[String, Any] = {
535566
if (!aggsNode.isObject) return Map.empty
567+
var bucketRoot: Option[String] = None
536568
val allTopHits =
537569
aggsNode
538570
.properties()
@@ -553,13 +585,20 @@ trait ElasticConversion {
553585
// Process each top_hits aggregation with their names
554586
val row = allTopHits.map { case (topHitName, hits) =>
555587
// Determine if it is a multivalued aggregation (array_agg, ...)
556-
val hasMultipleValues = aggregations.get(topHitName) match {
588+
val agg = aggregations.get(topHitName)
589+
val hasMultipleValues = agg match {
557590
case Some(agg) => agg.multivalued
558591
case None =>
559592
// Fallback on naming convention if aggregation is not found
560593
!topHitName.toLowerCase.matches("(first|last)_.*")
561594
}
562595

596+
agg match {
597+
case Some(agg) =>
598+
bucketRoot = Some(agg.bucketRoot)
599+
case _ =>
600+
}
601+
563602
val processedHits = hits.map { hit =>
564603
val source = extractSource(hit, fieldAliases)
565604
if (hasMultipleValues) {
@@ -582,7 +621,7 @@ trait ElasticConversion {
582621
} else {
583622
val metadata = extractHitMetadata(hit)
584623
val innerHits = extractInnerHits(hit, fieldAliases)
585-
source ++ metadata ++ innerHits
624+
source ++ metadata ++ innerHits ++ Map("bucket_root" -> bucketRoot)
586625
}
587626
}
588627

@@ -600,7 +639,10 @@ trait ElasticConversion {
600639
}
601640
}
602641

603-
row
642+
bucketRoot match {
643+
case Some(root) => row + ("bucket_root" -> root)
644+
case None => row
645+
}
604646
}
605647

606648
/** Extract global metrics from aggregations (for hits + aggs case)

core/src/main/scala/app/softnetwork/elastic/client/package.scala

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,8 @@ package object client extends SerializationApi {
171171
distinct: Boolean,
172172
sourceField: String,
173173
windowing: Boolean,
174-
bucketPath: String
174+
bucketPath: String,
175+
bucketRoot: String
175176
) {
176177
def multivalued: Boolean = aggType == AggregationType.ArrayAgg
177178
def singleValued: Boolean = !multivalued
@@ -187,6 +188,11 @@ package object client extends SerializationApi {
187188
case _: FirstValue => AggregationType.FirstValue
188189
case _: LastValue => AggregationType.LastValue
189190
case _: ArrayAgg => AggregationType.ArrayAgg
191+
case _: CountAgg => AggregationType.Count
192+
case _: MinAgg => AggregationType.Min
193+
case _: MaxAgg => AggregationType.Max
194+
case _: AvgAgg => AggregationType.Avg
195+
case _: SumAgg => AggregationType.Sum
190196
case _ => throw new IllegalArgumentException(s"Unsupported aggregation type: ${agg.aggType}")
191197
}
192198
ClientAggregation(
@@ -195,7 +201,8 @@ package object client extends SerializationApi {
195201
agg.distinct,
196202
agg.sourceField,
197203
agg.aggType.isWindowing,
198-
agg.bucketPath
204+
agg.bucketPath,
205+
agg.bucketRoot
199206
)
200207
}
201208
}

0 commit comments

Comments
 (0)