Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,11 @@

use std::sync::Arc;

use arrow::{array::ArrayRef, datatypes::DataType, record_batch::RecordBatch};
use arrow::{
array::{ArrayRef, new_empty_array},
datatypes::DataType,
record_batch::RecordBatch,
};
use datafusion::{
common::{DataFusionError, Result, ScalarValue},
physical_expr::PhysicalExprRef,
Expand All @@ -36,15 +40,19 @@ impl LeadProcessor {

impl WindowFunctionProcessor for LeadProcessor {
fn process_batch(&mut self, context: &WindowContext, batch: &RecordBatch) -> Result<ArrayRef> {
assert_eq!(
self.children.len(),
3,
"lead expects input/offset/default children",
);
if self.children.len() != 3 {
return Err(DataFusionError::Execution(format!(
"lead expects input/offset/default children, got {}",
self.children.len()
)));
}

let input_values = self.children[0]
.evaluate(batch)
.and_then(|v| v.into_array(batch.num_rows()))?;
if batch.num_rows() == 0 {
return Ok(new_empty_array(input_values.data_type()));
}

let offset_values = self.children[1]
.evaluate(batch)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -61,4 +61,43 @@ class AuronWindowSuite extends AuronQueryTest with BaseAuronSQLSuite with AuronS
}
}
}

test("lag window function") {
withSQLConf("spark.auron.enable.window" -> "true") {
withTable("t1") {
sql("create table t1(id int, grp int, v string) using parquet")
sql("insert into t1 values (1, 1, 'a'), (2, 1, null), (3, 1, 'c'), (4, 2, 'x')")

checkSparkAnswerAndOperator("""select
| id,
| grp,
| v,
| lag(v) over (partition by grp order by id) as prev_v,
| lag(v, 2, 'fallback') over (partition by grp order by id) as prev2_v
|from t1
|""".stripMargin)
}
}
}

test("lag window function with ignore nulls falls back") {
if (AuronTestUtils.isSparkV32OrGreater) {
withSQLConf("spark.auron.enable.window" -> "true") {
withTable("t1") {
sql("create table t1(id int, grp int, v string) using parquet")
sql("insert into t1 values (1, 1, 'a'), (2, 1, null), (3, 1, 'c'), (4, 2, 'x')")

val df = checkSparkAnswer("""select
| id,
| grp,
| lag(v, 1, 'fallback') ignore nulls
| over (partition by grp order by id) as prev_non_null_v
|from t1
|""".stripMargin)
val plan = stripAQEPlan(df.queryExecution.executedPlan)
assert(plan.collectFirst { case _: NativeWindowBase => true }.isEmpty)
}
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package org.apache.spark.sql.execution.auron.plan

import scala.collection.immutable.SortedMap
import scala.jdk.CollectionConverters._
import scala.util.Try

import org.apache.spark.OneToOneDependency
import org.apache.spark.sql.auron.NativeConverters
Expand All @@ -29,6 +30,7 @@ import org.apache.spark.sql.catalyst.expressions.Attribute
import org.apache.spark.sql.catalyst.expressions.CumeDist
import org.apache.spark.sql.catalyst.expressions.DenseRank
import org.apache.spark.sql.catalyst.expressions.Expression
import org.apache.spark.sql.catalyst.expressions.Lag
import org.apache.spark.sql.catalyst.expressions.Lead
import org.apache.spark.sql.catalyst.expressions.Literal
import org.apache.spark.sql.catalyst.expressions.NamedExpression
Expand Down Expand Up @@ -92,14 +94,12 @@ abstract class NativeWindowBase(
override def requiredChildOrdering: Seq[Seq[SortOrder]] =
Seq(partitionSpec.map(SortOrder(_, Ascending)) ++ orderSpec)

private def leadIgnoreNulls(expr: Lead): Boolean =
expr.getClass.getMethods
.find(method => method.getName == "ignoreNulls" && method.getParameterCount == 0)
.exists(method => method.invoke(expr).asInstanceOf[Boolean])

private def invokeNoArg[T](expr: Expression, methodName: String): T =
expr.getClass.getMethod(methodName).invoke(expr).asInstanceOf[T]

private def ignoreNulls(expr: Expression): Boolean =
Try(invokeNoArg[Boolean](expr, "ignoreNulls")).getOrElse(false)

private def isNthValue(expr: Expression): Boolean = expr.getClass.getSimpleName == "NthValue"

private def nthValueInput(expr: Expression): Expression = invokeNoArg[Expression](expr, "input")
Expand Down Expand Up @@ -180,7 +180,18 @@ abstract class NativeWindowBase(
assert(
spec.frameSpecification == e.frame,
s"window frame not supported: ${spec.frameSpecification}")
assert(!leadIgnoreNulls(e), "window function not supported: lead with IGNORE NULLS")
assert(!ignoreNulls(e), "window function not supported: lead with IGNORE NULLS")
windowExprBuilder.setFuncType(pb.WindowFunctionType.Window)
windowExprBuilder.setWindowFunc(pb.WindowFunction.LEAD)
windowExprBuilder.addChildren(NativeConverters.convertExpr(e.input))
windowExprBuilder.addChildren(NativeConverters.convertExpr(e.offset))
windowExprBuilder.addChildren(NativeConverters.convertExpr(e.default))

case e: Lag =>
assert(
spec.frameSpecification == e.frame,
s"window frame not supported: ${spec.frameSpecification}")
assert(!ignoreNulls(e), "window function not supported: lag with IGNORE NULLS")
windowExprBuilder.setFuncType(pb.WindowFunctionType.Window)
windowExprBuilder.setWindowFunc(pb.WindowFunction.LEAD)
windowExprBuilder.addChildren(NativeConverters.convertExpr(e.input))
Expand Down