From a73b66447bd0c525d2eb7a1ac972a2ecb9cfeba7 Mon Sep 17 00:00:00 2001
From: Xiangyi Zhu <82511136+zhuxiangyi@users.noreply.github.com>
Date: Fri, 26 Jun 2026 14:46:09 +0800
Subject: [PATCH] [AURON #2362] Support native bit_and / bit_or / bit_xor
aggregate
Implement native bit_and / bit_or / bit_xor aggregates:
- native: add a generic AggBitwise
(agg/bitwise.rs) with AggBitAnd /
AggBitOr / AggBitXor aliases. The accumulator is a single column of the
input type; the first non-null value initializes the slot and each
subsequent value is folded in with the bitwise operator. The operators
are associative and commutative, so the result is order-independent, and
null inputs are skipped (an all-null group yields null). Integral inputs
only (Int8/Int16/Int32/Int64). Wire through the AggFunction enum,
create_agg, the protobuf contract (BIT_AND / BIT_OR / BIT_XOR), the
protobuf->AggFunction conversion, and the window-agg mapping.
- spark-extension: add the BitAndAgg / BitOrAgg / BitXorAgg expression
conversions in NativeConverters; declare the buffer schema in
NativeAggBase.computeNativeAggBufferDataTypes (Seq(dataType)) so the
partial -> shuffle -> final buffer schema matches the native side.
Tests:
- Rust unit test agg_exec::test::test_agg_bitwise (partial -> final),
including an all-null group asserting null for all three aggregates.
- Scala e2e AuronDataFrameAggregateSuite "native bit_and / bit_or / bit_xor
aggregate" (spark34 + spark35), covering the partial -> shuffle -> final
native path (incl. all-null group) and asserting NativeAggBase offload.
---
.../sql/AuronDataFrameAggregateSuite.scala | 35 ++-
.../sql/AuronDataFrameAggregateSuite.scala | 35 ++-
native-engine/auron-planner/proto/auron.proto | 3 +
native-engine/auron-planner/src/lib.rs | 3 +
native-engine/auron-planner/src/planner.rs | 9 +
.../datafusion-ext-plans/src/agg/agg.rs | 13 +
.../datafusion-ext-plans/src/agg/bitwise.rs | 222 ++++++++++++++++++
.../datafusion-ext-plans/src/agg/mod.rs | 4 +
.../datafusion-ext-plans/src/agg_exec.rs | 122 ++++++++++
.../spark/sql/auron/NativeConverters.scala | 12 +-
.../execution/auron/plan/NativeAggBase.scala | 3 +
11 files changed, 458 insertions(+), 3 deletions(-)
create mode 100644 native-engine/datafusion-ext-plans/src/agg/bitwise.rs
diff --git a/auron-spark-tests/spark34/src/test/scala/org/apache/spark/sql/AuronDataFrameAggregateSuite.scala b/auron-spark-tests/spark34/src/test/scala/org/apache/spark/sql/AuronDataFrameAggregateSuite.scala
index d1361ab7e..c3584eaf7 100644
--- a/auron-spark-tests/spark34/src/test/scala/org/apache/spark/sql/AuronDataFrameAggregateSuite.scala
+++ b/auron-spark-tests/spark34/src/test/scala/org/apache/spark/sql/AuronDataFrameAggregateSuite.scala
@@ -21,7 +21,7 @@ import scala.util.Random
import org.apache.spark.sql.execution.WholeStageCodegenExec
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.auron.plan.NativeAggBase
-import org.apache.spark.sql.functions.{collect_list, monotonically_increasing_id, rand, randn, spark_partition_id, sum}
+import org.apache.spark.sql.functions.{collect_list, expr, monotonically_increasing_id, rand, randn, spark_partition_id, sum}
import org.apache.spark.sql.internal.SQLConf
class AuronDataFrameAggregateSuite extends DataFrameAggregateSuite with SparkQueryTestsBase {
@@ -75,4 +75,37 @@ class AuronDataFrameAggregateSuite extends DataFrameAggregateSuite with SparkQue
rand(Random.nextLong()),
randn(Random.nextLong())).foreach(assertNoExceptions)
}
+
+ testAuron("native bit_and / bit_or / bit_xor aggregate") {
+ // bit_* are integral-only, skip nulls, and are order-independent
+ // (associative + commutative), so the grouped result is deterministic.
+ // k=1: v = [3, 5, 1] => bit_and=1, bit_or=7, bit_xor=7
+ // k=2: v = [12, null, 10] => bit_and=8, bit_or=14, bit_xor=6
+ // k=3: v = [null, null] => bit_and=null, bit_or=null, bit_xor=null
+ val df = Seq[(Int, Option[Int])](
+ (1, Some(3)),
+ (1, Some(5)),
+ (1, Some(1)),
+ (2, Some(12)),
+ (2, None),
+ (2, Some(10)),
+ (3, None),
+ (3, None))
+ .toDF("k", "v")
+
+ val aggDF = df
+ .groupBy("k")
+ .agg(
+ expr("bit_and(v)").as("ba"),
+ expr("bit_or(v)").as("bo"),
+ expr("bit_xor(v)").as("bx"))
+
+ checkAnswer(aggDF, Seq(Row(1, 1, 7, 7), Row(2, 8, 14, 6), Row(3, null, null, null)))
+
+ // the aggregate must be offloaded to the native engine
+ assert(getExecutedPlan(aggDF).exists {
+ case _: NativeAggBase => true
+ case _ => false
+ })
+ }
}
diff --git a/auron-spark-tests/spark35/src/test/scala/org/apache/spark/sql/AuronDataFrameAggregateSuite.scala b/auron-spark-tests/spark35/src/test/scala/org/apache/spark/sql/AuronDataFrameAggregateSuite.scala
index d1361ab7e..c3584eaf7 100644
--- a/auron-spark-tests/spark35/src/test/scala/org/apache/spark/sql/AuronDataFrameAggregateSuite.scala
+++ b/auron-spark-tests/spark35/src/test/scala/org/apache/spark/sql/AuronDataFrameAggregateSuite.scala
@@ -21,7 +21,7 @@ import scala.util.Random
import org.apache.spark.sql.execution.WholeStageCodegenExec
import org.apache.spark.sql.execution.aggregate.HashAggregateExec
import org.apache.spark.sql.execution.auron.plan.NativeAggBase
-import org.apache.spark.sql.functions.{collect_list, monotonically_increasing_id, rand, randn, spark_partition_id, sum}
+import org.apache.spark.sql.functions.{collect_list, expr, monotonically_increasing_id, rand, randn, spark_partition_id, sum}
import org.apache.spark.sql.internal.SQLConf
class AuronDataFrameAggregateSuite extends DataFrameAggregateSuite with SparkQueryTestsBase {
@@ -75,4 +75,37 @@ class AuronDataFrameAggregateSuite extends DataFrameAggregateSuite with SparkQue
rand(Random.nextLong()),
randn(Random.nextLong())).foreach(assertNoExceptions)
}
+
+ testAuron("native bit_and / bit_or / bit_xor aggregate") {
+ // bit_* are integral-only, skip nulls, and are order-independent
+ // (associative + commutative), so the grouped result is deterministic.
+ // k=1: v = [3, 5, 1] => bit_and=1, bit_or=7, bit_xor=7
+ // k=2: v = [12, null, 10] => bit_and=8, bit_or=14, bit_xor=6
+ // k=3: v = [null, null] => bit_and=null, bit_or=null, bit_xor=null
+ val df = Seq[(Int, Option[Int])](
+ (1, Some(3)),
+ (1, Some(5)),
+ (1, Some(1)),
+ (2, Some(12)),
+ (2, None),
+ (2, Some(10)),
+ (3, None),
+ (3, None))
+ .toDF("k", "v")
+
+ val aggDF = df
+ .groupBy("k")
+ .agg(
+ expr("bit_and(v)").as("ba"),
+ expr("bit_or(v)").as("bo"),
+ expr("bit_xor(v)").as("bx"))
+
+ checkAnswer(aggDF, Seq(Row(1, 1, 7, 7), Row(2, 8, 14, 6), Row(3, null, null, null)))
+
+ // the aggregate must be offloaded to the native engine
+ assert(getExecutedPlan(aggDF).exists {
+ case _: NativeAggBase => true
+ case _ => false
+ })
+ }
}
diff --git a/native-engine/auron-planner/proto/auron.proto b/native-engine/auron-planner/proto/auron.proto
index 13b9f48bc..a2336614d 100644
--- a/native-engine/auron-planner/proto/auron.proto
+++ b/native-engine/auron-planner/proto/auron.proto
@@ -148,6 +148,9 @@ enum AggFunction {
FIRST = 7;
FIRST_IGNORES_NULL = 8;
BLOOM_FILTER = 9;
+ BIT_AND = 10;
+ BIT_OR = 11;
+ BIT_XOR = 12;
BRICKHOUSE_COLLECT = 1000;
BRICKHOUSE_COMBINE_UNIQUE = 1001;
UDAF = 1002;
diff --git a/native-engine/auron-planner/src/lib.rs b/native-engine/auron-planner/src/lib.rs
index a0f7b83d2..b02447086 100644
--- a/native-engine/auron-planner/src/lib.rs
+++ b/native-engine/auron-planner/src/lib.rs
@@ -135,6 +135,9 @@ impl From for AggFunction {
protobuf::AggFunction::CollectSet => AggFunction::CollectSet,
protobuf::AggFunction::First => AggFunction::First,
protobuf::AggFunction::FirstIgnoresNull => AggFunction::FirstIgnoresNull,
+ protobuf::AggFunction::BitAnd => AggFunction::BitAnd,
+ protobuf::AggFunction::BitOr => AggFunction::BitOr,
+ protobuf::AggFunction::BitXor => AggFunction::BitXor,
protobuf::AggFunction::BloomFilter => AggFunction::BloomFilter,
protobuf::AggFunction::BrickhouseCollect => AggFunction::BrickhouseCollect,
protobuf::AggFunction::BrickhouseCombineUnique => AggFunction::BrickhouseCombineUnique,
diff --git a/native-engine/auron-planner/src/planner.rs b/native-engine/auron-planner/src/planner.rs
index 418cc951d..8459af261 100644
--- a/native-engine/auron-planner/src/planner.rs
+++ b/native-engine/auron-planner/src/planner.rs
@@ -680,6 +680,15 @@ impl PhysicalPlanner {
protobuf::AggFunction::FirstIgnoresNull => {
WindowFunction::Agg(AggFunction::FirstIgnoresNull)
}
+ protobuf::AggFunction::BitAnd => {
+ WindowFunction::Agg(AggFunction::BitAnd)
+ }
+ protobuf::AggFunction::BitOr => {
+ WindowFunction::Agg(AggFunction::BitOr)
+ }
+ protobuf::AggFunction::BitXor => {
+ WindowFunction::Agg(AggFunction::BitXor)
+ }
protobuf::AggFunction::BloomFilter => {
WindowFunction::Agg(AggFunction::BloomFilter)
}
diff --git a/native-engine/datafusion-ext-plans/src/agg/agg.rs b/native-engine/datafusion-ext-plans/src/agg/agg.rs
index 5eb4c3dad..7ea6ce5c9 100644
--- a/native-engine/datafusion-ext-plans/src/agg/agg.rs
+++ b/native-engine/datafusion-ext-plans/src/agg/agg.rs
@@ -27,6 +27,7 @@ use crate::agg::{
AggFunction,
acc::AccColumnRef,
avg::AggAvg,
+ bitwise::{AggBitAnd, AggBitOr, AggBitXor},
bloom_filter::AggBloomFilter,
brickhouse,
collect::{AggCollectList, AggCollectSet},
@@ -212,6 +213,18 @@ pub fn create_agg(
let dt = children[0].data_type(input_schema)?;
Arc::new(AggFirstIgnoresNull::try_new(children[0].clone(), dt)?)
}
+ AggFunction::BitAnd => {
+ let dt = children[0].data_type(input_schema)?;
+ Arc::new(AggBitAnd::try_new(children[0].clone(), dt)?)
+ }
+ AggFunction::BitOr => {
+ let dt = children[0].data_type(input_schema)?;
+ Arc::new(AggBitOr::try_new(children[0].clone(), dt)?)
+ }
+ AggFunction::BitXor => {
+ let dt = children[0].data_type(input_schema)?;
+ Arc::new(AggBitXor::try_new(children[0].clone(), dt)?)
+ }
AggFunction::BloomFilter => {
let dt = children[0].data_type(input_schema)?;
let empty_batch = RecordBatch::new_empty(Arc::new(Schema::empty()));
diff --git a/native-engine/datafusion-ext-plans/src/agg/bitwise.rs b/native-engine/datafusion-ext-plans/src/agg/bitwise.rs
new file mode 100644
index 000000000..1dc785d0f
--- /dev/null
+++ b/native-engine/datafusion-ext-plans/src/agg/bitwise.rs
@@ -0,0 +1,222 @@
+// Licensed to the Apache Software Foundation (ASF) under one or more
+// contributor license agreements. See the NOTICE file distributed with
+// this work for additional information regarding copyright ownership.
+// The ASF licenses this file to You under the Apache License, Version 2.0
+// (the "License"); you may not use this file except in compliance with
+// the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+use std::{
+ any::Any,
+ fmt::{Debug, Formatter},
+ marker::PhantomData,
+ ops::{BitAnd, BitOr, BitXor},
+ sync::Arc,
+};
+
+use arrow::{array::*, datatypes::*};
+use datafusion::{common::Result, physical_expr::PhysicalExprRef};
+use datafusion_ext_commons::{df_execution_err, downcast_any};
+
+use crate::{
+ agg::{
+ Agg,
+ acc::{AccColumnRef, AccPrimColumn, create_acc_generic_column},
+ agg::IdxSelection,
+ },
+ idx_for_zipped,
+};
+
+pub type AggBitAnd = AggBitwise;
+pub type AggBitOr = AggBitwise;
+pub type AggBitXor = AggBitwise;
+
+/// Native implementation of Spark's bit_and / bit_or / bit_xor aggregates.
+///
+/// These only accept integral inputs. The accumulator is a single column of the
+/// same type as the input: the first non-null value initializes the slot and
+/// every subsequent value is folded in with the bitwise operator. Because the
+/// operators are associative and commutative, the result is independent of the
+/// visiting/merge order, and null inputs are simply skipped (an all-null group
+/// yields null).
+pub struct AggBitwise {
+ child: PhysicalExprRef,
+ data_type: DataType,
+ acc_array_data_types: Vec,
+ _phantom: PhantomData,
+}
+
+impl AggBitwise {
+ pub fn try_new(child: PhysicalExprRef, data_type: DataType) -> Result {
+ match &data_type {
+ DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {}
+ other => df_execution_err!("{} only supports integral types, got {other:?}", P::NAME)?,
+ }
+ let acc_array_data_types = vec![data_type.clone()];
+ Ok(Self {
+ child,
+ data_type,
+ acc_array_data_types,
+ _phantom: Default::default(),
+ })
+ }
+}
+
+impl Debug for AggBitwise {
+ fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+ write!(f, "{}({:?})", P::NAME, self.child)
+ }
+}
+
+impl Agg for AggBitwise {
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ fn exprs(&self) -> Vec {
+ vec![self.child.clone()]
+ }
+
+ fn with_new_exprs(&self, exprs: Vec) -> Result> {
+ Ok(Arc::new(Self::try_new(
+ exprs[0].clone(),
+ self.data_type.clone(),
+ )?))
+ }
+
+ fn data_type(&self) -> &DataType {
+ &self.data_type
+ }
+
+ fn nullable(&self) -> bool {
+ true
+ }
+
+ fn create_acc_column(&self, num_rows: usize) -> AccColumnRef {
+ create_acc_generic_column(self.data_type.clone(), num_rows)
+ }
+
+ fn acc_array_data_types(&self) -> &[DataType] {
+ &self.acc_array_data_types
+ }
+
+ fn partial_update(
+ &self,
+ accs: &mut AccColumnRef,
+ acc_idx: IdxSelection<'_>,
+ partial_args: &[ArrayRef],
+ partial_arg_idx: IdxSelection<'_>,
+ ) -> Result<()> {
+ let partial_arg = &partial_args[0];
+ accs.ensure_size(acc_idx);
+
+ macro_rules! handle_int {
+ ($array_ty:ty, $native:ty) => {{
+ let partial_arg = downcast_any!(partial_arg, $array_ty)?;
+ let accs = downcast_any!(accs, mut AccPrimColumn<$native>)?;
+ idx_for_zipped! {
+ ((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => {
+ if partial_arg.is_valid(partial_arg_idx) {
+ let partial_value = partial_arg.value(partial_arg_idx);
+ accs.update_value(acc_idx, partial_value, |v| P::op(v, partial_value));
+ }
+ }
+ }
+ }};
+ }
+
+ match &self.data_type {
+ DataType::Int8 => handle_int!(Int8Array, i8),
+ DataType::Int16 => handle_int!(Int16Array, i16),
+ DataType::Int32 => handle_int!(Int32Array, i32),
+ DataType::Int64 => handle_int!(Int64Array, i64),
+ other => df_execution_err!("{} only supports integral types, got {other:?}", P::NAME)?,
+ }
+ Ok(())
+ }
+
+ fn partial_merge(
+ &self,
+ accs: &mut AccColumnRef,
+ acc_idx: IdxSelection<'_>,
+ merging_accs: &mut AccColumnRef,
+ merging_acc_idx: IdxSelection<'_>,
+ ) -> Result<()> {
+ accs.ensure_size(acc_idx);
+
+ macro_rules! handle_int {
+ ($native:ty) => {{
+ let accs = downcast_any!(accs, mut AccPrimColumn<$native>)?;
+ let merging_accs = downcast_any!(merging_accs, mut AccPrimColumn<$native>)?;
+ idx_for_zipped! {
+ ((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => {
+ if let Some(merging_value) = merging_accs.value(merging_acc_idx) {
+ accs.update_value(acc_idx, merging_value, |v| P::op(v, merging_value));
+ }
+ }
+ }
+ }};
+ }
+
+ match &self.data_type {
+ DataType::Int8 => handle_int!(i8),
+ DataType::Int16 => handle_int!(i16),
+ DataType::Int32 => handle_int!(i32),
+ DataType::Int64 => handle_int!(i64),
+ other => df_execution_err!("{} only supports integral types, got {other:?}", P::NAME)?,
+ }
+ Ok(())
+ }
+
+ fn final_merge(&self, accs: &mut AccColumnRef, acc_idx: IdxSelection<'_>) -> Result {
+ Ok(accs.freeze_to_arrays(acc_idx)?[0].clone())
+ }
+}
+
+pub trait AggBitwiseParams: 'static + Send + Sync {
+ const NAME: &'static str;
+ fn op(a: T, b: T) -> T
+ where
+ T: BitAnd