Skip to content

Commit 43063e2

Browse files
viiryaHyukjinKwon
authored andcommitted
[SPARK-27217][SQL] Nested column aliasing for more operators which can prune nested column
### What changes were proposed in this pull request? Currently we only push nested column pruning from a Project through a few operators such as LIMIT, SAMPLE, etc. There are a few operators like Aggregate, Expand which can prune nested columns by themselves, without a Project on top. This patch extends the feature to those operators. ### Why are the changes needed? Currently nested column pruning only applied on a few cases. It limits the benefit of nested column pruning. Extending nested column pruning coverage to make this feature more generally applied through different queries. ### Does this PR introduce _any_ user-facing change? Yes. More SQL operators are covered by nested column pruning. ### How was this patch tested? Added unit test, end-to-end tests. Closes #28560 from viirya/SPARK-27217-2. Authored-by: Liang-Chi Hsieh <[email protected]> Signed-off-by: HyukjinKwon <[email protected]>
1 parent 82ff29b commit 43063e2

File tree

3 files changed

+190
-10
lines changed

3 files changed

+190
-10
lines changed

sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasing.scala

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,11 @@ object NestedColumnAliasing {
3535
case Project(projectList, child)
3636
if SQLConf.get.nestedSchemaPruningEnabled && canProjectPushThrough(child) =>
3737
getAliasSubMap(projectList)
38+
39+
case plan if SQLConf.get.nestedSchemaPruningEnabled && canPruneOn(plan) =>
40+
val exprCandidatesToPrune = plan.expressions
41+
getAliasSubMap(exprCandidatesToPrune, plan.producedAttributes.toSeq)
42+
3843
case _ => None
3944
}
4045

@@ -48,7 +53,11 @@ object NestedColumnAliasing {
4853
case Project(projectList, child) =>
4954
Project(
5055
getNewProjectList(projectList, nestedFieldToAlias),
51-
replaceChildrenWithAliases(child, attrToAliases))
56+
replaceChildrenWithAliases(child, nestedFieldToAlias, attrToAliases))
57+
58+
// The operators reaching here was already guarded by `canPruneOn`.
59+
case other =>
60+
replaceChildrenWithAliases(other, nestedFieldToAlias, attrToAliases)
5261
}
5362

5463
/**
@@ -68,10 +77,23 @@ object NestedColumnAliasing {
6877
*/
6978
def replaceChildrenWithAliases(
7079
plan: LogicalPlan,
80+
nestedFieldToAlias: Map[ExtractValue, Alias],
7181
attrToAliases: Map[ExprId, Seq[Alias]]): LogicalPlan = {
7282
plan.withNewChildren(plan.children.map { plan =>
7383
Project(plan.output.flatMap(a => attrToAliases.getOrElse(a.exprId, Seq(a))), plan)
74-
})
84+
}).transformExpressions {
85+
case f: ExtractValue if nestedFieldToAlias.contains(f) =>
86+
nestedFieldToAlias(f).toAttribute
87+
}
88+
}
89+
90+
/**
91+
* Returns true for those operators that we can prune nested column on it.
92+
*/
93+
private def canPruneOn(plan: LogicalPlan) = plan match {
94+
case _: Aggregate => true
95+
case _: Expand => true
96+
case _ => false
7597
}
7698

7799
/**
@@ -204,15 +226,8 @@ object GeneratorNestedColumnAliasing {
204226
g: Generate,
205227
nestedFieldToAlias: Map[ExtractValue, Alias],
206228
attrToAliases: Map[ExprId, Seq[Alias]]): LogicalPlan = {
207-
val newGenerator = g.generator.transform {
208-
case f: ExtractValue if nestedFieldToAlias.contains(f) =>
209-
nestedFieldToAlias(f).toAttribute
210-
}.asInstanceOf[Generator]
211-
212229
// Defer updating `Generate.unrequiredChildIndex` to next round of `ColumnPruning`.
213-
val newGenerate = g.copy(generator = newGenerator)
214-
215-
NestedColumnAliasing.replaceChildrenWithAliases(newGenerate, attrToAliases)
230+
NestedColumnAliasing.replaceChildrenWithAliases(g, nestedFieldToAlias, attrToAliases)
216231
}
217232

218233
/**

sql/catalyst/src/test/scala/org/apache/spark/sql/catalyst/optimizer/NestedColumnAliasingSuite.scala

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,100 @@ class NestedColumnAliasingSuite extends SchemaPruningTest {
341341
.analyze
342342
comparePlans(optimized, expected)
343343
}
344+
345+
test("Nested field pruning for Aggregate") {
346+
def runTest(basePlan: LogicalPlan => LogicalPlan): Unit = {
347+
val query1 = basePlan(contact).groupBy($"id")(first($"name.first").as("first")).analyze
348+
val optimized1 = Optimize.execute(query1)
349+
val aliases1 = collectGeneratedAliases(optimized1)
350+
351+
val expected1 = basePlan(
352+
contact
353+
.select($"id", 'name.getField("first").as(aliases1(0)))
354+
).groupBy($"id")(first($"${aliases1(0)}").as("first")).analyze
355+
comparePlans(optimized1, expected1)
356+
357+
val query2 = basePlan(contact).groupBy($"name.last")(first($"name.first").as("first")).analyze
358+
val optimized2 = Optimize.execute(query2)
359+
val aliases2 = collectGeneratedAliases(optimized2)
360+
361+
val expected2 = basePlan(
362+
contact
363+
.select('name.getField("last").as(aliases2(0)), 'name.getField("first").as(aliases2(1)))
364+
).groupBy($"${aliases2(0)}")(first($"${aliases2(1)}").as("first")).analyze
365+
comparePlans(optimized2, expected2)
366+
}
367+
368+
Seq(
369+
(plan: LogicalPlan) => plan,
370+
(plan: LogicalPlan) => plan.limit(100),
371+
(plan: LogicalPlan) => plan.repartition(100),
372+
(plan: LogicalPlan) => Sample(0.0, 0.6, false, 11L, plan)).foreach { base =>
373+
runTest(base)
374+
}
375+
376+
val query3 = contact.groupBy($"id")(first($"name"), first($"name.first").as("first")).analyze
377+
val optimized3 = Optimize.execute(query3)
378+
val expected3 = contact.select($"id", $"name")
379+
.groupBy($"id")(first($"name"), first($"name.first").as("first")).analyze
380+
comparePlans(optimized3, expected3)
381+
}
382+
383+
test("Nested field pruning for Expand") {
384+
def runTest(basePlan: LogicalPlan => LogicalPlan): Unit = {
385+
val query1 = Expand(
386+
Seq(
387+
Seq($"name.first", $"name.middle"),
388+
Seq(ConcatWs(Seq($"name.first", $"name.middle")),
389+
ConcatWs(Seq($"name.middle", $"name.first")))
390+
),
391+
Seq('a.string, 'b.string),
392+
basePlan(contact)
393+
).analyze
394+
val optimized1 = Optimize.execute(query1)
395+
val aliases1 = collectGeneratedAliases(optimized1)
396+
397+
val expected1 = Expand(
398+
Seq(
399+
Seq($"${aliases1(0)}", $"${aliases1(1)}"),
400+
Seq(ConcatWs(Seq($"${aliases1(0)}", $"${aliases1(1)}")),
401+
ConcatWs(Seq($"${aliases1(1)}", $"${aliases1(0)}")))
402+
),
403+
Seq('a.string, 'b.string),
404+
basePlan(contact.select(
405+
'name.getField("first").as(aliases1(0)),
406+
'name.getField("middle").as(aliases1(1))))
407+
).analyze
408+
comparePlans(optimized1, expected1)
409+
}
410+
411+
Seq(
412+
(plan: LogicalPlan) => plan,
413+
(plan: LogicalPlan) => plan.limit(100),
414+
(plan: LogicalPlan) => plan.repartition(100),
415+
(plan: LogicalPlan) => Sample(0.0, 0.6, false, 11L, plan)).foreach { base =>
416+
runTest(base)
417+
}
418+
419+
val query2 = Expand(
420+
Seq(
421+
Seq($"name", $"name.middle"),
422+
Seq($"name", ConcatWs(Seq($"name.middle", $"name.first")))
423+
),
424+
Seq('a.string, 'b.string),
425+
contact
426+
).analyze
427+
val optimized2 = Optimize.execute(query2)
428+
val expected2 = Expand(
429+
Seq(
430+
Seq($"name", $"name.middle"),
431+
Seq($"name", ConcatWs(Seq($"name.middle", $"name.first")))
432+
),
433+
Seq('a.string, 'b.string),
434+
contact.select($"name")
435+
).analyze
436+
comparePlans(optimized2, expected2)
437+
}
344438
}
345439

346440
object NestedColumnAliasingSuite {

sql/core/src/test/scala/org/apache/spark/sql/execution/datasources/SchemaPruningSuite.scala

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ import org.scalactic.Equality
2323

2424
import org.apache.spark.sql.{DataFrame, QueryTest, Row}
2525
import org.apache.spark.sql.catalyst.SchemaPruningTest
26+
import org.apache.spark.sql.catalyst.expressions.Concat
2627
import org.apache.spark.sql.catalyst.parser.CatalystSqlParser
28+
import org.apache.spark.sql.catalyst.plans.logical.Expand
2729
import org.apache.spark.sql.execution.FileSourceScanExec
2830
import org.apache.spark.sql.execution.adaptive.AdaptiveSparkPlanHelper
2931
import org.apache.spark.sql.functions._
@@ -338,6 +340,75 @@ abstract class SchemaPruningSuite
338340
}
339341
}
340342

343+
testSchemaPruning("select one deep nested complex field after repartition") {
344+
val query = sql("select * from contacts")
345+
.repartition(100)
346+
.where("employer.company.address is not null")
347+
.selectExpr("employer.id as employer_id")
348+
checkScan(query,
349+
"struct<employer:struct<id:int,company:struct<address:string>>>")
350+
checkAnswer(query, Row(0) :: Nil)
351+
}
352+
353+
testSchemaPruning("select nested field in aggregation function of Aggregate") {
354+
val query1 = sql("select count(name.first) from contacts group by name.last")
355+
checkScan(query1, "struct<name:struct<first:string,last:string>>")
356+
checkAnswer(query1, Row(2) :: Row(2) :: Nil)
357+
358+
val query2 = sql("select count(name.first), sum(pets) from contacts group by id")
359+
checkScan(query2, "struct<id:int,name:struct<first:string>,pets:int>")
360+
checkAnswer(query2, Row(1, 1) :: Row(1, null) :: Row(1, 3) :: Row(1, null) :: Nil)
361+
362+
val query3 = sql("select count(name.first), first(name) from contacts group by id")
363+
checkScan(query3, "struct<id:int,name:struct<first:string,middle:string,last:string>>")
364+
checkAnswer(query3,
365+
Row(1, Row("Jane", "X.", "Doe")) ::
366+
Row(1, Row("Jim", null, "Jones")) ::
367+
Row(1, Row("John", "Y.", "Doe")) ::
368+
Row(1, Row("Janet", null, "Jones")) :: Nil)
369+
370+
val query4 = sql("select count(name.first), sum(pets) from contacts group by name.last")
371+
checkScan(query4, "struct<name:struct<first:string,last:string>,pets:int>")
372+
checkAnswer(query4, Row(2, null) :: Row(2, 4) :: Nil)
373+
}
374+
375+
testSchemaPruning("select nested field in Expand") {
376+
import org.apache.spark.sql.catalyst.dsl.expressions._
377+
378+
val query1 = Expand(
379+
Seq(
380+
Seq($"name.first", $"name.last"),
381+
Seq(Concat(Seq($"name.first", $"name.last")),
382+
Concat(Seq($"name.last", $"name.first")))
383+
),
384+
Seq('a.string, 'b.string),
385+
sql("select * from contacts").logicalPlan
386+
).toDF()
387+
checkScan(query1, "struct<name:struct<first:string,last:string>>")
388+
checkAnswer(query1,
389+
Row("Jane", "Doe") ::
390+
Row("JaneDoe", "DoeJane") ::
391+
Row("John", "Doe") ::
392+
Row("JohnDoe", "DoeJohn") ::
393+
Row("Jim", "Jones") ::
394+
Row("JimJones", "JonesJim") ::
395+
Row("Janet", "Jones") ::
396+
Row("JanetJones", "JonesJanet") :: Nil)
397+
398+
val name = StructType.fromDDL("first string, middle string, last string")
399+
val query2 = Expand(
400+
Seq(Seq($"name", $"name.last")),
401+
Seq('a.struct(name), 'b.string),
402+
sql("select * from contacts").logicalPlan
403+
).toDF()
404+
checkScan(query2, "struct<name:struct<first:string,middle:string,last:string>>")
405+
checkAnswer(query2,
406+
Row(Row("Jane", "X.", "Doe"), "Doe") ::
407+
Row(Row("John", "Y.", "Doe"), "Doe") ::
408+
Row(Row("Jim", null, "Jones"), "Jones") ::
409+
Row(Row("Janet", null, "Jones"), "Jones") ::Nil)
410+
}
411+
341412
protected def testSchemaPruning(testName: String)(testThunk: => Unit): Unit = {
342413
test(s"Spark vectorized reader - without partition data column - $testName") {
343414
withSQLConf(vectorizedReaderEnabledKey -> "true") {

0 commit comments

Comments
 (0)