Skip to content

Commit 865604a

Browse files
committed
Allow cross references in collective extensions
1 parent eb060b6 commit 865604a

File tree

4 files changed

+48
-3
lines changed

4 files changed

+48
-3
lines changed

compiler/src/dotty/tools/dotc/core/Flags.scala

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,7 @@ object Flags {
499499

500500
/** Flags that can apply to a module val */
501501
val RetainedModuleValFlags: FlagSet = RetainedModuleValAndClassFlags |
502-
Override | Final | Method | Implicit | Given | Lazy |
502+
Override | Final | Method | Implicit | Given | Lazy | Extension |
503503
Accessor | AbsOverride | StableRealizable | Captured | Synchronized | Erased
504504

505505
/** Flags that can apply to a module class */

compiler/src/dotty/tools/dotc/transform/SymUtils.scala

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -220,4 +220,7 @@ class SymUtils(val self: Symbol) extends AnyVal {
220220
/** Is symbol a splice operation? */
221221
def isSplice(implicit ctx: Context): Boolean =
222222
self == defn.InternalQuoted_exprSplice || self == defn.QuotedType_splice
223+
224+
def isCollectiveExtensionClass(using Context): Boolean =
225+
self.is(ModuleClass) && self.sourceModule.is(Extension, butNot = Method)
223226
}

compiler/src/dotty/tools/dotc/typer/Typer.scala

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -401,6 +401,19 @@ class Typer extends Namer
401401
if (name == nme.ROOTPKG)
402402
return tree.withType(defn.RootPackage.termRef)
403403

404+
/** Convert a reference `f` to an extension method in a collective extension
405+
* on parameter `x` to `x.f`
406+
*/
407+
def extensionMethodSelect(xmethod: Symbol): untpd.Tree =
408+
val leadParamName = xmethod.info.paramNamess.head.head
409+
def isLeadParam(sym: Symbol) =
410+
sym.is(Param) && sym.owner.owner == xmethod.owner && sym.name == leadParamName
411+
def leadParam(ctx: Context): Symbol =
412+
ctx.scope.lookupAll(leadParamName).find(isLeadParam) match
413+
case Some(param) => param
414+
case None => leadParam(ctx.outersIterator.dropWhile(_.scope eq ctx.scope).next)
415+
untpd.cpy.Select(tree)(untpd.ref(leadParam(ctx).termRef), name)
416+
404417
val rawType = {
405418
val saved1 = unimported
406419
val saved2 = foundUnderScala2
@@ -441,8 +454,14 @@ class Typer extends Namer
441454
errorType(new MissingIdent(tree, kind, name.show), tree.sourcePos)
442455

443456
val tree1 = ownType match {
444-
case ownType: NamedType if !prefixIsElidable(ownType) =>
445-
ref(ownType).withSpan(tree.span)
457+
case ownType: NamedType =>
458+
val sym = ownType.symbol
459+
if sym.isAllOf(ExtensionMethod)
460+
&& sym.owner.isCollectiveExtensionClass
461+
&& ctx.owner.isContainedIn(sym.owner)
462+
then typed(extensionMethodSelect(sym), pt)
463+
else if prefixIsElidable(ownType) then tree.withType(ownType)
464+
else ref(ownType).withSpan(tree.span)
446465
case _ =>
447466
tree.withType(ownType)
448467
}

tests/run/collective-extensions.scala

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
extension on (x: String):
2+
def foo(y: String): String = x ++ y
3+
def bar(y: String): String = foo(y)
4+
def baz(y: String): String =
5+
val x = y
6+
bar(x)
7+
def bam(y: String): String = this.baz(x)(y)
8+
def ban(foo: String): String = x + foo
9+
def bao(y: String): String =
10+
val bam = "ABC"
11+
x ++ y ++ bam
12+
13+
def app(n: Int, suffix: String): String =
14+
if n == 0 then x ++ suffix
15+
else app(n - 1, suffix ++ suffix)
16+
17+
@main def Test =
18+
assert("abc".bar("def") == "abcdef")
19+
assert("abc".baz("def") == "abcdef")
20+
assert("abc".bam("def") == "abcdef")
21+
assert("abc".ban("def") == "abcdef")
22+
assert("abc".bao("def") == "abcdefABC")
23+
assert("abc".app(3, "!") == "abc!!!!!!!!")

0 commit comments

Comments
 (0)