From a73b66447bd0c525d2eb7a1ac972a2ecb9cfeba7 Mon Sep 17 00:00:00 2001 From: Xiangyi Zhu <82511136+zhuxiangyi@users.noreply.github.com> Date: Fri, 26 Jun 2026 14:46:09 +0800 Subject: [PATCH] [AURON #2362] Support native bit_and / bit_or / bit_xor aggregate Implement native bit_and / bit_or / bit_xor aggregates: - native: add a generic AggBitwise

(agg/bitwise.rs) with AggBitAnd / AggBitOr / AggBitXor aliases. The accumulator is a single column of the input type; the first non-null value initializes the slot and each subsequent value is folded in with the bitwise operator. The operators are associative and commutative, so the result is order-independent, and null inputs are skipped (an all-null group yields null). Integral inputs only (Int8/Int16/Int32/Int64). Wire through the AggFunction enum, create_agg, the protobuf contract (BIT_AND / BIT_OR / BIT_XOR), the protobuf->AggFunction conversion, and the window-agg mapping. - spark-extension: add the BitAndAgg / BitOrAgg / BitXorAgg expression conversions in NativeConverters; declare the buffer schema in NativeAggBase.computeNativeAggBufferDataTypes (Seq(dataType)) so the partial -> shuffle -> final buffer schema matches the native side. Tests: - Rust unit test agg_exec::test::test_agg_bitwise (partial -> final), including an all-null group asserting null for all three aggregates. - Scala e2e AuronDataFrameAggregateSuite "native bit_and / bit_or / bit_xor aggregate" (spark34 + spark35), covering the partial -> shuffle -> final native path (incl. all-null group) and asserting NativeAggBase offload. --- .../sql/AuronDataFrameAggregateSuite.scala | 35 ++- .../sql/AuronDataFrameAggregateSuite.scala | 35 ++- native-engine/auron-planner/proto/auron.proto | 3 + native-engine/auron-planner/src/lib.rs | 3 + native-engine/auron-planner/src/planner.rs | 9 + .../datafusion-ext-plans/src/agg/agg.rs | 13 + .../datafusion-ext-plans/src/agg/bitwise.rs | 222 ++++++++++++++++++ .../datafusion-ext-plans/src/agg/mod.rs | 4 + .../datafusion-ext-plans/src/agg_exec.rs | 122 ++++++++++ .../spark/sql/auron/NativeConverters.scala | 12 +- .../execution/auron/plan/NativeAggBase.scala | 3 + 11 files changed, 458 insertions(+), 3 deletions(-) create mode 100644 native-engine/datafusion-ext-plans/src/agg/bitwise.rs diff --git a/auron-spark-tests/spark34/src/test/scala/org/apache/spark/sql/AuronDataFrameAggregateSuite.scala b/auron-spark-tests/spark34/src/test/scala/org/apache/spark/sql/AuronDataFrameAggregateSuite.scala index d1361ab7e..c3584eaf7 100644 --- a/auron-spark-tests/spark34/src/test/scala/org/apache/spark/sql/AuronDataFrameAggregateSuite.scala +++ b/auron-spark-tests/spark34/src/test/scala/org/apache/spark/sql/AuronDataFrameAggregateSuite.scala @@ -21,7 +21,7 @@ import scala.util.Random import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.auron.plan.NativeAggBase -import org.apache.spark.sql.functions.{collect_list, monotonically_increasing_id, rand, randn, spark_partition_id, sum} +import org.apache.spark.sql.functions.{collect_list, expr, monotonically_increasing_id, rand, randn, spark_partition_id, sum} import org.apache.spark.sql.internal.SQLConf class AuronDataFrameAggregateSuite extends DataFrameAggregateSuite with SparkQueryTestsBase { @@ -75,4 +75,37 @@ class AuronDataFrameAggregateSuite extends DataFrameAggregateSuite with SparkQue rand(Random.nextLong()), randn(Random.nextLong())).foreach(assertNoExceptions) } + + testAuron("native bit_and / bit_or / bit_xor aggregate") { + // bit_* are integral-only, skip nulls, and are order-independent + // (associative + commutative), so the grouped result is deterministic. + // k=1: v = [3, 5, 1] => bit_and=1, bit_or=7, bit_xor=7 + // k=2: v = [12, null, 10] => bit_and=8, bit_or=14, bit_xor=6 + // k=3: v = [null, null] => bit_and=null, bit_or=null, bit_xor=null + val df = Seq[(Int, Option[Int])]( + (1, Some(3)), + (1, Some(5)), + (1, Some(1)), + (2, Some(12)), + (2, None), + (2, Some(10)), + (3, None), + (3, None)) + .toDF("k", "v") + + val aggDF = df + .groupBy("k") + .agg( + expr("bit_and(v)").as("ba"), + expr("bit_or(v)").as("bo"), + expr("bit_xor(v)").as("bx")) + + checkAnswer(aggDF, Seq(Row(1, 1, 7, 7), Row(2, 8, 14, 6), Row(3, null, null, null))) + + // the aggregate must be offloaded to the native engine + assert(getExecutedPlan(aggDF).exists { + case _: NativeAggBase => true + case _ => false + }) + } } diff --git a/auron-spark-tests/spark35/src/test/scala/org/apache/spark/sql/AuronDataFrameAggregateSuite.scala b/auron-spark-tests/spark35/src/test/scala/org/apache/spark/sql/AuronDataFrameAggregateSuite.scala index d1361ab7e..c3584eaf7 100644 --- a/auron-spark-tests/spark35/src/test/scala/org/apache/spark/sql/AuronDataFrameAggregateSuite.scala +++ b/auron-spark-tests/spark35/src/test/scala/org/apache/spark/sql/AuronDataFrameAggregateSuite.scala @@ -21,7 +21,7 @@ import scala.util.Random import org.apache.spark.sql.execution.WholeStageCodegenExec import org.apache.spark.sql.execution.aggregate.HashAggregateExec import org.apache.spark.sql.execution.auron.plan.NativeAggBase -import org.apache.spark.sql.functions.{collect_list, monotonically_increasing_id, rand, randn, spark_partition_id, sum} +import org.apache.spark.sql.functions.{collect_list, expr, monotonically_increasing_id, rand, randn, spark_partition_id, sum} import org.apache.spark.sql.internal.SQLConf class AuronDataFrameAggregateSuite extends DataFrameAggregateSuite with SparkQueryTestsBase { @@ -75,4 +75,37 @@ class AuronDataFrameAggregateSuite extends DataFrameAggregateSuite with SparkQue rand(Random.nextLong()), randn(Random.nextLong())).foreach(assertNoExceptions) } + + testAuron("native bit_and / bit_or / bit_xor aggregate") { + // bit_* are integral-only, skip nulls, and are order-independent + // (associative + commutative), so the grouped result is deterministic. + // k=1: v = [3, 5, 1] => bit_and=1, bit_or=7, bit_xor=7 + // k=2: v = [12, null, 10] => bit_and=8, bit_or=14, bit_xor=6 + // k=3: v = [null, null] => bit_and=null, bit_or=null, bit_xor=null + val df = Seq[(Int, Option[Int])]( + (1, Some(3)), + (1, Some(5)), + (1, Some(1)), + (2, Some(12)), + (2, None), + (2, Some(10)), + (3, None), + (3, None)) + .toDF("k", "v") + + val aggDF = df + .groupBy("k") + .agg( + expr("bit_and(v)").as("ba"), + expr("bit_or(v)").as("bo"), + expr("bit_xor(v)").as("bx")) + + checkAnswer(aggDF, Seq(Row(1, 1, 7, 7), Row(2, 8, 14, 6), Row(3, null, null, null))) + + // the aggregate must be offloaded to the native engine + assert(getExecutedPlan(aggDF).exists { + case _: NativeAggBase => true + case _ => false + }) + } } diff --git a/native-engine/auron-planner/proto/auron.proto b/native-engine/auron-planner/proto/auron.proto index 13b9f48bc..a2336614d 100644 --- a/native-engine/auron-planner/proto/auron.proto +++ b/native-engine/auron-planner/proto/auron.proto @@ -148,6 +148,9 @@ enum AggFunction { FIRST = 7; FIRST_IGNORES_NULL = 8; BLOOM_FILTER = 9; + BIT_AND = 10; + BIT_OR = 11; + BIT_XOR = 12; BRICKHOUSE_COLLECT = 1000; BRICKHOUSE_COMBINE_UNIQUE = 1001; UDAF = 1002; diff --git a/native-engine/auron-planner/src/lib.rs b/native-engine/auron-planner/src/lib.rs index a0f7b83d2..b02447086 100644 --- a/native-engine/auron-planner/src/lib.rs +++ b/native-engine/auron-planner/src/lib.rs @@ -135,6 +135,9 @@ impl From for AggFunction { protobuf::AggFunction::CollectSet => AggFunction::CollectSet, protobuf::AggFunction::First => AggFunction::First, protobuf::AggFunction::FirstIgnoresNull => AggFunction::FirstIgnoresNull, + protobuf::AggFunction::BitAnd => AggFunction::BitAnd, + protobuf::AggFunction::BitOr => AggFunction::BitOr, + protobuf::AggFunction::BitXor => AggFunction::BitXor, protobuf::AggFunction::BloomFilter => AggFunction::BloomFilter, protobuf::AggFunction::BrickhouseCollect => AggFunction::BrickhouseCollect, protobuf::AggFunction::BrickhouseCombineUnique => AggFunction::BrickhouseCombineUnique, diff --git a/native-engine/auron-planner/src/planner.rs b/native-engine/auron-planner/src/planner.rs index 418cc951d..8459af261 100644 --- a/native-engine/auron-planner/src/planner.rs +++ b/native-engine/auron-planner/src/planner.rs @@ -680,6 +680,15 @@ impl PhysicalPlanner { protobuf::AggFunction::FirstIgnoresNull => { WindowFunction::Agg(AggFunction::FirstIgnoresNull) } + protobuf::AggFunction::BitAnd => { + WindowFunction::Agg(AggFunction::BitAnd) + } + protobuf::AggFunction::BitOr => { + WindowFunction::Agg(AggFunction::BitOr) + } + protobuf::AggFunction::BitXor => { + WindowFunction::Agg(AggFunction::BitXor) + } protobuf::AggFunction::BloomFilter => { WindowFunction::Agg(AggFunction::BloomFilter) } diff --git a/native-engine/datafusion-ext-plans/src/agg/agg.rs b/native-engine/datafusion-ext-plans/src/agg/agg.rs index 5eb4c3dad..7ea6ce5c9 100644 --- a/native-engine/datafusion-ext-plans/src/agg/agg.rs +++ b/native-engine/datafusion-ext-plans/src/agg/agg.rs @@ -27,6 +27,7 @@ use crate::agg::{ AggFunction, acc::AccColumnRef, avg::AggAvg, + bitwise::{AggBitAnd, AggBitOr, AggBitXor}, bloom_filter::AggBloomFilter, brickhouse, collect::{AggCollectList, AggCollectSet}, @@ -212,6 +213,18 @@ pub fn create_agg( let dt = children[0].data_type(input_schema)?; Arc::new(AggFirstIgnoresNull::try_new(children[0].clone(), dt)?) } + AggFunction::BitAnd => { + let dt = children[0].data_type(input_schema)?; + Arc::new(AggBitAnd::try_new(children[0].clone(), dt)?) + } + AggFunction::BitOr => { + let dt = children[0].data_type(input_schema)?; + Arc::new(AggBitOr::try_new(children[0].clone(), dt)?) + } + AggFunction::BitXor => { + let dt = children[0].data_type(input_schema)?; + Arc::new(AggBitXor::try_new(children[0].clone(), dt)?) + } AggFunction::BloomFilter => { let dt = children[0].data_type(input_schema)?; let empty_batch = RecordBatch::new_empty(Arc::new(Schema::empty())); diff --git a/native-engine/datafusion-ext-plans/src/agg/bitwise.rs b/native-engine/datafusion-ext-plans/src/agg/bitwise.rs new file mode 100644 index 000000000..1dc785d0f --- /dev/null +++ b/native-engine/datafusion-ext-plans/src/agg/bitwise.rs @@ -0,0 +1,222 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{ + any::Any, + fmt::{Debug, Formatter}, + marker::PhantomData, + ops::{BitAnd, BitOr, BitXor}, + sync::Arc, +}; + +use arrow::{array::*, datatypes::*}; +use datafusion::{common::Result, physical_expr::PhysicalExprRef}; +use datafusion_ext_commons::{df_execution_err, downcast_any}; + +use crate::{ + agg::{ + Agg, + acc::{AccColumnRef, AccPrimColumn, create_acc_generic_column}, + agg::IdxSelection, + }, + idx_for_zipped, +}; + +pub type AggBitAnd = AggBitwise; +pub type AggBitOr = AggBitwise; +pub type AggBitXor = AggBitwise; + +/// Native implementation of Spark's bit_and / bit_or / bit_xor aggregates. +/// +/// These only accept integral inputs. The accumulator is a single column of the +/// same type as the input: the first non-null value initializes the slot and +/// every subsequent value is folded in with the bitwise operator. Because the +/// operators are associative and commutative, the result is independent of the +/// visiting/merge order, and null inputs are simply skipped (an all-null group +/// yields null). +pub struct AggBitwise { + child: PhysicalExprRef, + data_type: DataType, + acc_array_data_types: Vec, + _phantom: PhantomData

, +} + +impl AggBitwise

{ + pub fn try_new(child: PhysicalExprRef, data_type: DataType) -> Result { + match &data_type { + DataType::Int8 | DataType::Int16 | DataType::Int32 | DataType::Int64 => {} + other => df_execution_err!("{} only supports integral types, got {other:?}", P::NAME)?, + } + let acc_array_data_types = vec![data_type.clone()]; + Ok(Self { + child, + data_type, + acc_array_data_types, + _phantom: Default::default(), + }) + } +} + +impl Debug for AggBitwise

{ + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "{}({:?})", P::NAME, self.child) + } +} + +impl Agg for AggBitwise

{ + fn as_any(&self) -> &dyn Any { + self + } + + fn exprs(&self) -> Vec { + vec![self.child.clone()] + } + + fn with_new_exprs(&self, exprs: Vec) -> Result> { + Ok(Arc::new(Self::try_new( + exprs[0].clone(), + self.data_type.clone(), + )?)) + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn nullable(&self) -> bool { + true + } + + fn create_acc_column(&self, num_rows: usize) -> AccColumnRef { + create_acc_generic_column(self.data_type.clone(), num_rows) + } + + fn acc_array_data_types(&self) -> &[DataType] { + &self.acc_array_data_types + } + + fn partial_update( + &self, + accs: &mut AccColumnRef, + acc_idx: IdxSelection<'_>, + partial_args: &[ArrayRef], + partial_arg_idx: IdxSelection<'_>, + ) -> Result<()> { + let partial_arg = &partial_args[0]; + accs.ensure_size(acc_idx); + + macro_rules! handle_int { + ($array_ty:ty, $native:ty) => {{ + let partial_arg = downcast_any!(partial_arg, $array_ty)?; + let accs = downcast_any!(accs, mut AccPrimColumn<$native>)?; + idx_for_zipped! { + ((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => { + if partial_arg.is_valid(partial_arg_idx) { + let partial_value = partial_arg.value(partial_arg_idx); + accs.update_value(acc_idx, partial_value, |v| P::op(v, partial_value)); + } + } + } + }}; + } + + match &self.data_type { + DataType::Int8 => handle_int!(Int8Array, i8), + DataType::Int16 => handle_int!(Int16Array, i16), + DataType::Int32 => handle_int!(Int32Array, i32), + DataType::Int64 => handle_int!(Int64Array, i64), + other => df_execution_err!("{} only supports integral types, got {other:?}", P::NAME)?, + } + Ok(()) + } + + fn partial_merge( + &self, + accs: &mut AccColumnRef, + acc_idx: IdxSelection<'_>, + merging_accs: &mut AccColumnRef, + merging_acc_idx: IdxSelection<'_>, + ) -> Result<()> { + accs.ensure_size(acc_idx); + + macro_rules! handle_int { + ($native:ty) => {{ + let accs = downcast_any!(accs, mut AccPrimColumn<$native>)?; + let merging_accs = downcast_any!(merging_accs, mut AccPrimColumn<$native>)?; + idx_for_zipped! { + ((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => { + if let Some(merging_value) = merging_accs.value(merging_acc_idx) { + accs.update_value(acc_idx, merging_value, |v| P::op(v, merging_value)); + } + } + } + }}; + } + + match &self.data_type { + DataType::Int8 => handle_int!(i8), + DataType::Int16 => handle_int!(i16), + DataType::Int32 => handle_int!(i32), + DataType::Int64 => handle_int!(i64), + other => df_execution_err!("{} only supports integral types, got {other:?}", P::NAME)?, + } + Ok(()) + } + + fn final_merge(&self, accs: &mut AccColumnRef, acc_idx: IdxSelection<'_>) -> Result { + Ok(accs.freeze_to_arrays(acc_idx)?[0].clone()) + } +} + +pub trait AggBitwiseParams: 'static + Send + Sync { + const NAME: &'static str; + fn op(a: T, b: T) -> T + where + T: BitAnd + BitOr + BitXor; +} + +pub struct AggBitAndParams; +pub struct AggBitOrParams; +pub struct AggBitXorParams; + +impl AggBitwiseParams for AggBitAndParams { + const NAME: &'static str = "bit_and"; + fn op(a: T, b: T) -> T + where + T: BitAnd + BitOr + BitXor, + { + a & b + } +} + +impl AggBitwiseParams for AggBitOrParams { + const NAME: &'static str = "bit_or"; + fn op(a: T, b: T) -> T + where + T: BitAnd + BitOr + BitXor, + { + a | b + } +} + +impl AggBitwiseParams for AggBitXorParams { + const NAME: &'static str = "bit_xor"; + fn op(a: T, b: T) -> T + where + T: BitAnd + BitOr + BitXor, + { + a ^ b + } +} diff --git a/native-engine/datafusion-ext-plans/src/agg/mod.rs b/native-engine/datafusion-ext-plans/src/agg/mod.rs index 565e15b16..bd58205fe 100644 --- a/native-engine/datafusion-ext-plans/src/agg/mod.rs +++ b/native-engine/datafusion-ext-plans/src/agg/mod.rs @@ -19,6 +19,7 @@ pub mod agg_ctx; pub mod agg_hash_map; pub mod agg_table; pub mod avg; +pub mod bitwise; pub mod bloom_filter; pub mod brickhouse; pub mod collect; @@ -69,6 +70,9 @@ pub enum AggFunction { Min, First, FirstIgnoresNull, + BitAnd, + BitOr, + BitXor, CollectList, CollectSet, BloomFilter, diff --git a/native-engine/datafusion-ext-plans/src/agg_exec.rs b/native-engine/datafusion-ext-plans/src/agg_exec.rs index d75d304f0..2792aed81 100644 --- a/native-engine/datafusion-ext-plans/src/agg_exec.rs +++ b/native-engine/datafusion-ext-plans/src/agg_exec.rs @@ -694,6 +694,128 @@ mod test { Ok(()) } + #[tokio::test(flavor = "multi_thread", worker_threads = 1)] + async fn test_agg_bitwise() -> Result<()> { + MemManager::init(10000); + + // group key "k" and a nullable integer column "v". bit_* skip nulls and + // are order-independent (associative + commutative). + // k=1: v = [3, 5, 1] -> bit_and=1, bit_or=7, bit_xor=7 + // k=2: v = [12, null, 10] -> bit_and=8, bit_or=14, bit_xor=6 + // k=3: v = [null, null] -> bit_and=null, bit_or=null, bit_xor=null + // (the all-null group pins the "skip nulls, never seed" invariant) + let schema = Arc::new(Schema::new(vec![ + Field::new("k", DataType::Int32, false), + Field::new("v", DataType::Int32, true), + ])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(vec![1, 2, 1, 2, 1, 2, 3, 3])), + Arc::new(Int32Array::from(vec![ + Some(3), + Some(12), + Some(5), + None, + Some(1), + Some(10), + None, + None, + ])), + ], + )?; + let input: Arc = + Arc::new(TestMemoryExec::try_new(&[vec![batch]], schema, None)?); + + let agg_bit_and = create_agg( + AggFunction::BitAnd, + &[phys_expr::col("v", &input.schema())?], + &input.schema(), + DataType::Int32, + )?; + let agg_bit_or = create_agg( + AggFunction::BitOr, + &[phys_expr::col("v", &input.schema())?], + &input.schema(), + DataType::Int32, + )?; + let agg_bit_xor = create_agg( + AggFunction::BitXor, + &[phys_expr::col("v", &input.schema())?], + &input.schema(), + DataType::Int32, + )?; + let aggs_agg_expr = vec![ + AggExpr { + field_name: "agg_bit_and".to_string(), + mode: Partial, + filter: None, + agg: agg_bit_and, + }, + AggExpr { + field_name: "agg_bit_or".to_string(), + mode: Partial, + filter: None, + agg: agg_bit_or, + }, + AggExpr { + field_name: "agg_bit_xor".to_string(), + mode: Partial, + filter: None, + agg: agg_bit_xor, + }, + ]; + + let agg_exec_partial = AggExec::try_new( + HashAgg, + vec![GroupingExpr { + field_name: "k".to_string(), + expr: Arc::new(Column::new("k", 0)), + }], + aggs_agg_expr.clone(), + false, + input, + )?; + + let agg_exec_final = AggExec::try_new( + HashAgg, + vec![GroupingExpr { + field_name: "k".to_string(), + expr: Arc::new(Column::new("k", 0)), + }], + aggs_agg_expr + .into_iter() + .map(|mut agg| { + agg.agg = agg + .agg + .with_new_exprs(vec![Arc::new(phys_expr::Literal::new( + ScalarValue::Null, + ))])?; + agg.mode = Final; + Ok(agg) + }) + .collect::>()?, + false, + Arc::new(agg_exec_partial), + )?; + + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); + let output_final = agg_exec_final.execute(0, task_ctx)?; + let batches = datafusion::physical_plan::common::collect(output_final).await?; + let expected = vec![ + "+---+-------------+------------+-------------+", + "| k | agg_bit_and | agg_bit_or | agg_bit_xor |", + "+---+-------------+------------+-------------+", + "| 1 | 1 | 7 | 7 |", + "| 2 | 8 | 14 | 6 |", + "| 3 | | | |", + "+---+-------------+------------+-------------+", + ]; + assert_batches_sorted_eq!(expected, &batches); + Ok(()) + } + #[tokio::test] async fn test_agg_with_filter() -> Result<()> { MemManager::init(1000); diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala index 1a3aa3e6f..6f1dc4494 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala @@ -37,7 +37,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.auron.util.Using import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Average, CollectList, CollectSet, Count, DeclarativeAggregate, First, Max, Min, Sum, TypedImperativeAggregate} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Average, BitAndAgg, BitOrAgg, BitXorAgg, CollectList, CollectSet, Count, DeclarativeAggregate, First, Max, Min, Sum, TypedImperativeAggregate} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.expressions.codegen.ExprCode import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero @@ -1271,6 +1271,16 @@ object NativeConverters extends Logging { }) aggBuilder.addChildren(convertExpr(child)) + case e: BitAndAgg => + aggBuilder.setAggFunction(pb.AggFunction.BIT_AND) + aggBuilder.addChildren(convertExpr(e.child)) + case e: BitOrAgg => + aggBuilder.setAggFunction(pb.AggFunction.BIT_OR) + aggBuilder.addChildren(convertExpr(e.child)) + case e: BitXorAgg => + aggBuilder.setAggFunction(pb.AggFunction.BIT_XOR) + aggBuilder.addChildren(convertExpr(e.child)) + case CollectList(child, _, _) => aggBuilder.setAggFunction(pb.AggFunction.COLLECT_LIST) aggBuilder.addChildren(convertExpr(child)) diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeAggBase.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeAggBase.scala index 755fb6466..1590a0f7f 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeAggBase.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeAggBase.scala @@ -308,6 +308,9 @@ object NativeAggBase extends Logging { case f: Average => Seq(f.dataType, LongType) case f @ First(_, true) => Seq(f.dataType) case f @ First(_, false) => Seq(f.dataType, BooleanType) + case f: BitAndAgg => Seq(f.dataType) + case f: BitOrAgg => Seq(f.dataType) + case f: BitXorAgg => Seq(f.dataType) case _ => Seq(BinaryType) } }