Skip to content

Commit 59aa3d5

Browse files
mpetruskasrowen
authored andcommitted
[SPARK-20706][SPARK-SHELL] Spark-shell not overriding method/variable definition
## What changes were proposed in this pull request? [SPARK-20706](https://issues.apache.org/jira/browse/SPARK-20706): Spark-shell not overriding method/variable definition This is a Scala repl bug ( [SI-9740](scala/bug#9740) ), was fixed in version 2.11.9 ( [see the original PR](scala/scala#5090) ) ## How was this patch tested? Added a new test case in `ReplSuite`. Author: Mark Petruska <[email protected]> Closes #19879 from mpetruska/SPARK-20706.
1 parent 1e17ab8 commit 59aa3d5

File tree

2 files changed

+167
-10
lines changed

2 files changed

+167
-10
lines changed

repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoopInterpreter.scala

Lines changed: 137 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717

1818
package org.apache.spark.repl
1919

20+
import scala.collection.mutable
2021
import scala.tools.nsc.Settings
2122
import scala.tools.nsc.interpreter._
2223

@@ -30,7 +31,7 @@ class SparkILoopInterpreter(settings: Settings, out: JPrintWriter) extends IMain
3031

3132
override def chooseHandler(member: intp.global.Tree): MemberHandler = member match {
3233
case member: Import => new SparkImportHandler(member)
33-
case _ => super.chooseHandler (member)
34+
case _ => super.chooseHandler(member)
3435
}
3536

3637
class SparkImportHandler(imp: Import) extends ImportHandler(imp: Import) {
@@ -100,4 +101,139 @@ class SparkILoopInterpreter(settings: Settings, out: JPrintWriter) extends IMain
100101
override def typeOfExpression(expr: String, silent: Boolean): global.Type =
101102
expressionTyper.typeOfExpression(expr, silent)
102103

104+
105+
import global.Name
106+
override def importsCode(wanted: Set[Name], wrapper: Request#Wrapper,
107+
definesClass: Boolean, generousImports: Boolean): ComputedImports = {
108+
109+
import global._
110+
import definitions.{ ObjectClass, ScalaPackage, JavaLangPackage, PredefModule }
111+
import memberHandlers._
112+
113+
val header, code, trailingBraces, accessPath = new StringBuilder
114+
val currentImps = mutable.HashSet[Name]()
115+
// only emit predef import header if name not resolved in history, loosely
116+
var predefEscapes = false
117+
118+
/**
119+
* Narrow down the list of requests from which imports
120+
* should be taken. Removes requests which cannot contribute
121+
* useful imports for the specified set of wanted names.
122+
*/
123+
case class ReqAndHandler(req: Request, handler: MemberHandler)
124+
125+
def reqsToUse: List[ReqAndHandler] = {
126+
/**
127+
* Loop through a list of MemberHandlers and select which ones to keep.
128+
* 'wanted' is the set of names that need to be imported.
129+
*/
130+
def select(reqs: List[ReqAndHandler], wanted: Set[Name]): List[ReqAndHandler] = {
131+
// Single symbol imports might be implicits! See bug #1752. Rather than
132+
// try to finesse this, we will mimic all imports for now.
133+
def keepHandler(handler: MemberHandler) = handler match {
134+
// While defining classes in class based mode - implicits are not needed.
135+
case h: ImportHandler if isClassBased && definesClass =>
136+
h.importedNames.exists(x => wanted.contains(x))
137+
case _: ImportHandler => true
138+
case x if generousImports => x.definesImplicit ||
139+
(x.definedNames exists (d => wanted.exists(w => d.startsWith(w))))
140+
case x => x.definesImplicit ||
141+
(x.definedNames exists wanted)
142+
}
143+
144+
reqs match {
145+
case Nil =>
146+
predefEscapes = wanted contains PredefModule.name ; Nil
147+
case rh :: rest if !keepHandler(rh.handler) => select(rest, wanted)
148+
case rh :: rest =>
149+
import rh.handler._
150+
val augment = rh match {
151+
case ReqAndHandler(_, _: ImportHandler) => referencedNames
152+
case _ => Nil
153+
}
154+
val newWanted = wanted ++ augment -- definedNames -- importedNames
155+
rh :: select(rest, newWanted)
156+
}
157+
}
158+
159+
/** Flatten the handlers out and pair each with the original request */
160+
select(allReqAndHandlers reverseMap { case (r, h) => ReqAndHandler(r, h) }, wanted).reverse
161+
}
162+
163+
// add code for a new object to hold some imports
164+
def addWrapper() {
165+
import nme.{ INTERPRETER_IMPORT_WRAPPER => iw }
166+
code append (wrapper.prewrap format iw)
167+
trailingBraces append wrapper.postwrap
168+
accessPath append s".$iw"
169+
currentImps.clear()
170+
}
171+
172+
def maybeWrap(names: Name*) = if (names exists currentImps) addWrapper()
173+
174+
def wrapBeforeAndAfter[T](op: => T): T = {
175+
addWrapper()
176+
try op finally addWrapper()
177+
}
178+
179+
// imports from Predef are relocated to the template header to allow hiding.
180+
def checkHeader(h: ImportHandler) = h.referencedNames contains PredefModule.name
181+
182+
// loop through previous requests, adding imports for each one
183+
wrapBeforeAndAfter {
184+
// Reusing a single temporary value when import from a line with multiple definitions.
185+
val tempValLines = mutable.Set[Int]()
186+
for (ReqAndHandler(req, handler) <- reqsToUse) {
187+
val objName = req.lineRep.readPathInstance
188+
handler match {
189+
case h: ImportHandler if checkHeader(h) =>
190+
header.clear()
191+
header append f"${h.member}%n"
192+
// If the user entered an import, then just use it; add an import wrapping
193+
// level if the import might conflict with some other import
194+
case x: ImportHandler if x.importsWildcard =>
195+
wrapBeforeAndAfter(code append (x.member + "\n"))
196+
case x: ImportHandler =>
197+
maybeWrap(x.importedNames: _*)
198+
code append (x.member + "\n")
199+
currentImps ++= x.importedNames
200+
201+
case x if isClassBased =>
202+
for (sym <- x.definedSymbols) {
203+
maybeWrap(sym.name)
204+
x match {
205+
case _: ClassHandler =>
206+
code.append(s"import ${objName}${req.accessPath}.`${sym.name}`\n")
207+
case _ =>
208+
val valName = s"${req.lineRep.packageName}${req.lineRep.readName}"
209+
if (!tempValLines.contains(req.lineRep.lineId)) {
210+
code.append(s"val $valName: ${objName}.type = $objName\n")
211+
tempValLines += req.lineRep.lineId
212+
}
213+
code.append(s"import ${valName}${req.accessPath}.`${sym.name}`\n")
214+
}
215+
currentImps += sym.name
216+
}
217+
// For other requests, import each defined name.
218+
// import them explicitly instead of with _, so that
219+
// ambiguity errors will not be generated. Also, quote
220+
// the name of the variable, so that we don't need to
221+
// handle quoting keywords separately.
222+
case x =>
223+
for (sym <- x.definedSymbols) {
224+
maybeWrap(sym.name)
225+
code append s"import ${x.path}\n"
226+
currentImps += sym.name
227+
}
228+
}
229+
}
230+
}
231+
232+
val computedHeader = if (predefEscapes) header.toString else ""
233+
ComputedImports(computedHeader, code.toString, trailingBraces.toString, accessPath.toString)
234+
}
235+
236+
private def allReqAndHandlers =
237+
prevRequestList flatMap (req => req.handlers map (req -> _))
238+
103239
}

repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala

Lines changed: 30 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -227,14 +227,35 @@ class ReplSuite extends SparkFunSuite {
227227
assertDoesNotContain("error: not found: value sc", output)
228228
}
229229

230-
test("spark-shell should find imported types in class constructors and extends clause") {
231-
val output = runInterpreter("local",
232-
"""
233-
|import org.apache.spark.Partition
234-
|class P(p: Partition)
235-
|class P(val index: Int) extends Partition
236-
""".stripMargin)
237-
assertDoesNotContain("error: not found: type Partition", output)
238-
}
230+
test("spark-shell should find imported types in class constructors and extends clause") {
231+
val output = runInterpreter("local",
232+
"""
233+
|import org.apache.spark.Partition
234+
|class P(p: Partition)
235+
|class P(val index: Int) extends Partition
236+
""".stripMargin)
237+
assertDoesNotContain("error: not found: type Partition", output)
238+
}
239+
240+
test("spark-shell should shadow val/def definitions correctly") {
241+
val output1 = runInterpreter("local",
242+
"""
243+
|def myMethod() = "first definition"
244+
|val tmp = myMethod(); val out = tmp
245+
|def myMethod() = "second definition"
246+
|val tmp = myMethod(); val out = s"$tmp aabbcc"
247+
""".stripMargin)
248+
assertContains("second definition aabbcc", output1)
249+
250+
val output2 = runInterpreter("local",
251+
"""
252+
|val a = 1
253+
|val b = a; val c = b;
254+
|val a = 2
255+
|val b = a; val c = b;
256+
|s"!!$b!!"
257+
""".stripMargin)
258+
assertContains("!!2!!", output2)
259+
}
239260

240261
}

0 commit comments

Comments
 (0)