diff --git a/native-engine/datafusion-ext-plans/src/agg/collect.rs b/native-engine/datafusion-ext-plans/src/agg/collect.rs index 984f7031f..3389d6d7c 100644 --- a/native-engine/datafusion-ext-plans/src/agg/collect.rs +++ b/native-engine/datafusion-ext-plans/src/agg/collect.rs @@ -532,6 +532,12 @@ impl InternalSet { iter } + fn into_ordered_positions(self) -> Vec<(u32, u32)> { + let mut positions: Vec<_> = self.into_iter().collect(); + positions.sort_unstable_by_key(|&(pos, _)| pos); + positions + } + fn convert_to_huge_if_needed(&mut self, list: &mut AccList) { if let Self::Small(s) = self && s.len() >= 4 @@ -561,12 +567,12 @@ impl AccSet { } pub fn merge(&mut self, other: &mut Self) { - if self.set.len() < other.set.len() { - // ensure the probed set is smaller - std::mem::swap(self, other); - } - for pos_len in std::mem::take(&mut other.set).into_iter() { - self.append_raw(other.list.ref_raw(pos_len)); + let other_raw = std::mem::take(&mut other.list.raw); + let other_positions = std::mem::take(&mut other.set).into_ordered_positions(); + + for (pos, len) in other_positions { + let raw = &other_raw[pos as usize..][..len as usize]; + self.append_raw(raw); } } @@ -694,6 +700,52 @@ mod tests { assert_eq!(acc_set1.list.raw.len(), 12); // 4 bytes for each int32 assert_eq!(acc_set1.set.len(), 3); + let values: Vec = acc_set1.into_values(DataType::Int32, false).collect(); + assert_eq!(values, vec![value1, value2, value3]); + } + + #[test] + fn test_acc_set_merge_preserves_first_occurrence_order_when_rhs_is_larger() { + let mut acc_set1 = AccSet::default(); + let mut acc_set2 = AccSet::default(); + let value1 = ScalarValue::Int32(Some(1)); + let value2 = ScalarValue::Int32(Some(2)); + let value3 = ScalarValue::Int32(Some(3)); + + acc_set1.append(&value1, false); + acc_set2.append(&value2, false); + acc_set2.append(&value3, false); + + acc_set1.merge(&mut acc_set2); + + let values: Vec = acc_set1.into_values(DataType::Int32, false).collect(); + assert_eq!(values, vec![value1, value2, value3]); + } + + #[test] + fn test_acc_set_merge_preserves_first_occurrence_order_when_rhs_becomes_huge() { + let mut lhs = AccSet::default(); + let mut rhs = AccSet::default(); + let value1 = ScalarValue::Int32(Some(1)); + let value2 = ScalarValue::Int32(Some(2)); + let value3 = ScalarValue::Int32(Some(3)); + let value4 = ScalarValue::Int32(Some(4)); + let value5 = ScalarValue::Int32(Some(5)); + + lhs.append(&value1, false); + rhs.append(&value2, false); + rhs.append(&value3, false); + rhs.append(&value4, false); + rhs.append(&value5, false); + + assert!(matches!(&rhs.set, InternalSet::Huge(_))); + + lhs.merge(&mut rhs); + + let values: Vec = lhs.into_values(DataType::Int32, false).collect(); + assert_eq!(values, vec![value1, value2, value3, value4, value5]); + assert_eq!(rhs.list.raw.len(), 0); + assert_eq!(rhs.set.len(), 0); } #[test] @@ -746,4 +798,33 @@ mod tests { assert_eq!(acc_col.take_values(2), acc_col_unspill.take_values(2)); Ok(()) } + + #[test] + fn test_acc_set_merge_preserves_first_occurrence_order_after_rhs_spill() -> Result<()> { + let value1 = ScalarValue::Int32(Some(1)); + let value2 = ScalarValue::Int32(Some(2)); + let value3 = ScalarValue::Int32(Some(3)); + + let mut lhs = AccSetColumn::empty(DataType::Int32); + lhs.resize(1); + lhs.append_item(0, &value1); + + let mut rhs = AccSetColumn::empty(DataType::Int32); + rhs.resize(1); + rhs.append_item(0, &value2); + rhs.append_item(0, &value3); + + let mut spill: Box = Box::new(vec![]); + let mut spill_writer = spill.get_compressed_writer(); + rhs.spill(IdxSelection::Range(0, 1), &mut spill_writer)?; + spill_writer.finish()?; + + let mut rhs_unspill = AccSetColumn::empty(DataType::Int32); + rhs_unspill.unspill(1, &mut spill.get_compressed_reader())?; + + lhs.merge_items(0, &mut rhs_unspill, 0); + + assert_eq!(lhs.take_values(0), vec![value1, value2, value3]); + Ok(()) + } }