diff --git a/src/compiler/scala/tools/nsc/backend/jvm/BCodeBodyBuilder.scala b/src/compiler/scala/tools/nsc/backend/jvm/BCodeBodyBuilder.scala index b416de69340e..1a86a6933e0b 100644 --- a/src/compiler/scala/tools/nsc/backend/jvm/BCodeBodyBuilder.scala +++ b/src/compiler/scala/tools/nsc/backend/jvm/BCodeBodyBuilder.scala @@ -13,13 +13,13 @@ package scala.tools.nsc package backend.jvm -import scala.annotation.{ switch, tailrec } +import scala.annotation.{switch, tailrec} import scala.collection.mutable.ListBuffer import scala.reflect.internal.Flags import scala.tools.asm -import scala.tools.asm.Opcodes -import scala.tools.asm.tree.{ InvokeDynamicInsnNode, MethodInsnNode, MethodNode } -import scala.tools.nsc.backend.jvm.BCodeHelpers.{ InvokeStyle, TestOp } +import scala.tools.asm.{ConstantDynamic, Handle, Opcodes} +import scala.tools.asm.tree.{InvokeDynamicInsnNode, MethodInsnNode, MethodNode} +import scala.tools.nsc.backend.jvm.BCodeHelpers.{InvokeStyle, TestOp} import scala.tools.nsc.backend.jvm.BackendReporting._ import scala.tools.nsc.backend.jvm.GenBCode._ @@ -340,18 +340,50 @@ abstract class BCodeBodyBuilder extends BCodeSkelBuilder { genLoadTo(arg, paramType, jumpDest) generatedDest = jumpDest } + case Apply(fun @ Select(qualifier, _), Apply(lookup, Nil) :: (fieldName @ Literal(Constant(_: String))) :: Literal(Constant(varHandleClass: Type)) :: staticArgs) + if currentSettings.target.value.toInt >= 11 + && fun.symbol.owner == definitions.ConstantBootstraps.moduleClass + && lookup.symbol.owner == definitions.MethodHandlesModule.moduleClass + && lookup.symbol.name.string_==("lookup") + && staticArgs.forall(_.isInstanceOf[Literal]) => + // Treat ConstantBoostraps.* with literal arguments as intrinsics + genLoadTo(ApplyDynamic(fun, List(Literal(Constant(fun.symbol)), fieldName) ++ staticArgs).setType(tree.tpe), expectedType, dest) case app: Apply => generatedType = genApply(app, expectedType) - case ApplyDynamic(qual, Literal(Constant(bootstrapMethodRef: Symbol)) :: staticAndDynamicArgs) => - val numDynamicArgs = qual.symbol.info.params.length - val (staticArgs, dynamicArgs) = staticAndDynamicArgs.splitAt(staticAndDynamicArgs.length - numDynamicArgs) - val bootstrapDescriptor = staticHandleFromSymbol(bootstrapMethodRef) - val bootstrapArgs = staticArgs.map({case t @ Literal(c: Constant) => bootstrapMethodArg(c, t.pos) case x => throw new MatchError(x)}) - val descriptor = methodBTypeFromMethodType(qual.symbol.info, isConstructor=false) - genLoadArguments(dynamicArgs, qual.symbol.info.params.map(param => typeToBType(param.info))) - mnode.visitInvokeDynamicInsn(qual.symbol.name.encoded, descriptor.descriptor, bootstrapDescriptor, bootstrapArgs : _*) + case ApplyDynamic(qual, Literal(Constant(bootstrapMethodRef: Symbol)) :: args) => + val paramSig = bootstrapMethodRef.info.params.map(_.info.typeSymbol) + val x = MethodHandles_LookupClass + paramSig.toList match { + case MethodHandles_LookupClass :: StringClass :: ClassClass :: _ => + args match { + case Literal(Constant(fieldName: String)) :: staticArgs => + val bootstrapDescriptor = staticHandleFromSymbol(qual.symbol) + val bootstrapArgs = staticArgs.map({case t @ Literal(c: Constant) => bootstrapMethodArg(c, t.pos) case x => throw new MatchError()}) + mnode.visitLdcInsn( + new ConstantDynamic( + fieldName, + typeToBType(qual.symbol.info.resultType).descriptor, + bootstrapDescriptor, + bootstrapArgs: _*) + ) + generatedType = classBTypeFromSymbol(tree.tpe.typeSymbol) + case _ => throw new MatchError(args) + } + case MethodHandles_LookupClass :: StringClass :: MethodTypeClass :: _ => + val staticAndDynamicArgs = args + val numDynamicArgs = qual.symbol.info.params.length + val (staticArgs, dynamicArgs) = staticAndDynamicArgs.splitAt(staticAndDynamicArgs.length - numDynamicArgs) + val bootstrapDescriptor = staticHandleFromSymbol(bootstrapMethodRef) + val bootstrapArgs = staticArgs.map({case t @ Literal(c: Constant) => bootstrapMethodArg(c, t.pos) case x => throw new MatchError(x)}) + val descriptor = methodBTypeFromMethodType(qual.symbol.info, isConstructor=false) + genLoadArguments(dynamicArgs, qual.symbol.info.params.map(param => typeToBType(param.info))) + mnode.visitInvokeDynamicInsn(qual.symbol.name.encoded, descriptor.descriptor, bootstrapDescriptor, bootstrapArgs : _*) + case _ => + abort(s"Unexpected bootstrap method signature: ${bootstrapMethodRef.info}") + } + case ApplyDynamic(qual, args) => abort("No invokedynamic support yet.") diff --git a/src/compiler/scala/tools/nsc/transform/TypingTransformers.scala b/src/compiler/scala/tools/nsc/transform/TypingTransformers.scala index d2185a7fe678..294074fce869 100644 --- a/src/compiler/scala/tools/nsc/transform/TypingTransformers.scala +++ b/src/compiler/scala/tools/nsc/transform/TypingTransformers.scala @@ -81,6 +81,27 @@ trait TypingTransformers { } override def transform(tree: Tree): Tree = tree match { + case Template(parents, self, stats) => + atOwner(currentOwner) { + val transformedStats = transformTrees(stats) + + if (stats eq transformedStats) super.transform(tree) + else { + val expanded = new mutable.ListBuffer[Tree] + + def expandStats(): Unit = transformedStats.foreach { + case EmptyTree => + case blk@Block(stats, expr) if blk.attachments.containsElement(ThicketAttachment) => + stats.foreach { s => if (s != EmptyTree) expanded += s } + if (expr != EmptyTree) expanded += expr + case t => + expanded += t + } + + expandStats() + treeCopy.Template(tree, super.transformTrees(parents), super.transformValDef(self), expanded.toList) + } + } case Block(stats, expr) => val transformedStats = transformTrees(stats) val transformedExpr = transform(expr) diff --git a/src/reflect/scala/reflect/internal/Definitions.scala b/src/reflect/scala/reflect/internal/Definitions.scala index 5be1e6a9ea1a..a5d6cfef1d93 100644 --- a/src/reflect/scala/reflect/internal/Definitions.scala +++ b/src/reflect/scala/reflect/internal/Definitions.scala @@ -631,7 +631,11 @@ trait Definitions extends api.StandardDefinitions { lazy val ScalaSignatureAnnotation = requiredClass[scala.reflect.ScalaSignature] lazy val ScalaLongSignatureAnnotation = requiredClass[scala.reflect.ScalaLongSignature] + lazy val ConstantBootstraps = getModuleIfDefined("java.lang.invoke.ConstantBootstraps") + lazy val MethodTypeClass = getClassIfDefined("java.lang.invoke.MethodType") lazy val MethodHandleClass = getClassIfDefined("java.lang.invoke.MethodHandle") + lazy val MethodHandlesModule = getModuleIfDefined("java.lang.invoke.MethodHandles") + lazy val MethodHandles_LookupClass = getMemberClass(MethodHandlesModule.moduleClass, TypeName("Lookup")) lazy val VarHandleClass = getClassIfDefined("java.lang.invoke.VarHandle") // Option classes diff --git a/test/junit/scala/tools/nsc/ClinitTest.scala b/test/junit/scala/tools/nsc/ClinitTest.scala new file mode 100644 index 000000000000..88a64560a5c6 --- /dev/null +++ b/test/junit/scala/tools/nsc/ClinitTest.scala @@ -0,0 +1,125 @@ +/* + * Scala (https://www.scala-lang.org) + * + * Copyright EPFL and Lightbend, Inc. + * + * Licensed under Apache License 2.0 + * (http://www.apache.org/licenses/LICENSE-2.0). + * + * See the NOTICE file distributed with this work for + * additional information regarding copyright ownership. + */ + +package scala.tools.nsc + +import java.io.{File, PrintWriter, StringWriter} +import java.lang.invoke.{ConstantBootstraps, MethodHandles, VarHandle} +import scala.annotation.nowarn +import scala.collection.immutable.{List, Nil, Seq} +import scala.reflect.internal.Flags.{FINAL, STATIC} +import scala.reflect.internal.util.ScalaClassLoader.URLClassLoader +import scala.tools.nsc.backend.jvm.AsmUtils +import scala.tools.nsc.{Global, Phase} +import scala.tools.nsc.plugins.{Plugin, PluginComponent} +import scala.tools.nsc.transform.TypingTransformers + +object ClinitTest { + + def main(args: Array[String]): Unit = { + val out = createTempDir() + val settings = new scala.tools.nsc.Settings() + settings.debug.value = true + settings.outdir.value = out.getAbsolutePath + settings.embeddedDefaults(getClass.getClassLoader) + settings.target.value = "11" + val isInSBT = !settings.classpath.isSetByUser + if (isInSBT) settings.usejavacp.value = true + + val global = new Global(settings) { + self => + @nowarn("cat=deprecation&msg=early initializers") + object late extends { + val global: self.type = self + } with DemoPlugin + + override protected def loadPlugins(): List[Plugin] = late :: Nil + } + import global._ + val run = new Run() + + run.compileUnits(newCompilationUnit( + """ + |class Staticify { + | useDirectVH + | + | + | def direct: String = null + | def useDirectVH = { + | import java.lang.invoke._ + | val vh = ConstantBootstraps.fieldVarHandle(MethodHandles.lookup(), "direct$impl", classOf[VarHandle], classOf[Staticify], classOf[String]) + | vh.get(this) + | } + |} + |""".stripMargin) :: Nil) + val loader = new URLClassLoader(Seq(new File(settings.outdir.value).toURI.toURL), global.getClass.getClassLoader) + + val bytecode = out.listFiles().flatMap { file => + val asmp = AsmUtils.textify(AsmUtils.readClass(file.getAbsolutePath)) + val sw = new StringWriter() + + asmp :: sw.toString :: Nil + }.mkString("\n\n") + println(bytecode) + + Class.forName("Staticify", true, loader).getDeclaredConstructor().newInstance() + } + + private def createTempDir(): File = { + val f = File.createTempFile("output", "") + f.delete() + f.mkdirs() + f + } +} +abstract class DemoPlugin extends Plugin { + + import global._ + override val description: String = "demo" + override val name: String = "demo" + + override val components: List[PluginComponent] = List(new PluginComponent with TypingTransformers { + val global: DemoPlugin.this.global.type = DemoPlugin.this.global + override def newPhase(prev: Phase): Phase = new StdPhase(prev) { + override def apply(unit: CompilationUnit): Unit = { + newTransformer(unit).transformUnit(unit) + } + } + + // If we run this before erasure we get an assertion error in specialConstructorErasure which expects + // constructors to return the class type, but we're returning Unit. + override val runsAfter: List[String] = "typer" :: Nil + override val phaseName: String = "demo" + private lazy val VarHandleClass = rootMirror.getClassIfDefined("java.lang.invoke.VarHandle") + private lazy val MethodHandlesClass = rootMirror.getModuleIfDefined("java.lang.invoke.MethodHandles") + + def newTransformer(unit: CompilationUnit) = new ThicketTransformer(newRootLocalTyper(unit)) { + override def transform(tree: Tree): Tree = tree match { + + case dd: DefDef if dd.name.string_==("direct") => + val cls = tree.symbol.owner + assert(cls.isClass, cls) + val implField = cls.newValue(dd.name.append("$impl").toTermName, tree.pos).setInfo(definitions.StringTpe) + cls.info.decls.enter(implField) + val implInit = q"null" + val vhField = cls.newValue(dd.name.append("$vh").toTermName, tree.pos, newFlags = STATIC).setInfo(VarHandleClass.tpeHK) + cls.info.decls.enter(vhField)/**/ + val vhInit = q"$MethodHandlesClass.lookup().findVarHandle(classOf[$cls], ${implField.name.dropLocal.encoded}, classOf[${implField.info.resultType.typeSymbol}])" + + localTyper.typed(Thicket(Block(dd, localTyper.atOwner(cls).typed(newValDef(implField, implInit)()), localTyper.atOwner(cls).typed(newValDef(vhField, vhInit)())))) + + case _ => + super.transform(tree) + } + } + }) +}