Skip to content

Commit 25dd1de

Browse files
authored
Merge pull request #4422 from dotty-staging/fix-4373
Fix #4373: reject wildcard types in syntactically invalid positions
2 parents 9a30e20 + 0804fde commit 25dd1de

File tree

7 files changed

+131
-35
lines changed

7 files changed

+131
-35
lines changed

compiler/src/dotty/tools/dotc/ast/Desugar.scala

+1-3
Original file line numberDiff line numberDiff line change
@@ -1130,9 +1130,7 @@ object desugar {
11301130
Apply(Select(Apply(Ident(nme.StringContext), strs), id), elems)
11311131
case InfixOp(l, op, r) =>
11321132
if (ctx.mode is Mode.Type)
1133-
if (!op.isBackquoted && op.name == tpnme.raw.AMP) AndTypeTree(l, r) // l & r
1134-
else if (!op.isBackquoted && op.name == tpnme.raw.BAR) OrTypeTree(l, r) // l | r
1135-
else AppliedTypeTree(op, l :: r :: Nil) // op[l, r]
1133+
AppliedTypeTree(op, l :: r :: Nil) // op[l, r]
11361134
else {
11371135
assert(ctx.mode is Mode.Pattern) // expressions are handled separately by `binop`
11381136
Apply(op, l :: r :: Nil) // op(l, r)

compiler/src/dotty/tools/dotc/parsing/Parsers.scala

+67-30
Original file line numberDiff line numberDiff line change
@@ -452,7 +452,7 @@ object Parsers {
452452
if (isLeftAssoc(op1) != op2LeftAssoc)
453453
syntaxError(MixedLeftAndRightAssociativeOps(op1, op2, op2LeftAssoc), offset)
454454

455-
def reduceStack(base: List[OpInfo], top: Tree, prec: Int, leftAssoc: Boolean, op2: Name): Tree = {
455+
def reduceStack(base: List[OpInfo], top: Tree, prec: Int, leftAssoc: Boolean, op2: Name, isType: Boolean): Tree = {
456456
if (opStack != base && precedence(opStack.head.operator.name) == prec)
457457
checkAssoc(opStack.head.offset, opStack.head.operator.name, op2, leftAssoc)
458458
def recur(top: Tree): Tree = {
@@ -464,7 +464,15 @@ object Parsers {
464464
opStack = opStack.tail
465465
recur {
466466
atPos(opInfo.operator.pos union opInfo.operand.pos union top.pos) {
467-
InfixOp(opInfo.operand, opInfo.operator, top)
467+
val op = opInfo.operator
468+
val l = opInfo.operand
469+
val r = top
470+
if (isType && !op.isBackquoted && op.name == tpnme.raw.BAR) {
471+
OrTypeTree(checkAndOrArgument(l), checkAndOrArgument(r))
472+
} else if (isType && !op.isBackquoted && op.name == tpnme.raw.AMP) {
473+
AndTypeTree(checkAndOrArgument(l), checkAndOrArgument(r))
474+
} else
475+
InfixOp(l, op, r)
468476
}
469477
}
470478
}
@@ -488,20 +496,20 @@ object Parsers {
488496
var top = first
489497
while (isIdent && isOperator) {
490498
val op = if (isType) typeIdent() else termIdent()
491-
top = reduceStack(base, top, precedence(op.name), isLeftAssoc(op.name), op.name)
499+
top = reduceStack(base, top, precedence(op.name), isLeftAssoc(op.name), op.name, isType)
492500
opStack = OpInfo(top, op, in.offset) :: opStack
493501
newLineOptWhenFollowing(canStartOperand)
494502
if (maybePostfix && !canStartOperand(in.token)) {
495503
val topInfo = opStack.head
496504
opStack = opStack.tail
497-
val od = reduceStack(base, topInfo.operand, 0, true, in.name)
505+
val od = reduceStack(base, topInfo.operand, 0, true, in.name, isType)
498506
return atPos(startOffset(od), topInfo.offset) {
499507
PostfixOp(od, topInfo.operator)
500508
}
501509
}
502510
top = operand()
503511
}
504-
reduceStack(base, top, 0, true, in.name)
512+
reduceStack(base, top, 0, true, in.name, isType)
505513
}
506514

507515
/* -------- IDENTIFIERS AND LITERALS ------------------------------------------- */
@@ -709,15 +717,7 @@ object Parsers {
709717
/** Same as [[typ]], but if this results in a wildcard it emits a syntax error and
710718
* returns a tree for type `Any` instead.
711719
*/
712-
def toplevelTyp(): Tree = {
713-
val t = typ()
714-
findWildcardType(t) match {
715-
case Some(wildcardPos) =>
716-
syntaxError(UnboundWildcardType(), wildcardPos)
717-
scalaAny
718-
case None => t
719-
}
720-
}
720+
def toplevelTyp(): Tree = checkWildcard(typ())
721721

722722
/** Type ::= [FunArgMods] FunArgTypes `=>' Type
723723
* | HkTypeParamClause `->' Type
@@ -768,9 +768,16 @@ object Parsers {
768768
accept(RPAREN)
769769
if (imods.is(Implicit) || isValParamList || in.token == ARROW) functionRest(ts)
770770
else {
771-
for (t <- ts)
772-
if (t.isInstanceOf[ByNameTypeTree])
773-
syntaxError(ByNameParameterNotSupported())
771+
val ts1 =
772+
for (t <- ts) yield {
773+
t match {
774+
case t@ByNameTypeTree(t1) =>
775+
syntaxError(ByNameParameterNotSupported(t), t.pos)
776+
t1
777+
case _ =>
778+
t
779+
}
780+
}
774781
val tuple = atPos(start) { makeTupleOrParens(ts) }
775782
infixTypeRest(
776783
refinedTypeRest(
@@ -784,7 +791,7 @@ object Parsers {
784791
val start = in.offset
785792
val tparams = typeParamClause(ParamOwner.TypeParam)
786793
if (in.token == ARROW)
787-
atPos(start, in.skipToken())(LambdaTypeTree(tparams, typ()))
794+
atPos(start, in.skipToken())(LambdaTypeTree(tparams, toplevelTyp()))
788795
else { accept(ARROW); typ() }
789796
}
790797
else infixType()
@@ -822,7 +829,7 @@ object Parsers {
822829

823830
def refinedTypeRest(t: Tree): Tree = {
824831
newLineOptWhenFollowedBy(LBRACE)
825-
if (in.token == LBRACE) refinedTypeRest(atPos(startOffset(t)) { RefinedTypeTree(t, refinement()) })
832+
if (in.token == LBRACE) refinedTypeRest(atPos(startOffset(t)) { RefinedTypeTree(checkWildcard(t), refinement()) })
826833
else t
827834
}
828835

@@ -835,7 +842,7 @@ object Parsers {
835842
if (ctx.settings.strict.value)
836843
deprecationWarning(DeprecatedWithOperator())
837844
in.nextToken()
838-
AndTypeTree(t, withType())
845+
AndTypeTree(checkAndOrArgument(t), checkAndOrArgument(withType()))
839846
}
840847
else t
841848

@@ -886,7 +893,7 @@ object Parsers {
886893
private def simpleTypeRest(t: Tree): Tree = in.token match {
887894
case HASH => simpleTypeRest(typeProjection(t))
888895
case LBRACKET => simpleTypeRest(atPos(startOffset(t)) {
889-
AppliedTypeTree(t, typeArgs(namedOK = false, wildOK = true)) })
896+
AppliedTypeTree(checkWildcard(t), typeArgs(namedOK = false, wildOK = true)) })
890897
case _ => t
891898
}
892899

@@ -917,7 +924,7 @@ object Parsers {
917924
else Nil
918925
first :: rest
919926
}
920-
def typParser() = if (wildOK) typ() else toplevelTyp()
927+
def typParser() = checkWildcard(typ(), wildOK)
921928
if (namedOK && in.token == IDENTIFIER)
922929
typParser() match {
923930
case Ident(name) if in.token == EQUALS =>
@@ -1001,17 +1008,46 @@ object Parsers {
10011008
else if (location == Location.InPattern) refinedType()
10021009
else infixType()
10031010

1004-
/** Checks whether `t` is a wildcard type.
1005-
* If it is, returns the [[Position]] where the wildcard occurs.
1011+
/** Checks whether `t` represents a non-value type (wildcard types, or ByNameTypeTree).
1012+
* If it is, returns the [[Tree]] which immediately represents the non-value type.
10061013
*/
10071014
@tailrec
1008-
private final def findWildcardType(t: Tree): Option[Position] = t match {
1009-
case TypeBoundsTree(_, _) => Some(t.pos)
1010-
case Parens(t1) => findWildcardType(t1)
1011-
case Annotated(t1, _) => findWildcardType(t1)
1015+
private final def findNonValueTypeTree(t: Tree, alsoNonValue: Boolean): Option[Tree] = t match {
1016+
case TypeBoundsTree(_, _) => Some(t)
1017+
case ByNameTypeTree(_) if alsoNonValue => Some(t)
1018+
case Parens(t1) => findNonValueTypeTree(t1, alsoNonValue)
1019+
case Annotated(t1, _) => findNonValueTypeTree(t1, alsoNonValue)
10121020
case _ => None
10131021
}
10141022

1023+
def rejectWildcard(t: Tree, fallbackTree: Tree): Tree =
1024+
findNonValueTypeTree(t, false) match {
1025+
case Some(wildcardTree) =>
1026+
syntaxError(UnboundWildcardType(), wildcardTree.pos)
1027+
fallbackTree
1028+
case None => t
1029+
}
1030+
1031+
1032+
def checkWildcard(t: Tree, wildOK: Boolean = false, fallbackTree: Tree = scalaAny): Tree =
1033+
if (wildOK)
1034+
t
1035+
else
1036+
rejectWildcard(t, fallbackTree)
1037+
1038+
def checkAndOrArgument(t: Tree): Tree =
1039+
findNonValueTypeTree(t, true) match {
1040+
case Some(typTree) =>
1041+
typTree match {
1042+
case typTree: TypeBoundsTree =>
1043+
syntaxError(UnboundWildcardType(), typTree.pos)
1044+
case typTree: ByNameTypeTree =>
1045+
syntaxError(ByNameParameterNotSupported(typTree), typTree.pos)
1046+
}
1047+
scalaAny
1048+
case None => t
1049+
}
1050+
10151051
/* ----------- EXPRESSIONS ------------------------------------------------ */
10161052

10171053
/** EqualsExpr ::= `=' Expr
@@ -2148,7 +2184,7 @@ object Parsers {
21482184
in.token match {
21492185
case EQUALS =>
21502186
in.nextToken()
2151-
TypeDef(name, lambdaAbstract(tparams, typ())).withMods(mods).setComment(in.getDocComment(start))
2187+
TypeDef(name, lambdaAbstract(tparams, toplevelTyp())).withMods(mods).setComment(in.getDocComment(start))
21522188
case SUPERTYPE | SUBTYPE | SEMI | NEWLINE | NEWLINES | COMMA | RBRACE | EOF =>
21532189
TypeDef(name, lambdaAbstract(tparams, typeBounds())).withMods(mods).setComment(in.getDocComment(start))
21542190
case _ =>
@@ -2276,7 +2312,8 @@ object Parsers {
22762312
/** ConstrApp ::= SimpleType {ParArgumentExprs}
22772313
*/
22782314
val constrApp = () => {
2279-
val t = annotType()
2315+
// Using Ident(nme.ERROR) to avoid causing cascade errors on non-user-written code
2316+
val t = checkWildcard(annotType(), fallbackTree = Ident(nme.ERROR))
22802317
if (in.token == LPAREN) parArgumentExprss(wrapNew(t))
22812318
else t
22822319
}

compiler/src/dotty/tools/dotc/reporting/diagnostic/messages.scala

+2-2
Original file line numberDiff line numberDiff line change
@@ -696,10 +696,10 @@ object messages {
696696
}
697697
}
698698

699-
case class ByNameParameterNotSupported()(implicit ctx: Context)
699+
case class ByNameParameterNotSupported(tpe: untpd.TypTree)(implicit ctx: Context)
700700
extends Message(ByNameParameterNotSupportedID) {
701701
val kind = "Syntax"
702-
val msg = "By-name parameter type not allowed here."
702+
val msg = hl"By-name parameter type ${tpe} not allowed here."
703703

704704
val explanation =
705705
hl"""|By-name parameters act like functions that are only evaluated when referenced,

tests/neg/i4373.scala

+51
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
trait Base
2+
trait TypeConstr[X]
3+
4+
class X1[A >: _ | X1[_]] // error
5+
class X2[A >: _ & X2[_]] // error
6+
class X3[A >: X3[_] | _] // error
7+
class X4[A >: X4[_] & _] // error
8+
class X5[A >: _ with X5[_]] // error
9+
class X6[A >: X6[_] with _] // error
10+
11+
class A1 extends _ // error
12+
class A2 extends _ with _ // error // error
13+
class A3 extends Base with _ // error
14+
class A4 extends _ with Base // error
15+
16+
object Test {
17+
type T1 = _ // error
18+
type T2 = _[Int] // error
19+
type T3 = _ { type S } // error
20+
type T4 = [X] => _ // error
21+
22+
// Open questions:
23+
type T5 = TypeConstr[_ { type S }] // error
24+
type T6 = TypeConstr[_[Int]] // error
25+
26+
// expression types
27+
type T7 = (=> Int) | (Int => Int) // error
28+
type T8 = (=> Int) & (Int => Int) // error
29+
type T9 = (=> Int) with (Int => Int) // error
30+
type T10 = (Int => Int) | (=> Int) // error
31+
type T11 = (Int => Int) & (=> Int) // error
32+
type T12 = (Int => Int) with (=> Int) // error
33+
34+
// annotations
35+
type T13 = _ @ annotation.tailrec // error
36+
type T14 = Int @ _ // error
37+
type T15 = (_ | Int) @ annotation.tailrec // error
38+
type T16 = (Int | _) @ annotation.tailrec // error
39+
type T17 = Int @ (_ | annotation.tailrec) // error
40+
type T18 = Int @ (annotation.tailrec | _) // error
41+
42+
type T19 = (_ with Int) @ annotation.tailrec // error
43+
type T20 = (Int with _) @ annotation.tailrec // error
44+
type T21 = Int @ (_ with annotation.tailrec) // error // error
45+
type T22 = Int @ (annotation.tailrec with _) // error // error
46+
47+
type T23 = (_ & Int) @ annotation.tailrec // error
48+
type T24 = (Int & _) @ annotation.tailrec // error
49+
type T25 = Int @ (_ & annotation.tailrec) // error
50+
type T26 = Int @ (annotation.tailrec & _) // error
51+
}

tests/neg/i4373a.scala

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
// ==> 040fb47fbaf718cecb11a7d51ac5a48bf4f6a1fe.scala <==
2+
object x0 {
3+
val x0 : _ with // error // error // error

tests/neg/i4373b.scala

+4
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
// ==> 05bef7805687ba94da37177f7568e3ba7da1f91c.scala <==
2+
class x0 {
3+
x1: // error
4+
x0 | _ // error // error

tests/neg/i4373c.scala

+3
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
// ==> 18b253a4a89a84c5674165c6fc3efafad535eee3.scala <==
2+
object x0 {
3+
def x1[x2 <:_[ // error // error // error

0 commit comments

Comments
 (0)