diff --git a/native-engine/datafusion-ext-plans/src/window/processors/lead_processor.rs b/native-engine/datafusion-ext-plans/src/window/processors/lead_processor.rs index c8b73face..3623a8a78 100644 --- a/native-engine/datafusion-ext-plans/src/window/processors/lead_processor.rs +++ b/native-engine/datafusion-ext-plans/src/window/processors/lead_processor.rs @@ -15,7 +15,11 @@ use std::sync::Arc; -use arrow::{array::ArrayRef, datatypes::DataType, record_batch::RecordBatch}; +use arrow::{ + array::{ArrayRef, new_empty_array}, + datatypes::DataType, + record_batch::RecordBatch, +}; use datafusion::{ common::{DataFusionError, Result, ScalarValue}, physical_expr::PhysicalExprRef, @@ -36,15 +40,19 @@ impl LeadProcessor { impl WindowFunctionProcessor for LeadProcessor { fn process_batch(&mut self, context: &WindowContext, batch: &RecordBatch) -> Result { - assert_eq!( - self.children.len(), - 3, - "lead expects input/offset/default children", - ); + if self.children.len() != 3 { + return Err(DataFusionError::Execution(format!( + "lead expects input/offset/default children, got {}", + self.children.len() + ))); + } let input_values = self.children[0] .evaluate(batch) .and_then(|v| v.into_array(batch.num_rows()))?; + if batch.num_rows() == 0 { + return Ok(new_empty_array(input_values.data_type())); + } let offset_values = self.children[1] .evaluate(batch) diff --git a/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronWindowSuite.scala b/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronWindowSuite.scala index 5f1e78fb8..f875f2442 100644 --- a/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronWindowSuite.scala +++ b/spark-extension-shims-spark/src/test/scala/org/apache/auron/AuronWindowSuite.scala @@ -61,4 +61,43 @@ class AuronWindowSuite extends AuronQueryTest with BaseAuronSQLSuite with AuronS } } } + + test("lag window function") { + withSQLConf("spark.auron.enable.window" -> "true") { + withTable("t1") { + sql("create table t1(id int, grp int, v string) using parquet") + sql("insert into t1 values (1, 1, 'a'), (2, 1, null), (3, 1, 'c'), (4, 2, 'x')") + + checkSparkAnswerAndOperator("""select + | id, + | grp, + | v, + | lag(v) over (partition by grp order by id) as prev_v, + | lag(v, 2, 'fallback') over (partition by grp order by id) as prev2_v + |from t1 + |""".stripMargin) + } + } + } + + test("lag window function with ignore nulls falls back") { + if (AuronTestUtils.isSparkV32OrGreater) { + withSQLConf("spark.auron.enable.window" -> "true") { + withTable("t1") { + sql("create table t1(id int, grp int, v string) using parquet") + sql("insert into t1 values (1, 1, 'a'), (2, 1, null), (3, 1, 'c'), (4, 2, 'x')") + + val df = checkSparkAnswer("""select + | id, + | grp, + | lag(v, 1, 'fallback') ignore nulls + | over (partition by grp order by id) as prev_non_null_v + |from t1 + |""".stripMargin) + val plan = stripAQEPlan(df.queryExecution.executedPlan) + assert(plan.collectFirst { case _: NativeWindowBase => true }.isEmpty) + } + } + } + } } diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeWindowBase.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeWindowBase.scala index f0336be8b..1b05423ca 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeWindowBase.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/auron/plan/NativeWindowBase.scala @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.auron.plan import scala.collection.immutable.SortedMap import scala.jdk.CollectionConverters._ +import scala.util.Try import org.apache.spark.OneToOneDependency import org.apache.spark.sql.auron.NativeConverters @@ -29,6 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute import org.apache.spark.sql.catalyst.expressions.CumeDist import org.apache.spark.sql.catalyst.expressions.DenseRank import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.Lag import org.apache.spark.sql.catalyst.expressions.Lead import org.apache.spark.sql.catalyst.expressions.Literal import org.apache.spark.sql.catalyst.expressions.NamedExpression @@ -92,14 +94,12 @@ abstract class NativeWindowBase( override def requiredChildOrdering: Seq[Seq[SortOrder]] = Seq(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec) - private def leadIgnoreNulls(expr: Lead): Boolean = - expr.getClass.getMethods - .find(method => method.getName == "ignoreNulls" && method.getParameterCount == 0) - .exists(method => method.invoke(expr).asInstanceOf[Boolean]) - private def invokeNoArg[T](expr: Expression, methodName: String): T = expr.getClass.getMethod(methodName).invoke(expr).asInstanceOf[T] + private def ignoreNulls(expr: Expression): Boolean = + Try(invokeNoArg[Boolean](expr, "ignoreNulls")).getOrElse(false) + private def isNthValue(expr: Expression): Boolean = expr.getClass.getSimpleName == "NthValue" private def nthValueInput(expr: Expression): Expression = invokeNoArg[Expression](expr, "input") @@ -180,7 +180,18 @@ abstract class NativeWindowBase( assert( spec.frameSpecification == e.frame, s"window frame not supported: ${spec.frameSpecification}") - assert(!leadIgnoreNulls(e), "window function not supported: lead with IGNORE NULLS") + assert(!ignoreNulls(e), "window function not supported: lead with IGNORE NULLS") + windowExprBuilder.setFuncType(pb.WindowFunctionType.Window) + windowExprBuilder.setWindowFunc(pb.WindowFunction.LEAD) + windowExprBuilder.addChildren(NativeConverters.convertExpr(e.input)) + windowExprBuilder.addChildren(NativeConverters.convertExpr(e.offset)) + windowExprBuilder.addChildren(NativeConverters.convertExpr(e.default)) + + case e: Lag => + assert( + spec.frameSpecification == e.frame, + s"window frame not supported: ${spec.frameSpecification}") + assert(!ignoreNulls(e), "window function not supported: lag with IGNORE NULLS") windowExprBuilder.setFuncType(pb.WindowFunctionType.Window) windowExprBuilder.setWindowFunc(pb.WindowFunction.LEAD) windowExprBuilder.addChildren(NativeConverters.convertExpr(e.input))