Skip to content

Commit 63b84cb

Browse files
authored
Merge pull request #9451 from dotty-staging/java-generic-varargs
2 parents bc4b401 + 7e68962 commit 63b84cb

File tree

13 files changed

+162
-44
lines changed

13 files changed

+162
-44
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

+6-4
Original file line numberDiff line numberDiff line change
@@ -1684,10 +1684,12 @@ object desugar {
16841684
Apply(Select(Apply(scalaDot(nme.StringContext), strs), id).withSpan(tree.span), elems)
16851685
case PostfixOp(t, op) =>
16861686
if ((ctx.mode is Mode.Type) && !isBackquoted(op) && op.name == tpnme.raw.STAR) {
1687-
val seqType = if (ctx.compilationUnit.isJava) defn.ArrayType else defn.SeqType
1688-
Annotated(
1689-
AppliedTypeTree(ref(seqType), t),
1690-
New(ref(defn.RepeatedAnnot.typeRef), Nil :: Nil))
1687+
if ctx.compilationUnit.isJava then
1688+
AppliedTypeTree(ref(defn.RepeatedParamType), t)
1689+
else
1690+
Annotated(
1691+
AppliedTypeTree(ref(defn.SeqType), t),
1692+
New(ref(defn.RepeatedAnnot.typeRef), Nil :: Nil))
16911693
}
16921694
else {
16931695
assert(ctx.mode.isExpr || ctx.reporter.errorsReported || ctx.mode.is(Mode.Interactive), ctx.mode)

compiler/src/dotty/tools/dotc/core/Definitions.scala

+2
Original file line numberDiff line numberDiff line change
@@ -399,6 +399,8 @@ class Definitions {
399399
def runtimeMethodRef(name: PreName): TermRef = ScalaRuntimeModule.requiredMethodRef(name)
400400
def ScalaRuntime_drop: Symbol = runtimeMethodRef(nme.drop).symbol
401401
@tu lazy val ScalaRuntime__hashCode: Symbol = ScalaRuntimeModule.requiredMethod(nme._hashCode_)
402+
@tu lazy val ScalaRuntime_toArray: Symbol = ScalaRuntimeModule.requiredMethod(nme.toArray)
403+
@tu lazy val ScalaRuntime_toObjectArray: Symbol = ScalaRuntimeModule.requiredMethod(nme.toObjectArray)
402404

403405
@tu lazy val BoxesRunTimeModule: Symbol = requiredModule("scala.runtime.BoxesRunTime")
404406
@tu lazy val BoxesRunTimeModule_externalEquals: Symbol = BoxesRunTimeModule.info.decl(nme.equals_).suchThat(toDenot(_).info.firstParamTypes.size == 2).symbol

compiler/src/dotty/tools/dotc/core/classfile/ClassfileParser.scala

+41-14
Original file line numberDiff line numberDiff line change
@@ -280,15 +280,13 @@ class ClassfileParser(
280280
addConstructorTypeParams(denot)
281281
}
282282

283-
denot.info = pool.getType(in.nextChar)
283+
val isVarargs = denot.is(Flags.Method) && (jflags & JAVA_ACC_VARARGS) != 0
284+
denot.info = pool.getType(in.nextChar, isVarargs)
284285
if (isEnum) denot.info = ConstantType(Constant(sym))
285286
if (isConstructor) normalizeConstructorParams()
286-
denot.info = translateTempPoly(parseAttributes(sym, denot.info))
287+
denot.info = translateTempPoly(parseAttributes(sym, denot.info, isVarargs))
287288
if (isConstructor) normalizeConstructorInfo()
288289

289-
if (denot.is(Flags.Method) && (jflags & JAVA_ACC_VARARGS) != 0)
290-
denot.info = arrayToRepeated(denot.info)
291-
292290
if (ctx.explicitNulls) denot.info = JavaNullInterop.nullifyMember(denot.symbol, denot.info, isEnum)
293291

294292
// seal java enums
@@ -324,7 +322,7 @@ class ClassfileParser(
324322
case BOOL_TAG => defn.BooleanType
325323
}
326324

327-
private def sigToType(sig: SimpleName, owner: Symbol = null)(using Context): Type = {
325+
private def sigToType(sig: SimpleName, owner: Symbol = null, isVarargs: Boolean = false)(using Context): Type = {
328326
var index = 0
329327
val end = sig.length
330328
def accept(ch: Char): Unit = {
@@ -395,13 +393,42 @@ class ClassfileParser(
395393
val elemtp = sig2type(tparams, skiptvs)
396394
defn.ArrayOf(elemtp.translateJavaArrayElementType)
397395
case '(' =>
398-
// we need a method symbol. given in line 486 by calling getType(methodSym, ..)
396+
def isMethodEnd(i: Int) = sig(i) == ')'
397+
def isArray(i: Int) = sig(i) == '['
398+
399+
/** Is this a repeated parameter type?
400+
* This is true if we're in a vararg method and this is the last parameter.
401+
*/
402+
def isRepeatedParam(i: Int): Boolean =
403+
if !isVarargs then return false
404+
var cur = i
405+
// Repeated parameters are represented as arrays
406+
if !isArray(cur) then return false
407+
// Handle nested arrays: int[]...
408+
while isArray(cur) do
409+
cur += 1
410+
// Simple check to see if we're the last parameter: there should be no
411+
// array in the signature until the method end.
412+
while !isMethodEnd(cur) do
413+
if isArray(cur) then return false
414+
cur += 1
415+
true
416+
end isRepeatedParam
417+
399418
val paramtypes = new ListBuffer[Type]()
400419
var paramnames = new ListBuffer[TermName]()
401-
while (sig(index) != ')') {
420+
while !isMethodEnd(index) do
402421
paramnames += nme.syntheticParamName(paramtypes.length)
403-
paramtypes += objToAny(sig2type(tparams, skiptvs))
404-
}
422+
paramtypes += {
423+
if isRepeatedParam(index) then
424+
index += 1
425+
val elemType = sig2type(tparams, skiptvs)
426+
// `ElimRepeated` is responsible for correctly erasing this.
427+
defn.RepeatedParamType.appliedTo(elemType)
428+
else
429+
objToAny(sig2type(tparams, skiptvs))
430+
}
431+
405432
index += 1
406433
val restype = sig2type(tparams, skiptvs)
407434
JavaMethodType(paramnames.toList, paramtypes.toList, restype)
@@ -574,7 +601,7 @@ class ClassfileParser(
574601
None // ignore malformed annotations
575602
}
576603

577-
def parseAttributes(sym: Symbol, symtype: Type)(using Context): Type = {
604+
def parseAttributes(sym: Symbol, symtype: Type, isVarargs: Boolean = false)(using Context): Type = {
578605
var newType = symtype
579606

580607
def parseAttribute(): Unit = {
@@ -584,7 +611,7 @@ class ClassfileParser(
584611
attrName match {
585612
case tpnme.SignatureATTR =>
586613
val sig = pool.getExternalName(in.nextChar)
587-
newType = sigToType(sig, sym)
614+
newType = sigToType(sig, sym, isVarargs)
588615
if (ctx.debug && ctx.verbose)
589616
println("" + sym + "; signature = " + sig + " type = " + newType)
590617
case tpnme.SyntheticATTR =>
@@ -1103,8 +1130,8 @@ class ClassfileParser(
11031130
c
11041131
}
11051132

1106-
def getType(index: Int)(using Context): Type =
1107-
sigToType(getExternalName(index))
1133+
def getType(index: Int, isVarargs: Boolean = false)(using Context): Type =
1134+
sigToType(getExternalName(index), isVarargs = isVarargs)
11081135

11091136
def getSuperClass(index: Int)(using Context): Symbol = {
11101137
assert(index != 0, "attempt to parse java.lang.Object from classfile")

compiler/src/dotty/tools/dotc/core/unpickleScala2/Scala2Unpickler.scala

-15
Original file line numberDiff line numberDiff line change
@@ -59,21 +59,6 @@ object Scala2Unpickler {
5959
denot.info = PolyType.fromParams(denot.owner.typeParams, denot.info)
6060
}
6161

62-
/** Convert array parameters denoting a repeated parameter of a Java method
63-
* to `RepeatedParamClass` types.
64-
*/
65-
def arrayToRepeated(tp: Type)(using Context): Type = tp match {
66-
case tp: MethodType =>
67-
val lastArg = tp.paramInfos.last
68-
assert(lastArg isRef defn.ArrayClass)
69-
tp.derivedLambdaType(
70-
tp.paramNames,
71-
tp.paramInfos.init :+ lastArg.translateParameterized(defn.ArrayClass, defn.RepeatedParamClass),
72-
tp.resultType)
73-
case tp: PolyType =>
74-
tp.derivedLambdaType(tp.paramNames, tp.paramInfos, arrayToRepeated(tp.resultType))
75-
}
76-
7762
def ensureConstructor(cls: ClassSymbol, scope: Scope)(using Context): Unit = {
7863
if (scope.lookup(nme.CONSTRUCTOR) == NoSymbol) {
7964
val constr = newDefaultConstructor(cls)

compiler/src/dotty/tools/dotc/transform/ElimRepeated.scala

+62-9
Original file line numberDiff line numberDiff line change
@@ -20,8 +20,9 @@ object ElimRepeated {
2020
val name: String = "elimRepeated"
2121
}
2222

23-
/** A transformer that removes repeated parameters (T*) from all types, replacing
24-
* them with Seq types.
23+
/** A transformer that eliminates repeated parameters (T*) from all types, replacing
24+
* them with Seq or Array types and adapting repeated arguments to conform to
25+
* the transformed type if needed.
2526
*/
2627
class ElimRepeated extends MiniPhase with InfoTransformer { thisPhase =>
2728
import ast.tpd._
@@ -55,9 +56,28 @@ class ElimRepeated extends MiniPhase with InfoTransformer { thisPhase =>
5556
case tp @ MethodTpe(paramNames, paramTypes, resultType) =>
5657
val resultType1 = elimRepeated(resultType)
5758
val paramTypes1 =
58-
if paramTypes.nonEmpty && paramTypes.last.isRepeatedParam then
59-
val last = paramTypes.last.translateFromRepeated(toArray = tp.isJavaMethod)
60-
paramTypes.init :+ last
59+
val lastIdx = paramTypes.length - 1
60+
if lastIdx >= 0 then
61+
val last = paramTypes(lastIdx)
62+
if last.isRepeatedParam then
63+
val isJava = tp.isJavaMethod
64+
// A generic Java varargs `T...` where `T` is unbounded is erased to
65+
// `Object[]` in bytecode, we directly translate such a type to
66+
// `Array[_ <: Object]` instead of `Array[_ <: T]` here. This allows
67+
// the tree transformer of this phase to emit the correct adaptation
68+
// for repeated arguments if needed (for example, an `Array[Int]` will
69+
// be copied into an `Array[Object]`, see `adaptToArray`).
70+
val last1 =
71+
if isJava && {
72+
val elemTp = last.elemType
73+
elemTp.isInstanceOf[TypeParamRef] && elemTp.typeSymbol == defn.AnyClass
74+
}
75+
then
76+
defn.ArrayOf(TypeBounds.upper(defn.ObjectType))
77+
else
78+
last.translateFromRepeated(toArray = isJava)
79+
paramTypes.updated(lastIdx, last1)
80+
else paramTypes
6181
else paramTypes
6282
tp.derivedLambdaType(paramNames, paramTypes1, resultType1)
6383
case tp: PolyType =>
@@ -82,9 +102,10 @@ class ElimRepeated extends MiniPhase with InfoTransformer { thisPhase =>
82102
case arg: Typed if isWildcardStarArg(arg) =>
83103
val isJavaDefined = tree.fun.symbol.is(JavaDefined)
84104
val tpe = arg.expr.tpe
85-
if isJavaDefined && tpe.derivesFrom(defn.SeqClass) then
86-
seqToArray(arg.expr)
87-
else if !isJavaDefined && tpe.derivesFrom(defn.ArrayClass)
105+
if isJavaDefined then
106+
val pt = tree.fun.tpe.widen.firstParamTypes.last
107+
adaptToArray(arg.expr, pt.elemType.bounds.hi)
108+
else if tpe.derivesFrom(defn.ArrayClass) then
88109
arrayToSeq(arg.expr)
89110
else
90111
arg.expr
@@ -107,7 +128,39 @@ class ElimRepeated extends MiniPhase with InfoTransformer { thisPhase =>
107128
.appliedToType(elemType)
108129
.appliedTo(tree, clsOf(elemClass.typeRef))
109130

110-
/** Convert Java array argument to Scala Seq */
131+
/** Adapt a Seq or Array tree to be a subtype of `Array[_ <: $elemPt]`.
132+
*
133+
* @pre `elemPt` must either be a super type of the argument element type or `Object`.
134+
* The special handling of `Object` is required to deal with the translation
135+
* of generic Java varargs in `elimRepeated`.
136+
*/
137+
private def adaptToArray(tree: Tree, elemPt: Type)(implicit ctx: Context): Tree =
138+
val elemTp = tree.tpe.elemType
139+
val treeIsArray = tree.tpe.derivesFrom(defn.ArrayClass)
140+
if elemTp <:< elemPt then
141+
if treeIsArray then
142+
tree // no adaptation needed
143+
else
144+
tree match
145+
case SeqLiteral(elems, elemtpt) =>
146+
JavaSeqLiteral(elems, elemtpt).withSpan(tree.span)
147+
case _ =>
148+
// Convert a Seq[T] to an Array[$elemPt]
149+
ref(defn.DottyArraysModule)
150+
.select(nme.seqToArray)
151+
.appliedToType(elemPt)
152+
.appliedTo(tree, clsOf(elemPt))
153+
else if treeIsArray then
154+
// Convert an Array[T] to an Array[Object]
155+
ref(defn.ScalaRuntime_toObjectArray)
156+
.appliedTo(tree)
157+
else
158+
// Convert a Seq[T] to an Array[Object]
159+
ref(defn.ScalaRuntime_toArray)
160+
.appliedToType(elemTp)
161+
.appliedTo(tree)
162+
163+
/** Convert an Array into a scala.Seq */
111164
private def arrayToSeq(tree: Tree)(using Context): Tree =
112165
tpd.wrapArray(tree, tree.tpe.elemType)
113166

compiler/src/dotty/tools/dotc/typer/Typer.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -1723,7 +1723,7 @@ class Typer extends Namer
17231723
checkedArgs = checkedArgs.mapconserve(arg =>
17241724
checkSimpleKinded(checkNoWildcard(arg)))
17251725
else if (ctx.compilationUnit.isJava)
1726-
if (tpt1.symbol eq defn.ArrayClass) || (tpt1.symbol eq defn.RepeatedParamClass) then
1726+
if (tpt1.symbol eq defn.ArrayClass) then
17271727
checkedArgs match {
17281728
case List(arg) =>
17291729
val elemtp = arg.tpe.translateJavaArrayElementType

tests/neg/i533/Test.scala

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@ object Test {
33
val x = new Array[Int](1)
44
x(0) = 10
55
println(JA.get(x)) // error
6-
println(JA.getVarargs(x: _*)) // error
6+
println(JA.getVarargs(x: _*)) // now OK.
77
}
88
}

tests/pos/arrays2.scala

+1
Original file line numberDiff line numberDiff line change
@@ -24,4 +24,5 @@ one warning found
2424
// #2461
2525
object arrays3 {
2626
def apply[X <: AnyRef](xs : X*) : java.util.List[X] = java.util.Arrays.asList(xs: _*)
27+
def apply2[X](xs : X*) : java.util.List[X] = java.util.Arrays.asList(xs: _*)
2728
}

tests/run/i9439.scala

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
object Test {
2+
// First example with a concrete type <: AnyVal
3+
def main(args: Array[String]): Unit = {
4+
val coll = new java.util.ArrayList[Int]()
5+
java.util.Collections.addAll(coll, 5, 6)
6+
println(coll.size())
7+
8+
foo(5, 6)
9+
}
10+
11+
// Second example with an abstract type not known to be <: AnyRef
12+
def foo[A](a1: A, a2: A): Unit = {
13+
val coll = new java.util.ArrayList[A]()
14+
java.util.Collections.addAll(coll, a1, a2)
15+
println(coll.size())
16+
}
17+
}

tests/run/java-varargs-2/A.java

+10
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
class A {
2+
public static void foo(int... args) {
3+
}
4+
5+
public static <T> void gen(T... args) {
6+
}
7+
8+
public static <T extends java.io.Serializable> void gen2(T... args) {
9+
}
10+
}

tests/run/java-varargs-2/Test.scala

+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
object Test {
2+
def main(args: Array[String]): Unit = {
3+
A.foo(1)
4+
A.foo(Array(1): _*)
5+
A.foo(Seq(1): _*)
6+
A.gen(1)
7+
A.gen(Array(1): _*)
8+
A.gen(Seq(1): _*)
9+
A.gen2("")
10+
A.gen2(Array(""): _*)
11+
A.gen2(Seq(""): _*)
12+
}
13+
}

tests/run/java-varargs/A_1.java

+3
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,7 @@ public static void foo(int... args) {
44

55
public static <T> void gen(T... args) {
66
}
7+
8+
public static <T extends java.io.Serializable> void gen2(T... args) {
9+
}
710
}

tests/run/java-varargs/Test_2.scala

+5
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,13 @@
11
object Test {
22
def main(args: Array[String]): Unit = {
3+
A_1.foo(1)
34
A_1.foo(Array(1): _*)
45
A_1.foo(Seq(1): _*)
6+
A_1.gen(1)
57
A_1.gen(Array(1): _*)
68
A_1.gen(Seq(1): _*)
9+
A_1.gen2("")
10+
A_1.gen2(Array(""): _*)
11+
A_1.gen2(Seq(""): _*)
712
}
813
}

0 commit comments

Comments
 (0)