From 1b0897896701134cd612d63c29901cb69b87c090 Mon Sep 17 00:00:00 2001 From: Asish Kumar Date: Tue, 14 Apr 2026 01:46:02 +0530 Subject: [PATCH] [AURON #2178] [AURON #2179] Implement native support for first_value and last_value window functions Spark `first_value(...)` and `last_value(...)` are not supported in Auron's native window execution path, causing queries using them to fall back to Spark instead of running natively. This extends native window coverage for both functions through the existing aggregate-window infrastructure. Changes included here: - map Spark window `First` and `Last` expressions to native aggregate window functions - add native `AggLast` and `AggLastIgnoresNull` implementations - extend protobuf, planner, and aggregate factory mappings for `LAST` and `LAST_IGNORES_NULL` - support `last(...)` as a native group aggregate and cover that path with a deterministic test - add Scala regression coverage for leading-null `first_value`, string/int/boolean `last_value`, and IGNORE NULLS variants - hoist the typed-null scalar value in `AggLast` instead of rebuilding it for every null row Signed-off-by: Asish Kumar --- native-engine/auron-planner/proto/auron.proto | 2 + native-engine/auron-planner/src/lib.rs | 2 + native-engine/auron-planner/src/planner.rs | 6 + .../datafusion-ext-plans/src/agg/agg.rs | 10 + .../datafusion-ext-plans/src/agg/last.rs | 236 ++++++++++++++++++ .../src/agg/last_ignores_null.rs | 232 +++++++++++++++++ .../datafusion-ext-plans/src/agg/mod.rs | 4 + .../org/apache/auron/AuronWindowSuite.scala | 145 +++++++++++ .../spark/sql/auron/NativeConverters.scala | 14 +- .../auron/plan/NativeWindowBase.scala | 29 +++ 10 files changed, 679 insertions(+), 1 deletion(-) create mode 100644 native-engine/datafusion-ext-plans/src/agg/last.rs create mode 100644 native-engine/datafusion-ext-plans/src/agg/last_ignores_null.rs diff --git a/native-engine/auron-planner/proto/auron.proto b/native-engine/auron-planner/proto/auron.proto index a905c8a36..7cba87445 100644 --- a/native-engine/auron-planner/proto/auron.proto +++ b/native-engine/auron-planner/proto/auron.proto @@ -148,6 +148,8 @@ enum AggFunction { FIRST = 7; FIRST_IGNORES_NULL = 8; BLOOM_FILTER = 9; + LAST = 10; + LAST_IGNORES_NULL = 11; BRICKHOUSE_COLLECT = 1000; BRICKHOUSE_COMBINE_UNIQUE = 1001; UDAF = 1002; diff --git a/native-engine/auron-planner/src/lib.rs b/native-engine/auron-planner/src/lib.rs index c118cd2b0..4502a303d 100644 --- a/native-engine/auron-planner/src/lib.rs +++ b/native-engine/auron-planner/src/lib.rs @@ -138,6 +138,8 @@ impl From for AggFunction { protobuf::AggFunction::CollectSet => AggFunction::CollectSet, protobuf::AggFunction::First => AggFunction::First, protobuf::AggFunction::FirstIgnoresNull => AggFunction::FirstIgnoresNull, + protobuf::AggFunction::Last => AggFunction::Last, + protobuf::AggFunction::LastIgnoresNull => AggFunction::LastIgnoresNull, protobuf::AggFunction::BloomFilter => AggFunction::BloomFilter, protobuf::AggFunction::BrickhouseCollect => AggFunction::BrickhouseCollect, protobuf::AggFunction::BrickhouseCombineUnique => AggFunction::BrickhouseCombineUnique, diff --git a/native-engine/auron-planner/src/planner.rs b/native-engine/auron-planner/src/planner.rs index b1ee15843..578ef0044 100644 --- a/native-engine/auron-planner/src/planner.rs +++ b/native-engine/auron-planner/src/planner.rs @@ -702,6 +702,12 @@ impl PhysicalPlanner { protobuf::AggFunction::FirstIgnoresNull => { WindowFunction::Agg(AggFunction::FirstIgnoresNull) } + protobuf::AggFunction::Last => { + WindowFunction::Agg(AggFunction::Last) + } + protobuf::AggFunction::LastIgnoresNull => { + WindowFunction::Agg(AggFunction::LastIgnoresNull) + } protobuf::AggFunction::BloomFilter => { WindowFunction::Agg(AggFunction::BloomFilter) } diff --git a/native-engine/datafusion-ext-plans/src/agg/agg.rs b/native-engine/datafusion-ext-plans/src/agg/agg.rs index 5eb4c3dad..99adc470d 100644 --- a/native-engine/datafusion-ext-plans/src/agg/agg.rs +++ b/native-engine/datafusion-ext-plans/src/agg/agg.rs @@ -33,6 +33,8 @@ use crate::agg::{ count::AggCount, first::AggFirst, first_ignores_null::AggFirstIgnoresNull, + last::AggLast, + last_ignores_null::AggLastIgnoresNull, maxmin::{AggMax, AggMin}, spark_udaf_wrapper::SparkUDAFWrapper, sum::AggSum, @@ -212,6 +214,14 @@ pub fn create_agg( let dt = children[0].data_type(input_schema)?; Arc::new(AggFirstIgnoresNull::try_new(children[0].clone(), dt)?) } + AggFunction::Last => { + let dt = children[0].data_type(input_schema)?; + Arc::new(AggLast::try_new(children[0].clone(), dt)?) + } + AggFunction::LastIgnoresNull => { + let dt = children[0].data_type(input_schema)?; + Arc::new(AggLastIgnoresNull::try_new(children[0].clone(), dt)?) + } AggFunction::BloomFilter => { let dt = children[0].data_type(input_schema)?; let empty_batch = RecordBatch::new_empty(Arc::new(Schema::empty())); diff --git a/native-engine/datafusion-ext-plans/src/agg/last.rs b/native-engine/datafusion-ext-plans/src/agg/last.rs new file mode 100644 index 000000000..aa1989397 --- /dev/null +++ b/native-engine/datafusion-ext-plans/src/agg/last.rs @@ -0,0 +1,236 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{ + any::Any, + fmt::{Debug, Formatter}, + sync::Arc, +}; + +use arrow::{array::*, datatypes::*}; +use datafusion::{ + common::{Result, ScalarValue}, + physical_expr::PhysicalExprRef, +}; +use datafusion_ext_commons::{downcast_any, scalar_value::compacted_scalar_value_from_array}; + +use crate::{ + agg::{ + Agg, + acc::{ + AccBooleanColumn, AccBytes, AccBytesColumn, AccColumnRef, AccPrimColumn, + AccScalarValueColumn, create_acc_generic_column, + }, + agg::IdxSelection, + }, + idx_for_zipped, +}; + +pub struct AggLast { + child: PhysicalExprRef, + data_type: DataType, + acc_array_data_types: Vec, +} + +impl AggLast { + pub fn try_new(child: PhysicalExprRef, data_type: DataType) -> Result { + let acc_array_data_types = vec![data_type.clone()]; + Ok(Self { + child, + data_type, + acc_array_data_types, + }) + } +} + +impl Debug for AggLast { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "Last({:?})", self.child) + } +} + +impl Agg for AggLast { + fn as_any(&self) -> &dyn Any { + self + } + + fn exprs(&self) -> Vec { + vec![self.child.clone()] + } + + fn with_new_exprs(&self, exprs: Vec) -> Result> { + Ok(Arc::new(Self::try_new( + exprs[0].clone(), + self.data_type.clone(), + )?)) + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn nullable(&self) -> bool { + true + } + + fn create_acc_column(&self, num_rows: usize) -> AccColumnRef { + create_acc_generic_column(self.data_type.clone(), num_rows) + } + + fn acc_array_data_types(&self) -> &[DataType] { + &self.acc_array_data_types + } + + fn partial_update( + &self, + accs: &mut AccColumnRef, + acc_idx: IdxSelection<'_>, + partial_args: &[ArrayRef], + partial_arg_idx: IdxSelection<'_>, + ) -> Result<()> { + let partial_arg = &partial_args[0]; + accs.ensure_size(acc_idx); + + macro_rules! handle_bytes { + ($array:expr) => {{ + let accs = downcast_any!(accs, mut AccBytesColumn)?; + let partial_arg = $array; + idx_for_zipped! { + ((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => { + if partial_arg.is_valid(partial_arg_idx) { + accs.set_value(acc_idx, Some(AccBytes::from(partial_arg.value(partial_arg_idx).as_ref()))); + } else { + accs.set_value(acc_idx, None); + } + } + } + }} + } + + downcast_primitive_array! { + partial_arg => { + if let Ok(accs) = downcast_any!(accs, mut AccPrimColumn<_>) { + idx_for_zipped! { + ((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => { + if partial_arg.is_valid(partial_arg_idx) { + accs.set_value(acc_idx, Some(partial_arg.value(partial_arg_idx))); + } else { + accs.set_value(acc_idx, None); + } + } + } + } + } + DataType::Boolean => { + let accs = downcast_any!(accs, mut AccBooleanColumn)?; + let partial_arg = downcast_any!(partial_arg, BooleanArray)?; + idx_for_zipped! { + ((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => { + if partial_arg.is_valid(partial_arg_idx) { + accs.set_value(acc_idx, Some(partial_arg.value(partial_arg_idx))); + } else { + accs.set_value(acc_idx, None); + } + } + } + } + DataType::Utf8 => handle_bytes!(downcast_any!(partial_arg, StringArray)?), + DataType::Binary => handle_bytes!(downcast_any!(partial_arg, BinaryArray)?), + _other => { + let accs = downcast_any!(accs, mut AccScalarValueColumn)?; + let null_value = ScalarValue::try_from(&self.data_type)?; + idx_for_zipped! { + ((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => { + if partial_arg.is_valid(partial_arg_idx) { + accs.set_value(acc_idx, compacted_scalar_value_from_array(partial_arg, partial_arg_idx)?); + } else { + accs.set_value(acc_idx, null_value.clone()); + } + } + } + } + } + Ok(()) + } + + fn partial_merge( + &self, + accs: &mut AccColumnRef, + acc_idx: IdxSelection<'_>, + merging_accs: &mut AccColumnRef, + merging_acc_idx: IdxSelection<'_>, + ) -> Result<()> { + accs.ensure_size(acc_idx); + + // For last, always overwrite with the merging accumulator's value + macro_rules! handle_primitive { + ($ty:ty) => {{ + type TNative = <$ty as ArrowPrimitiveType>::Native; + let accs = downcast_any!(accs, mut AccPrimColumn)?; + let merging_accs = downcast_any!(merging_accs, mut AccPrimColumn<_>)?; + idx_for_zipped! { + ((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => { + accs.set_value(acc_idx, merging_accs.value(merging_acc_idx)); + } + } + }} + } + + macro_rules! handle_boolean { + () => {{ + let accs = downcast_any!(accs, mut AccBooleanColumn)?; + let merging_accs = downcast_any!(merging_accs, mut AccBooleanColumn)?; + idx_for_zipped! { + ((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => { + accs.set_value(acc_idx, merging_accs.value(merging_acc_idx)); + } + } + }}; + } + + macro_rules! handle_bytes { + () => {{ + let accs = downcast_any!(accs, mut AccBytesColumn)?; + let merging_accs = downcast_any!(merging_accs, mut AccBytesColumn)?; + idx_for_zipped! { + ((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => { + accs.set_value(acc_idx, merging_accs.take_value(merging_acc_idx)); + } + } + }}; + } + + downcast_primitive! { + (&self.data_type) => (handle_primitive), + DataType::Boolean => handle_boolean!(), + DataType::Utf8 | DataType::Binary => handle_bytes!(), + DataType::Null => {} + _ => { + let accs = downcast_any!(accs, mut AccScalarValueColumn)?; + let merging_accs = downcast_any!(merging_accs, mut AccScalarValueColumn)?; + idx_for_zipped! { + ((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => { + accs.set_value(acc_idx, merging_accs.take_value(merging_acc_idx)); + } + } + } + } + Ok(()) + } + + fn final_merge(&self, accs: &mut AccColumnRef, acc_idx: IdxSelection<'_>) -> Result { + Ok(accs.freeze_to_arrays(acc_idx)?[0].clone()) + } +} diff --git a/native-engine/datafusion-ext-plans/src/agg/last_ignores_null.rs b/native-engine/datafusion-ext-plans/src/agg/last_ignores_null.rs new file mode 100644 index 000000000..fde2afd94 --- /dev/null +++ b/native-engine/datafusion-ext-plans/src/agg/last_ignores_null.rs @@ -0,0 +1,232 @@ +// Licensed to the Apache Software Foundation (ASF) under one or more +// contributor license agreements. See the NOTICE file distributed with +// this work for additional information regarding copyright ownership. +// The ASF licenses this file to You under the Apache License, Version 2.0 +// (the "License"); you may not use this file except in compliance with +// the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{ + any::Any, + fmt::{Debug, Formatter}, + sync::Arc, +}; + +use arrow::{array::*, datatypes::*}; +use datafusion::{common::Result, physical_expr::PhysicalExprRef}; +use datafusion_ext_commons::{downcast_any, scalar_value::compacted_scalar_value_from_array}; + +use crate::{ + agg::{ + Agg, + acc::{ + AccBooleanColumn, AccBytes, AccBytesColumn, AccColumnRef, AccPrimColumn, + AccScalarValueColumn, create_acc_generic_column, + }, + agg::IdxSelection, + }, + idx_for_zipped, +}; + +pub struct AggLastIgnoresNull { + child: PhysicalExprRef, + data_type: DataType, + acc_array_data_types: Vec, +} + +impl AggLastIgnoresNull { + pub fn try_new(child: PhysicalExprRef, data_type: DataType) -> Result { + let acc_array_data_types = vec![data_type.clone()]; + Ok(Self { + child, + data_type, + acc_array_data_types, + }) + } +} + +impl Debug for AggLastIgnoresNull { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "LastIgnoresNull({:?})", self.child) + } +} + +impl Agg for AggLastIgnoresNull { + fn as_any(&self) -> &dyn Any { + self + } + + fn exprs(&self) -> Vec { + vec![self.child.clone()] + } + + fn with_new_exprs(&self, exprs: Vec) -> Result> { + Ok(Arc::new(Self::try_new( + exprs[0].clone(), + self.data_type.clone(), + )?)) + } + + fn data_type(&self) -> &DataType { + &self.data_type + } + + fn nullable(&self) -> bool { + true + } + + fn create_acc_column(&self, num_rows: usize) -> AccColumnRef { + create_acc_generic_column(self.data_type.clone(), num_rows) + } + + fn acc_array_data_types(&self) -> &[DataType] { + &self.acc_array_data_types + } + + fn partial_update( + &self, + accs: &mut AccColumnRef, + acc_idx: IdxSelection<'_>, + partial_args: &[ArrayRef], + partial_arg_idx: IdxSelection<'_>, + ) -> Result<()> { + let partial_arg = &partial_args[0]; + accs.ensure_size(acc_idx); + + macro_rules! handle_bytes { + ($array:expr) => {{ + let accs = downcast_any!(accs, mut AccBytesColumn)?; + let partial_arg = $array; + idx_for_zipped! { + ((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => { + if partial_arg.is_valid(partial_arg_idx) { + accs.set_value(acc_idx, Some(AccBytes::from(partial_arg.value(partial_arg_idx).as_ref()))); + } + } + } + }} + } + + downcast_primitive_array! { + partial_arg => { + if let Ok(accs) = downcast_any!(accs, mut AccPrimColumn<_>) { + idx_for_zipped! { + ((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => { + if partial_arg.is_valid(partial_arg_idx) { + accs.set_value(acc_idx, Some(partial_arg.value(partial_arg_idx))); + } + } + } + } + } + DataType::Boolean => { + let accs = downcast_any!(accs, mut AccBooleanColumn)?; + let partial_arg = downcast_any!(partial_arg, BooleanArray)?; + idx_for_zipped! { + ((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => { + if partial_arg.is_valid(partial_arg_idx) { + accs.set_value(acc_idx, Some(partial_arg.value(partial_arg_idx))); + } + } + } + } + DataType::Utf8 => handle_bytes!(downcast_any!(partial_arg, StringArray)?), + DataType::Binary => handle_bytes!(downcast_any!(partial_arg, BinaryArray)?), + _other => { + let accs = downcast_any!(accs, mut AccScalarValueColumn)?; + idx_for_zipped! { + ((acc_idx, partial_arg_idx) in (acc_idx, partial_arg_idx)) => { + if partial_arg.is_valid(partial_arg_idx) { + accs.set_value(acc_idx, compacted_scalar_value_from_array(partial_arg, partial_arg_idx)?); + } + } + } + } + } + Ok(()) + } + + fn partial_merge( + &self, + accs: &mut AccColumnRef, + acc_idx: IdxSelection<'_>, + merging_accs: &mut AccColumnRef, + merging_acc_idx: IdxSelection<'_>, + ) -> Result<()> { + accs.ensure_size(acc_idx); + + // primitive types + macro_rules! handle_primitive { + ($ty:ty) => {{ + type TNative = <$ty as ArrowPrimitiveType>::Native; + let accs = downcast_any!(accs, mut AccPrimColumn)?; + let merging_accs = downcast_any!(merging_accs, mut AccPrimColumn<_>)?; + idx_for_zipped! { + ((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => { + if merging_accs.value(merging_acc_idx).is_some() { + accs.set_value(acc_idx, merging_accs.value(merging_acc_idx)); + } + } + } + }} + } + + macro_rules! handle_boolean { + () => {{ + let accs = downcast_any!(accs, mut AccBooleanColumn)?; + let merging_accs = downcast_any!(merging_accs, mut AccBooleanColumn)?; + idx_for_zipped! { + ((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => { + if merging_accs.value(merging_acc_idx).is_some() { + accs.set_value(acc_idx, merging_accs.value(merging_acc_idx)); + } + } + } + }}; + } + + macro_rules! handle_bytes { + () => {{ + let accs = downcast_any!(accs, mut AccBytesColumn)?; + let merging_accs = downcast_any!(merging_accs, mut AccBytesColumn)?; + idx_for_zipped! { + ((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => { + if merging_accs.value(merging_acc_idx).is_some() { + accs.set_value(acc_idx, merging_accs.take_value(merging_acc_idx)); + } + } + } + }}; + } + + downcast_primitive! { + (&self.data_type) => (handle_primitive), + DataType::Boolean => handle_boolean!(), + DataType::Utf8 | DataType::Binary => handle_bytes!(), + DataType::Null => {} + _ => { + let accs = downcast_any!(accs, mut AccScalarValueColumn)?; + let merging_accs = downcast_any!(merging_accs, mut AccScalarValueColumn)?; + idx_for_zipped! { + ((acc_idx, merging_acc_idx) in (acc_idx, merging_acc_idx)) => { + if !merging_accs.value(merging_acc_idx).is_null() { + accs.set_value(acc_idx, merging_accs.take_value(merging_acc_idx)); + } + } + } + } + } + Ok(()) + } + + fn final_merge(&self, accs: &mut AccColumnRef, acc_idx: IdxSelection<'_>) -> Result { + Ok(accs.freeze_to_arrays(acc_idx)?[0].clone()) + } +} diff --git a/native-engine/datafusion-ext-plans/src/agg/mod.rs b/native-engine/datafusion-ext-plans/src/agg/mod.rs index 565e15b16..9867524ca 100644 --- a/native-engine/datafusion-ext-plans/src/agg/mod.rs +++ b/native-engine/datafusion-ext-plans/src/agg/mod.rs @@ -25,6 +25,8 @@ pub mod collect; pub mod count; pub mod first; pub mod first_ignores_null; +pub mod last; +pub mod last_ignores_null; pub mod maxmin; pub mod spark_udaf_wrapper; pub mod sum; @@ -69,6 +71,8 @@ pub enum AggFunction { Min, First, FirstIgnoresNull, + Last, + LastIgnoresNull, CollectList, CollectSet, BloomFilter, diff --git a/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronWindowSuite.scala b/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronWindowSuite.scala index 5f1e78fb8..9f5aadd9f 100644 --- a/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronWindowSuite.scala +++ b/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronWindowSuite.scala @@ -61,4 +61,149 @@ class AuronWindowSuite extends AuronQueryTest with BaseAuronSQLSuite with AuronS } } } + + test("first_value window function") { + withSQLConf("spark.auron.enable.window" -> "true") { + withTable("t1") { + sql("create table t1(id int, grp int, v string) using parquet") + sql("insert into t1 values (1, 1, null), (2, 1, 'b'), (3, 1, 'c'), (4, 2, 'x')") + + checkSparkAnswerAndOperator("""select + | id, + | grp, + | v, + | first_value(v) over (partition by grp order by id) as first_v + |from t1 + |""".stripMargin) + } + } + } + + test("first_value window function with ignore nulls") { + if (AuronTestUtils.isSparkV32OrGreater) { + withSQLConf("spark.auron.enable.window" -> "true") { + withTable("t1") { + sql("create table t1(id int, grp int, v string) using parquet") + sql("insert into t1 values (1, 1, null), (2, 1, 'b'), (3, 1, 'c'), (4, 2, 'x')") + + checkSparkAnswerAndOperator("""select + | id, + | grp, + | v, + | first_value(v) ignore nulls over (partition by grp order by id) as first_non_null_v + |from t1 + |""".stripMargin) + } + } + } + } + + test("last_value window function") { + withSQLConf("spark.auron.enable.window" -> "true") { + withTable("t1") { + sql("create table t1(id int, grp int, v string) using parquet") + sql("insert into t1 values (1, 1, 'a'), (2, 1, null), (3, 1, 'c'), (4, 2, 'x')") + + checkSparkAnswerAndOperator("""select + | id, + | grp, + | v, + | last_value(v) over (partition by grp order by id) as last_v + |from t1 + |""".stripMargin) + } + } + } + + test("last_value window function with ignore nulls") { + if (AuronTestUtils.isSparkV32OrGreater) { + withSQLConf("spark.auron.enable.window" -> "true") { + withTable("t1") { + sql("create table t1(id int, grp int, v string) using parquet") + sql("insert into t1 values (1, 1, 'a'), (2, 1, null), (3, 1, 'c'), (4, 2, 'x')") + + checkSparkAnswerAndOperator("""select + | id, + | grp, + | v, + | last_value(v) ignore nulls over (partition by grp order by id) as last_non_null_v + |from t1 + |""".stripMargin) + } + } + } + } + + test("last_value window function over int and boolean columns") { + withSQLConf("spark.auron.enable.window" -> "true") { + withTable("t1") { + sql("create table t1(id int, grp int, int_v int, bool_v boolean) using parquet") + sql("""insert into t1 values + | (1, 1, 10, true), + | (2, 1, null, false), + | (3, 1, 30, null), + | (4, 2, null, true), + | (5, 2, 50, false) + |""".stripMargin) + + checkSparkAnswerAndOperator("""select + | id, + | grp, + | int_v, + | bool_v, + | last_value(int_v) over (partition by grp order by id) as last_int_v, + | last_value(bool_v) over (partition by grp order by id) as last_bool_v + |from t1 + |""".stripMargin) + } + } + } + + test("last_value window function with ignore nulls over int and boolean columns") { + if (AuronTestUtils.isSparkV32OrGreater) { + withSQLConf("spark.auron.enable.window" -> "true") { + withTable("t1") { + sql("create table t1(id int, grp int, int_v int, bool_v boolean) using parquet") + sql("""insert into t1 values + | (1, 1, 10, true), + | (2, 1, null, false), + | (3, 1, 30, null), + | (4, 2, null, true), + | (5, 2, 50, false) + |""".stripMargin) + + checkSparkAnswerAndOperator("""select + | id, + | grp, + | int_v, + | bool_v, + | last_value(int_v) ignore nulls over (partition by grp order by id) as last_int_v, + | last_value(bool_v) ignore nulls over (partition by grp order by id) as last_bool_v + |from t1 + |""".stripMargin) + } + } + } + } + + test("last aggregate function") { + withTable("t1") { + sql("create table t1(grp int, v string) using parquet") + sql("""insert into t1 values + | (1, 'a'), + | (1, 'a'), + | (2, null), + | (2, null), + | (3, 'z') + |""".stripMargin) + + checkSparkAnswerAndOperator("""select + | grp, + | last(v) as last_v, + | last(v, true) as last_ignore_nulls_v + |from t1 + |group by grp + |""".stripMargin) + } + } } diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala index 378a8d662..3119dd861 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/auron/NativeConverters.scala @@ -37,7 +37,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.sql.auron.util.Using import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions._ -import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Average, CollectList, CollectSet, Count, DeclarativeAggregate, First, Max, Min, Sum, TypedImperativeAggregate} +import org.apache.spark.sql.catalyst.expressions.aggregate.{AggregateExpression, AggregateFunction, Average, CollectList, CollectSet, Count, DeclarativeAggregate, First, Last, Max, Min, Sum, TypedImperativeAggregate} import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.expressions.codegen.ExprCode import org.apache.spark.sql.catalyst.optimizer.NormalizeNaNAndZero @@ -1276,6 +1276,18 @@ object NativeConverters extends Logging { }) aggBuilder.addChildren(convertExpr(child)) + case Last(child, ignoresNullExpr) => + val ignoresNull = ignoresNullExpr.asInstanceOf[Any] match { + case Literal(v: Boolean, BooleanType) => v + case v: Boolean => v + } + aggBuilder.setAggFunction(if (ignoresNull) { + pb.AggFunction.LAST_IGNORES_NULL + } else { + pb.AggFunction.LAST + }) + aggBuilder.addChildren(convertExpr(child)) + case CollectList(child, _, _) => aggBuilder.setAggFunction(pb.AggFunction.COLLECT_LIST) aggBuilder.addChildren(convertExpr(child)) diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeWindowBase.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeWindowBase.scala index f0336be8b..dcd61aaa7 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeWindowBase.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeWindowBase.scala @@ -40,6 +40,8 @@ import org.apache.spark.sql.catalyst.expressions.SortOrder import org.apache.spark.sql.catalyst.expressions.WindowExpression import org.apache.spark.sql.catalyst.expressions.aggregate.Average import org.apache.spark.sql.catalyst.expressions.aggregate.Count +import org.apache.spark.sql.catalyst.expressions.aggregate.First +import org.apache.spark.sql.catalyst.expressions.aggregate.Last import org.apache.spark.sql.catalyst.expressions.aggregate.Max import org.apache.spark.sql.catalyst.expressions.aggregate.Min import org.apache.spark.sql.catalyst.expressions.aggregate.Sum @@ -50,6 +52,7 @@ import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.UnaryExecNode import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.types.BooleanType import org.apache.auron.{protobuf => pb} import org.apache.auron.metric.SparkMetricNode @@ -234,6 +237,32 @@ abstract class NativeWindowBase( windowExprBuilder.setAggFunc(pb.AggFunction.COUNT) windowExprBuilder.addChildren(NativeConverters.convertExpr(child)) + case First(child, ignoresNullExpr) => + assert( + spec.frameSpecification == RowNumber().frame, // only supports RowFrame(Unbounded, CurrentRow) + s"window frame not supported: ${spec.frameSpecification}") + val ignoresNull = ignoresNullExpr.asInstanceOf[Any] match { + case Literal(v: Boolean, BooleanType) => v + case v: Boolean => v + } + windowExprBuilder.setFuncType(pb.WindowFunctionType.Agg) + windowExprBuilder.setAggFunc( + if (ignoresNull) pb.AggFunction.FIRST_IGNORES_NULL else pb.AggFunction.FIRST) + windowExprBuilder.addChildren(NativeConverters.convertExpr(child)) + + case Last(child, ignoresNullExpr) => + assert( + spec.frameSpecification == RowNumber().frame, // only supports RowFrame(Unbounded, CurrentRow) + s"window frame not supported: ${spec.frameSpecification}") + val ignoresNull = ignoresNullExpr.asInstanceOf[Any] match { + case Literal(v: Boolean, BooleanType) => v + case v: Boolean => v + } + windowExprBuilder.setFuncType(pb.WindowFunctionType.Agg) + windowExprBuilder.setAggFunc( + if (ignoresNull) pb.AggFunction.LAST_IGNORES_NULL else pb.AggFunction.LAST) + windowExprBuilder.addChildren(NativeConverters.convertExpr(child)) + case other => throw new NotImplementedError(s"window function not supported: $other") }