Skip to content

Commit a2a1112

Browse files
authored
Merge pull request #5078 from dotty-staging/fix-4984
Fix #4984: support name-based unapplySeq
2 parents e34bb2d + 6a0fc8c commit a2a1112

File tree

10 files changed

+274
-16
lines changed

10 files changed

+274
-16
lines changed

compiler/src/dotty/tools/dotc/core/Definitions.scala

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -422,6 +422,10 @@ class Definitions {
422422
def Seq_drop(implicit ctx: Context) = Seq_dropR.symbol
423423
lazy val Seq_lengthCompareR = SeqClass.requiredMethodRef(nme.lengthCompare)
424424
def Seq_lengthCompare(implicit ctx: Context) = Seq_lengthCompareR.symbol
425+
lazy val Seq_lengthR = SeqClass.requiredMethodRef(nme.length)
426+
def Seq_length(implicit ctx: Context) = Seq_lengthR.symbol
427+
lazy val Seq_toSeqR = SeqClass.requiredMethodRef(nme.toSeq)
428+
def Seq_toSeq(implicit ctx: Context) = Seq_toSeqR.symbol
425429

426430
lazy val ArrayType: TypeRef = ctx.requiredClassRef("scala.Array")
427431
def ArrayClass(implicit ctx: Context) = ArrayType.symbol.asClass

compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ object PatternMatcher {
254254
*/
255255
def matchElemsPlan(seqSym: Symbol, args: List[Tree], exact: Boolean, onSuccess: Plan) = {
256256
val selectors = args.indices.toList.map(idx =>
257-
ref(seqSym).select(nme.apply).appliedTo(Literal(Constant(idx))))
257+
ref(seqSym).select(defn.Seq_apply.matchingMember(seqSym.info)).appliedTo(Literal(Constant(idx))))
258258
TestPlan(LengthTest(args.length, exact), seqSym, seqSym.pos,
259259
matchArgsPlan(selectors, args, onSuccess))
260260
}
@@ -265,8 +265,13 @@ object PatternMatcher {
265265
def unapplySeqPlan(getResult: Symbol, args: List[Tree]): Plan = args.lastOption match {
266266
case Some(VarArgPattern(arg)) =>
267267
val matchRemaining =
268-
if (args.length == 1)
269-
patternPlan(getResult, arg, onSuccess)
268+
if (args.length == 1) {
269+
val toSeq = ref(getResult)
270+
.select(defn.Seq_toSeq.matchingMember(getResult.info))
271+
letAbstract(toSeq) { toSeqResult =>
272+
patternPlan(toSeqResult, arg, onSuccess)
273+
}
274+
}
270275
else {
271276
val dropped = ref(getResult)
272277
.select(defn.Seq_drop.matchingMember(getResult.info))
@@ -638,11 +643,18 @@ object PatternMatcher {
638643
case EqualTest(tree) =>
639644
tree.equal(scrutinee)
640645
case LengthTest(len, exact) =>
641-
scrutinee
642-
.select(defn.Seq_lengthCompare.matchingMember(scrutinee.tpe))
643-
.appliedTo(Literal(Constant(len)))
644-
.select(if (exact) defn.Int_== else defn.Int_>=)
645-
.appliedTo(Literal(Constant(0)))
646+
val lengthCompareSym = defn.Seq_lengthCompare.matchingMember(scrutinee.tpe)
647+
if (lengthCompareSym.exists)
648+
scrutinee
649+
.select(defn.Seq_lengthCompare.matchingMember(scrutinee.tpe))
650+
.appliedTo(Literal(Constant(len)))
651+
.select(if (exact) defn.Int_== else defn.Int_>=)
652+
.appliedTo(Literal(Constant(0)))
653+
else // try length
654+
scrutinee
655+
.select(defn.Seq_length.matchingMember(scrutinee.tpe))
656+
.select(if (exact) defn.Int_== else defn.Int_>=)
657+
.appliedTo(Literal(Constant(len)))
646658
case TypeTest(tpt) =>
647659
val expectedTp = tpt.tpe
648660

compiler/src/dotty/tools/dotc/typer/Applications.scala

Lines changed: 36 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,43 @@ object Applications {
100100
Nil
101101
}
102102

103+
/** If `getType` is of the form:
104+
* ```
105+
* {
106+
* def lengthCompare(len: Int): Int // or, def length: Int
107+
* def apply(i: Int): T = a(i)
108+
* def drop(n: Int): scala.Seq[T]
109+
* def toSeq: scala.Seq[T]
110+
* }
111+
* ```
112+
* returns `T`, otherwise NoType.
113+
*/
114+
def unapplySeqTypeElemTp(getTp: Type): Type = {
115+
def lengthTp = ExprType(defn.IntType)
116+
def lengthCompareTp = MethodType(List(defn.IntType), defn.IntType)
117+
def applyTp(elemTp: Type) = MethodType(List(defn.IntType), elemTp)
118+
def dropTp(elemTp: Type) = MethodType(List(defn.IntType), defn.SeqType.appliedTo(elemTp))
119+
def toSeqTp(elemTp: Type) = ExprType(defn.SeqType.appliedTo(elemTp))
120+
121+
// the result type of `def apply(i: Int): T`
122+
val elemTp = getTp.member(nme.apply).suchThat(_.info <:< applyTp(WildcardType)).info.resultType
123+
124+
def hasMethod(name: Name, tp: Type) =
125+
getTp.member(name).suchThat(getTp.memberInfo(_) <:< tp).exists
126+
127+
val isValid =
128+
elemTp.exists &&
129+
(hasMethod(nme.lengthCompare, lengthCompareTp) || hasMethod(nme.length, lengthTp)) &&
130+
hasMethod(nme.drop, dropTp(elemTp)) &&
131+
hasMethod(nme.toSeq, toSeqTp(elemTp))
132+
133+
if (isValid) elemTp else NoType
134+
}
135+
103136
if (unapplyName == nme.unapplySeq) {
104-
if (unapplyResult derivesFrom defn.SeqClass) seqSelector :: Nil
105-
else if (isGetMatch(unapplyResult, pos) && getTp.derivesFrom(defn.SeqClass)) {
106-
val seqArg = getTp.elemType.hiBound
107-
if (seqArg.exists) args.map(Function.const(seqArg))
137+
if (isGetMatch(unapplyResult, pos)) {
138+
val elemTp = unapplySeqTypeElemTp(getTp)
139+
if (elemTp.exists) args.map(Function.const(elemTp))
108140
else fail
109141
}
110142
else fail

docs/docs/reference/changed/pattern-matching.md

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,23 @@ object FirstChars {
6363
```
6464

6565

66-
## Seq Pattern
66+
## Name-based Seq Pattern
6767

6868
- Extractor defines `def unapplySeq(x: T): U`
6969
- `U` has (parameterless `def` or `val`) members `isEmpty: Boolean` and `get: S`
70-
- `S <: Seq[V]`
71-
- Pattern-matching on `N` pattern with types `V, V, ..., V`, where `N` is the runtime size of the `Seq`.
70+
- `S` conforms to `X`, `T2` and `T3` conform to `T1`
71+
72+
```Scala
73+
type X = {
74+
def lengthCompare(len: Int): Int // or, `def length: Int`
75+
def apply(i: Int): T1
76+
def drop(n: Int): scala.Seq[T2]
77+
def toSeq: scala.Seq[T3]
78+
}
79+
```
80+
81+
- Pattern-matching on _exactly_ `N` simple patterns with types `T1, T1, ..., T1`, where `N` is the runtime size of the sequence, or
82+
- Pattern-matching on `>= N` simple patterns and _a vararg pattern_ (e.g., `xs: _*`) with types `T1, T1, ..., T1, Seq[T1]`, where `N` is the minimum size of the sequence.
7283

7384
<!-- To be kept in sync with tests/new/patmat-spec.scala -->
7485

@@ -87,7 +98,7 @@ object CharList {
8798
```
8899

89100

90-
## Name Based Pattern
101+
## Name-based Pattern
91102

92103
- Extractor defines `def unapply(x: T): U`
93104
- `U` has (parameterless `def` or `val`) members `isEmpty: Boolean` and `get: S`

tests/neg/i4984.scala

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
object Array2 {
2+
def unapplySeq(x: Array[Int]): Data = new Data
3+
class Data {
4+
def isEmpty: Boolean = false
5+
def get: Data = this
6+
def lengthCompare(len: Int): Int = 0
7+
def apply(i: Int): Int = 3
8+
// drop return type, not conforming to apply's
9+
def drop(n: Int): scala.Seq[String] = Seq("hello")
10+
def toSeq: scala.Seq[Int] = Seq(6, 7)
11+
}
12+
}
13+
14+
object Array3 {
15+
def unapplySeq(x: Array[Int]): Data = new Data
16+
class Data {
17+
def isEmpty: Boolean = false
18+
def get: Data = this
19+
def lengthCompare(len: Int): Int = 0
20+
// missing apply
21+
def drop(n: Int): scala.Seq[Int] = ???
22+
def toSeq: scala.Seq[Int] = ???
23+
}
24+
}
25+
26+
object Test {
27+
def test(xs: Array[Int]): Int = xs match {
28+
case Array2(x, y) => 1 // error
29+
case Array2(x, y, xs: _*) => 2 // error
30+
case Array2(xs: _*) => 3 // error
31+
}
32+
33+
def test2(xs: Array[Int]): Int = xs match {
34+
case Array3(x, y) => 1 // error
35+
case Array3(x, y, xs: _*) => 2 // error
36+
case Array3(xs: _*) => 3 // error
37+
}
38+
}

tests/pos/i4984.scala

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
object Array2 {
2+
def unapplySeq[T](x: Array[T]): UnapplySeqWrapper[T] = new UnapplySeqWrapper(x)
3+
4+
final class UnapplySeqWrapper[T](private val a: Array[T]) extends AnyVal {
5+
def isEmpty: Boolean = false
6+
def get: UnapplySeqWrapper[T] = this
7+
def lengthCompare(len: Int): Int = a.lengthCompare(len)
8+
def apply(i: Int): T = a(i)
9+
def drop(n: Int): scala.Seq[T] = ???
10+
def toSeq: scala.Seq[T] = a.toSeq // clones the array
11+
}
12+
}
13+
14+
class Test {
15+
def test1(xs: Array[Int]): Int = xs match {
16+
case Array2(x, y) => x + y
17+
}
18+
19+
def test2(xs: Array[Int]): Seq[Int] = xs match {
20+
case Array2(x, y, xs:_*) => xs
21+
}
22+
23+
def test3(xs: Array[Int]): Seq[Int] = xs match {
24+
case Array2(xs:_*) => xs
25+
}
26+
}

tests/run/i4984b.scala

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
object Array2 {
2+
def unapplySeq(x: Array[Int]): Data = new Data
3+
4+
final class Data {
5+
def isEmpty: Boolean = false
6+
def get: Data = this
7+
def lengthCompare(len: Int): Int = 0
8+
def apply(i: Int): Int = 3
9+
def drop(n: Int): scala.Seq[Int] = Seq(2, 5)
10+
def toSeq: scala.Seq[Int] = Seq(6, 7)
11+
}
12+
}
13+
14+
object Test {
15+
def test1(xs: Array[Int]): Int = xs match {
16+
case Array2(x, y) => x + y
17+
}
18+
19+
def test2(xs: Array[Int]): Seq[Int] = xs match {
20+
case Array2(x, y, xs:_*) => xs
21+
}
22+
23+
def test3(xs: Array[Int]): Seq[Int] = xs match {
24+
case Array2(xs:_*) => xs
25+
}
26+
27+
def main(args: Array[String]): Unit = {
28+
test1(Array(3, 5))
29+
test2(Array(3, 5))
30+
test3(Array(3, 5))
31+
}
32+
}

tests/run/i4984c.scala

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
object Array2 {
2+
def unapplySeq(x: Array[Int]): Data = new Data
3+
4+
final class Data {
5+
def isEmpty: Boolean = false
6+
def get: Data = this
7+
def length: Int = 2
8+
def apply(i: Int): Int = 3
9+
def drop(n: Int): scala.Seq[Int] = Seq(2, 5)
10+
def toSeq: scala.Seq[Int] = Seq(6, 7)
11+
}
12+
}
13+
14+
object Test {
15+
def test1(xs: Array[Int]): Int = xs match {
16+
case Array2(x, y) => x + y
17+
}
18+
19+
def test2(xs: Array[Int]): Seq[Int] = xs match {
20+
case Array2(x, y, xs:_*) => xs
21+
}
22+
23+
def test3(xs: Array[Int]): Seq[Int] = xs match {
24+
case Array2(xs:_*) => xs
25+
}
26+
27+
def main(args: Array[String]): Unit = {
28+
test1(Array(3, 5))
29+
test2(Array(3, 5))
30+
test3(Array(3, 5))
31+
}
32+
}

tests/run/i4984d.scala

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
object Array2 {
2+
def unapplySeq(x: Array[Int]): Data1 = new Data1
3+
4+
class Data1 {
5+
def isEmpty: Boolean = false
6+
def get: Data2 = new Data2
7+
}
8+
9+
class Data2 {
10+
def apply(i: Int): Int = 3
11+
def drop(n: Int): scala.Seq[Int] = Seq(2, 5)
12+
def toSeq: scala.Seq[Int] = Seq(6, 7)
13+
def lengthCompare(len: Int): Int = 0
14+
}
15+
}
16+
17+
object Test {
18+
def test1(xs: Array[Int]): Int = xs match {
19+
case Array2(x, y) => x + y
20+
}
21+
22+
def test2(xs: Array[Int]): Seq[Int] = xs match {
23+
case Array2(x, y, xs:_*) => xs
24+
}
25+
26+
def test3(xs: Array[Int]): Seq[Int] = xs match {
27+
case Array2(xs:_*) => xs
28+
}
29+
30+
def main(args: Array[String]): Unit = {
31+
test1(Array(3, 5))
32+
test2(Array(3, 5))
33+
test3(Array(3, 5))
34+
}
35+
}

tests/run/i4984e.scala

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
object Array2 {
2+
def unapplySeq(x: Array[Int]): Data = new Data
3+
4+
final class Data {
5+
def isEmpty: Boolean = false
6+
def get: Data = this
7+
def lengthCompare(len: Int): Int = 0
8+
def lengthCompare: Int = 0
9+
def apply(i: Int): Int = 3
10+
def apply(i: String): Int = 3
11+
def drop(n: Int): scala.Seq[Int] = Seq(2, 5)
12+
def drop: scala.Seq[Int] = Seq(2, 5)
13+
def toSeq: scala.Seq[Int] = Seq(6, 7)
14+
def toSeq(x: Int): scala.Seq[Int] = Seq(6, 7)
15+
}
16+
}
17+
18+
object Test {
19+
def test1(xs: Array[Int]): Int = xs match {
20+
case Array2(x, y) => x + y
21+
}
22+
23+
def test2(xs: Array[Int]): Seq[Int] = xs match {
24+
case Array2(x, y, xs:_*) => xs
25+
}
26+
27+
def test3(xs: Array[Int]): Seq[Int] = xs match {
28+
case Array2(xs:_*) => xs
29+
}
30+
31+
def main(args: Array[String]): Unit = {
32+
test1(Array(3, 5))
33+
test2(Array(3, 5))
34+
test3(Array(3, 5))
35+
}
36+
}

0 commit comments

Comments
 (0)