Skip to content

Commit 3640d84

Browse files
wbtlbwangbo
andauthored
feat(sharding): support sum function (#388)
* fix(sharding): fix min max error Signed-off-by: wangbo <wangbo@sphere-ex.com> * feat(sharding): support sum Signed-off-by: wangbo <wangbo@sphere-ex.com> * feat(sharding): support sum Signed-off-by: wangbo <wangbo@sphere-ex.com> * chore(sharding): delete space Signed-off-by: wangbo <wangbo@sphere-ex.com> Signed-off-by: wangbo <wangbo@sphere-ex.com> Co-authored-by: wangbo <wangbo@sphere-ex.com>
1 parent 520e59f commit 3640d84

File tree

4 files changed

+64
-6
lines changed

4 files changed

+64
-6
lines changed

pisa-proxy/protocol/mysql/src/row.rs

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ pub trait RowData<T: AsRef<[u8]>> {
2828
fn get_row_data_with_name(&mut self, name: &str) -> value::Result<RowPartData>;
2929
}
3030

31-
#[derive(Clone)]
31+
#[derive(Clone, Debug)]
3232
pub enum RowDataTyp<T: AsRef<[u8]>> {
3333
Text(RowDataText<T>),
3434
Binary(RowDataBinary<T>),
@@ -239,7 +239,6 @@ impl<T: AsRef<[u8]>> RowData<T> for RowDataBinary<T> {
239239

240240
// Need to add packet header and null_map to returnd data
241241
let raw_data = &self.buf.as_ref()[start_pos + pos as usize..(start_pos + pos as usize + length as usize)];
242-
println!("eeeeeeeeeeeee {:?}", &raw_data[..]);
243242
return Ok(Some(
244243
RowPartData {
245244
data: raw_data.into(),

pisa-proxy/proxy/strategy/src/sharding_rewrite/meta.rs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ pub enum FieldWrapFunc {
3939
Min,
4040
Max,
4141
Count,
42+
Sum,
4243
None,
4344
}
4445

@@ -48,6 +49,7 @@ impl AsRef<str> for FieldWrapFunc {
4849
Self::Max => "max",
4950
Self::Min => "min",
5051
Self::Count => "count",
52+
Self::Sum => "sum",
5153
Self::None => "none",
5254
}
5355
}
@@ -378,6 +380,14 @@ impl Transformer for RewriteMetaData {
378380
);
379381
}
380382

383+
AggFuncName::Sum => {
384+
self.state = ScanState::FieldWrapFunc(
385+
item.span,
386+
FieldWrapFunc::Sum,
387+
item.alias_name.clone(),
388+
);
389+
}
390+
381391
_ => {}
382392
}
383393
return false;

pisa-proxy/proxy/strategy/src/sharding_rewrite/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1378,7 +1378,7 @@ impl ShardingRewrite {
13781378
.iter()
13791379
.filter_map(|f| {
13801380
if let FieldMeta::Ident(meta) = f {
1381-
let is_match = matches!(meta.wrap_func, FieldWrapFunc::Max| FieldWrapFunc::Min |FieldWrapFunc::Count);
1381+
let is_match = matches!(meta.wrap_func, FieldWrapFunc::Max| FieldWrapFunc::Min |FieldWrapFunc::Count| FieldWrapFunc::Sum);
13821382
return is_match.then(|| meta.clone())
13831383
} else {
13841384
None

pisa-proxy/runtime/mysql/src/server/executor.rs

Lines changed: 52 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -215,7 +215,8 @@ where
215215
.map_err(ErrorKind::from)?;
216216

217217
let ro = &req.rewrite_outputs;
218-
Self::handle_min_max(ro, &mut chunk, row_data.clone(), is_binary)?;
218+
219+
Self::handle_agg(ro, &mut chunk, row_data.clone(), is_binary)?;
219220

220221
let avg_change = get_avg_change(&ro.results[0].changes);
221222

@@ -260,7 +261,6 @@ where
260261

261262
let count: u64 = count_data.par_iter().sum();
262263
let sum: u64 = sum_data.par_iter().sum();
263-
264264
chunk.par_iter_mut().for_each(|x| {
265265
let mut row_data = row_data.clone();
266266
row_data.with_buf(&x[4..]);
@@ -386,7 +386,7 @@ where
386386
Ok(())
387387
}
388388

389-
fn handle_min_max(
389+
fn handle_agg(
390390
ro: &ShardingRewriteOutput,
391391
chunk: &mut [BytesMut],
392392
row_data: RowDataTyp<&[u8]>,
@@ -410,6 +410,54 @@ where
410410
a.cmp(&b)
411411
});
412412
}
413+
FieldWrapFunc::Sum => {
414+
let sum_data: Vec<_> = chunk
415+
.par_iter()
416+
.map(|x| -> Result<u64, Error> {
417+
let mut row_data = row_data.clone();
418+
row_data.with_buf(&x[4..]);
419+
let sum = if is_binary {
420+
let sum = decode_with_name::<&[u8], String>(
421+
&mut row_data,
422+
&agg.name,
423+
is_binary,
424+
).unwrap();
425+
if let Some(sum) = sum {
426+
sum.parse::<u64>().map_err(|e| ErrorKind::Runtime(e.into()))?
427+
} else {
428+
0
429+
}
430+
} else {
431+
decode_with_name::<&[u8], u64>(&mut row_data, &agg.name, is_binary)
432+
.map_err(|e| ErrorKind::Runtime(e))?
433+
.unwrap_or_else(|| 0)
434+
};
435+
436+
Ok(sum)
437+
})
438+
.collect::<Result<Vec<_>, _>>().unwrap();
439+
440+
let sum: u64 = sum_data.par_iter().sum();
441+
chunk.par_iter_mut().for_each(|x| {
442+
let mut row_data = row_data.clone();
443+
row_data.with_buf(&x[4..]);
444+
let sum_data =
445+
row_data.get_row_data_with_name(&agg.name).unwrap().unwrap();
446+
let part_data = RowPartData {
447+
data: vec![].into(),
448+
start_idx: sum_data.start_idx,
449+
part_encode_length: sum_data.part_encode_length,
450+
part_data_length: sum_data.part_data_length,
451+
};
452+
453+
let sum = format!("{:.4}", sum as u64);
454+
455+
row_data_cut_merge(x, &part_data, |data: &mut BytesMut| {
456+
data.put_lenc_int(sum.len() as u64, false);
457+
data.extend_from_slice(sum.as_bytes());
458+
});
459+
});
460+
}
413461
_ => {}
414462
}
415463

@@ -709,5 +757,6 @@ fn get_min_max_value<'a>(
709757

710758
row_data.with_buf(&b[4..]);
711759
let b = decode_with_name::<&[u8], u64>(row_data, name, is_binary).unwrap().unwrap();
760+
712761
(a, b)
713762
}

0 commit comments

Comments
 (0)