diff --git a/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala b/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala index 9e6669fe9256..ab1fb4e9d1c7 100644 --- a/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala +++ b/compiler/src/scala/quoted/runtime/impl/QuotesImpl.scala @@ -1406,6 +1406,45 @@ class QuotesImpl private (using val ctx: Context) extends Quotes, QuoteUnpickler end extension end TypeCaseDefMethods + + type Wildcard = tpd.Ident + + object WildcardTypeTest extends TypeTest[Tree, Wildcard]: + def unapply(x: Tree): Option[Wildcard & x.type] = x match + case x: (tpd.Ident & x.type) if x.name == nme.WILDCARD => Some(x) + case _ => None + end WildcardTypeTest + + object Wildcard extends WildcardModule: + def apply(): Wildcard = + withDefaultPos(untpd.Ident(nme.WILDCARD).withType(dotc.core.Symbols.defn.AnyType)) + def unapply(pattern: Wildcard): true = true + end Wildcard + + type TypedTree = tpd.Typed + + object TypedTreeTypeTest extends TypeTest[Tree, TypedTree]: + def unapply(x: Tree): Option[TypedTree & x.type] = x match + case x: (tpd.Typed & x.type) => Some(x) + case _ => None + end TypedTreeTypeTest + + object TypedTree extends TypedTreeModule: + def apply(expr: Term, tpt: TypeTree): Typed = + withDefaultPos(tpd.Typed(xCheckMacroValidExpr(expr), tpt)) + def copy(original: Tree)(expr: Term, tpt: TypeTree): Typed = + tpd.cpy.Typed(original)(xCheckMacroValidExpr(expr), tpt) + def unapply(x: Typed): (Term, TypeTree) = + (x.expr, x.tpt) + end TypedTree + + given TypedTreeMethods: TypedTreeMethods with + extension (self: Typed) + def tree: Tree = self.expr + def tpt: TypeTree = self.tpt + end extension + end TypedTreeMethods + type Bind = tpd.Bind object BindTypeTest extends TypeTest[Tree, Bind]: diff --git a/compiler/src/scala/quoted/runtime/impl/printers/Extractors.scala b/compiler/src/scala/quoted/runtime/impl/printers/Extractors.scala index a05f1e41f724..bf725bb7da1a 100644 --- a/compiler/src/scala/quoted/runtime/impl/printers/Extractors.scala +++ b/compiler/src/scala/quoted/runtime/impl/printers/Extractors.scala @@ -166,12 +166,16 @@ object Extractors { this += "CaseDef(" += pat += ", " += guard += ", " += body += ")" case TypeCaseDef(pat, body) => this += "TypeCaseDef(" += pat += ", " += body += ")" + case Wildcard() => + this += "Wildcard()" case Bind(name, body) => this += "Bind(\"" += name += "\", " += body += ")" case Unapply(fun, implicits, patterns) => this += "Unapply(" += fun += ", " ++= implicits += ", " ++= patterns += ")" case Alternatives(patterns) => this += "Alternatives(" ++= patterns += ")" + case TypedTree(tree, tpt) => + this += "TypedTree(" += tree += ", " += tpt += ")" } def visitConstant(x: Constant): this.type = x match { diff --git a/compiler/src/scala/quoted/runtime/impl/printers/SourceCode.scala b/compiler/src/scala/quoted/runtime/impl/printers/SourceCode.scala index e3f0f282e524..c47e19a9b864 100644 --- a/compiler/src/scala/quoted/runtime/impl/printers/SourceCode.scala +++ b/compiler/src/scala/quoted/runtime/impl/printers/SourceCode.scala @@ -328,7 +328,7 @@ object SourceCode { } this - case Ident("_") => + case Wildcard() => this += "_" case tree: Ident => @@ -453,6 +453,15 @@ object SourceCode { printTypeOrAnnots(tpt.tpe) } } + case TypedTree(tree1, tpt) => + printPattern(tree1) + tree1 match + case Wildcard() => + this += ":" + printType(tpt.tpe) + case _ => // Alternatives, Unapply, Bind + this + case Assign(lhs, rhs) => printTree(lhs) @@ -896,13 +905,13 @@ object SourceCode { } private def printPattern(pattern: Tree): this.type = pattern match { - case Ident("_") => + case Wildcard() => this += "_" - case Bind(name, Ident("_")) => + case Bind(name, Wildcard()) => this += name - case Bind(name, Typed(Ident("_"), tpt)) => + case Bind(name, Typed(Wildcard(), tpt)) => this += highlightValDef(name) += ": " printTypeTree(tpt) @@ -928,9 +937,13 @@ object SourceCode { case Alternatives(trees) => inParens(printPatterns(trees, " | ")) - case Typed(Ident("_"), tpt) => - this += "_: " - printTypeTree(tpt) + case TypedTree(tree1, tpt) => + tree1 match + case Wildcard() => + this += "_: " + printTypeTree(tpt) + case _ => + printPattern(tree1) case v: Term => printTree(v) diff --git a/library/src/scala/quoted/Quotes.scala b/library/src/scala/quoted/Quotes.scala index e3a4075e6821..5eba3114e5a1 100644 --- a/library/src/scala/quoted/Quotes.scala +++ b/library/src/scala/quoted/Quotes.scala @@ -133,7 +133,6 @@ trait Quotes { self: runtime.QuoteUnpickler & runtime.QuoteMatching => * | +- Apply * | +- TypeApply * | +- Super - * | +- Typed * | +- Assign * | +- Block * | +- Closure @@ -146,7 +145,16 @@ trait Quotes { self: runtime.QuoteUnpickler & runtime.QuoteMatching => * | +- Inlined * | +- SelectOuter * | +- While + * | +---+- Typed + * | / + * +- TypedTree +------------------ยท + * +- Wildcard + * +- Bind + * +- Unapply + * +- Alternatives * | + * +- CaseDef + * +- TypeCaseDef * | * +- TypeTree ----+- Inferred * | +- TypeIdent @@ -164,13 +172,6 @@ trait Quotes { self: runtime.QuoteUnpickler & runtime.QuoteMatching => * | * +- TypeBoundsTree * +- WildcardTypeTree - * | - * +- CaseDef - * | - * +- TypeCaseDef - * +- Bind - * +- Unapply - * +- Alternatives * * +- ParamClause -+- TypeParamClause * +- TermParamClause @@ -1120,8 +1121,12 @@ trait Quotes { self: runtime.QuoteUnpickler & runtime.QuoteMatching => /** `TypeTest` that allows testing at runtime in a pattern match if a `Tree` is a `Typed` */ given TypedTypeTest: TypeTest[Tree, Typed] - /** Tree representing a type ascription `x: T` in the source code */ - type Typed <: Term + /** Tree representing a type ascription `x: T` in the source code. + * + * Also represents a pattern that contains a term `x`. + * Other `: T` patterns use the more general `TypeTree`. + */ + type Typed <: Term & TypeTree /** Module object of `type Typed` */ val Typed: TypedModule @@ -2049,6 +2054,56 @@ trait Quotes { self: runtime.QuoteUnpickler & runtime.QuoteMatching => // ----- Patterns ------------------------------------------------ + /** Pattern representing a `_` wildcard. */ + type Wildcard <: Tree + + /** `TypeTest` that allows testing at runtime in a pattern match if a `Tree` is a `Wildcard` */ + given WildcardTypeTest: TypeTest[Tree, Wildcard] + + /** Module object of `type Wildcard` */ + val Wildcard: WildcardModule + + /** Methods of the module object `val Wildcard` */ + trait WildcardModule { this: Wildcard.type => + def apply(): Wildcard + def unapply(pattern: Wildcard): true + } + + /** `TypeTest` that allows testing at runtime in a pattern match if a `Tree` is a `TypedTree` */ + given TypedTreeTypeTest: TypeTest[Tree, TypedTree] + + /** Tree representing a type ascription or pattern `x: T` in the source code + * + * The tree `x` may contain a `Constant`, `Ref`, `Wildcard`, `Bind`, `Unapply` or `Alternatives`. + */ + type TypedTree <: Term + + /** Module object of `type TypedTree` */ + val TypedTree: TypedTreeModule + + /** Methods of the module object `val TypedTree` */ + trait TypedTreeModule { this: TypedTree.type => + + /** Create a type ascription `: ` */ + def apply(expr: Tree, tpt: TypeTree): TypedTree + + def copy(original: Tree)(expr: Tree, tpt: TypeTree): TypedTree + + /** Matches `: ` */ + def unapply(x: TypedTree): (Tree, TypeTree) + } + + /** Makes extension methods on `TypedTree` available without any imports */ + given TypedTreeMethods: TypedTreeMethods + + /** Extension methods of `TypedTree` */ + trait TypedTreeMethods: + extension (self: TypedTree) + def tree: Tree + def tpt: TypeTree + end extension + end TypedTreeMethods + /** Pattern representing a `_ @ _` binding. */ type Bind <: Tree @@ -4327,9 +4382,11 @@ trait Quotes { self: runtime.QuoteUnpickler & runtime.QuoteMatching => case TypeBoundsTree(lo, hi) => foldTree(foldTree(x, lo)(owner), hi)(owner) case CaseDef(pat, guard, body) => foldTree(foldTrees(foldTree(x, pat)(owner), guard)(owner), body)(owner) case TypeCaseDef(pat, body) => foldTree(foldTree(x, pat)(owner), body)(owner) + case Wildcard() => x case Bind(_, body) => foldTree(x, body)(owner) case Unapply(fun, implicits, patterns) => foldTrees(foldTrees(foldTree(x, fun)(owner), implicits)(owner), patterns)(owner) case Alternatives(patterns) => foldTrees(x, patterns)(owner) + case TypedTree(tree1, tpt) => foldTree(foldTree(x, tree1)(owner), tpt)(owner) } } end TreeAccumulator @@ -4387,12 +4444,15 @@ trait Quotes { self: runtime.QuoteUnpickler & runtime.QuoteMatching => transformCaseDef(tree)(owner) case tree: TypeCaseDef => transformTypeCaseDef(tree)(owner) + case Wildcard() => tree case pattern: Bind => Bind.copy(pattern)(pattern.name, pattern.pattern) case pattern: Unapply => Unapply.copy(pattern)(transformTerm(pattern.fun)(owner), transformSubTrees(pattern.implicits)(owner), transformTrees(pattern.patterns)(owner)) case pattern: Alternatives => Alternatives.copy(pattern)(transformTrees(pattern.patterns)(owner)) + case TypedTree(expr, tpt) => + TypedTree.copy(tree)(transformTree(expr)(owner), transformTypeTree(tpt)(owner)) } } @@ -4443,7 +4503,7 @@ trait Quotes { self: runtime.QuoteUnpickler & runtime.QuoteMatching => case New(tpt) => New.copy(tree)(transformTypeTree(tpt)(owner)) case Typed(expr, tpt) => - Typed.copy(tree)(/*FIXME #12222: transformTerm(expr)(owner)*/transformTree(expr)(owner).asInstanceOf[Term], transformTypeTree(tpt)(owner)) + Typed.copy(tree)(transformTerm(expr)(owner), transformTypeTree(tpt)(owner)) case tree: NamedArg => NamedArg.copy(tree)(tree.name, transformTerm(tree.value)(owner)) case Assign(lhs, rhs) => diff --git a/project/MiMaFilters.scala b/project/MiMaFilters.scala index 21fe1ea09c14..9099a5ba901c 100644 --- a/project/MiMaFilters.scala +++ b/project/MiMaFilters.scala @@ -8,6 +8,19 @@ object MiMaFilters { exclude[MissingClassProblem]("scala.annotation.internal.ProvisionalSuperClass"), // New APIs marked @experimental in 3.0.2 - exclude[MissingClassProblem]("scala.Selectable$WithoutPreciseParameterTypes") + exclude[MissingClassProblem]("scala.Selectable$WithoutPreciseParameterTypes"), + exclude[ReversedMissingMethodProblem]("scala.quoted.Quotes#reflectModule.WildcardTypeTest"), + exclude[ReversedMissingMethodProblem]("scala.quoted.Quotes#reflectModule.Wildcard"), + exclude[ReversedMissingMethodProblem]("scala.quoted.Quotes#reflectModule.TypedTreeTypeTest"), + exclude[ReversedMissingMethodProblem]("scala.quoted.Quotes#reflectModule.TypedTree"), + exclude[ReversedMissingMethodProblem]("scala.quoted.Quotes#reflectModule.TypedTreeMethods"), + exclude[DirectMissingMethodProblem]("scala.quoted.Quotes#reflectModule.WildcardTypeTest"), + exclude[DirectMissingMethodProblem]("scala.quoted.Quotes#reflectModule.Wildcard"), + exclude[DirectMissingMethodProblem]("scala.quoted.Quotes#reflectModule.TypedTreeTypeTest"), + exclude[DirectMissingMethodProblem]("scala.quoted.Quotes#reflectModule.TypedTree"), + exclude[DirectMissingMethodProblem]("scala.quoted.Quotes#reflectModule.TypedTreeMethods"), + exclude[MissingClassProblem]("scala.quoted.Quotes$reflectModule$TypedTreeMethods"), + exclude[MissingClassProblem]("scala.quoted.Quotes$reflectModule$TypedTreeModule"), + exclude[MissingClassProblem]("scala.quoted.Quotes$reflectModule$WildcardModule"), ) } diff --git a/tests/pos-macros/i11401/X_1.scala b/tests/pos-macros/i11401/X_1.scala index 92f868d996a4..b8a5d033ec58 100644 --- a/tests/pos-macros/i11401/X_1.scala +++ b/tests/pos-macros/i11401/X_1.scala @@ -16,7 +16,7 @@ class SLSelect[S]: def fold[S](s0:S)(step: (S,SLSelect[S])=> S): S = { ??? - } + } def fold_async[S](s0:S)(step: (S,SLSelect[S])=> Future[S]): Future[S] = { ??? @@ -27,7 +27,7 @@ class SLSelect[S]: await(s0.onRead(ch)(f).runAsync()) def runAsync(): Future[S] = ??? - + object X: @@ -36,21 +36,21 @@ object X: processImpl[T]('f) } - def processImpl[T:Type](t:Expr[T])(using Quotes):Expr[Future[T]] = + def processImpl[T:Type](t:Expr[T])(using Quotes):Expr[Future[T]] = import quotes.reflect._ val r = processTree[T](t.asTerm) r.asExprOf[Future[T]] - - def processTree[T:Type](using Quotes)(t: quotes.reflect.Term):quotes.reflect.Term = + + def processTree[T:Type](using Quotes)(t: quotes.reflect.Term):quotes.reflect.Term = import quotes.reflect._ val r: Term = t match case Inlined(_,List(),body) => processTree(body) - case Inlined(d,bindings,body) => + case Inlined(d,bindings,body) => Inlined(d,bindings,processTree[T](body)) case Block(stats,expr) => Block(stats,processTree(expr)) case Apply(Apply(TypeApply(Select(x,"fold"),targs),List(state)),List(fun)) => - val nFun = processLambda[T](fun) + val nFun = processLambda[T](fun) Apply(Apply(TypeApply(Select.unique(x,"fold_async"),targs),List(state)),List(nFun)) case Apply(TypeApply(Ident("await"),targs),List(body)) => body case Typed(x,tp) => Typed(processTree(x), Inferred(TypeRepr.of[Future].appliedTo(tp.tpe)) ) @@ -58,8 +58,8 @@ object X: val checker = new TreeMap() {} checker.transformTerm(r)(Symbol.spliceOwner) r - - def processLambda[T:Type](using Quotes)(fun: quotes.reflect.Term):quotes.reflect.Term = + + def processLambda[T:Type](using Quotes)(fun: quotes.reflect.Term):quotes.reflect.Term = import quotes.reflect._ def changeArgs(oldArgs:List[Tree], newArgs:List[Tree], body:Term, owner: Symbol):Term = diff --git a/tests/pos-macros/i12188b/Macro_1.scala b/tests/pos-macros/i12188b/Macro_1.scala new file mode 100644 index 000000000000..fa9ea10e666a --- /dev/null +++ b/tests/pos-macros/i12188b/Macro_1.scala @@ -0,0 +1,13 @@ +import scala.quoted.* + +object MatchTest { + inline def test[T](inline obj: Any): Unit = ${testImpl('obj)} + + def testImpl[T](objExpr: Expr[T])(using Quotes): Expr[Unit] = { + import quotes.reflect.* + // test that the extractors work + val Inlined(None, Nil, Block(Nil, Match(param @ Ident("a"), List(CaseDef(Literal(IntConstant(1)), None, Block(Nil, Literal(UnitConstant()))), CaseDef(Wildcard(), None, Block(Nil, Literal(UnitConstant()))))))) = objExpr.asTerm + // test that the constructors work + Block(Nil, Match(param, List(CaseDef(Literal(IntConstant(1)), None, Block(Nil, Literal(UnitConstant()))), CaseDef(Wildcard(), None, Block(Nil, Literal(UnitConstant())))))).asExprOf[Unit] + } +} diff --git a/tests/pos-macros/i12188b/Test_2.scala b/tests/pos-macros/i12188b/Test_2.scala new file mode 100644 index 000000000000..f9abca65a1f8 --- /dev/null +++ b/tests/pos-macros/i12188b/Test_2.scala @@ -0,0 +1,6 @@ + +def test(a: Int) = MatchTest.test { + a match + case 1 => + case _ => +} diff --git a/tests/pos-macros/i12188c/Macro_1.scala b/tests/pos-macros/i12188c/Macro_1.scala new file mode 100644 index 000000000000..fd627964bd61 --- /dev/null +++ b/tests/pos-macros/i12188c/Macro_1.scala @@ -0,0 +1,15 @@ +import scala.quoted.* + +object MatchTest { + inline def test(a: Int): Unit = ${testImpl('a)} + + def testImpl(a: Expr[Any])(using Quotes): Expr[Unit] = { + import quotes.reflect.* + val matchTree = Match(a.asTerm, List( + CaseDef(Literal(IntConstant(1)), None, Block(Nil, Literal(UnitConstant()))), + CaseDef(Alternatives(List(Literal(IntConstant(2)), Literal(IntConstant(3)), Literal(IntConstant(4)))), None, Block(Nil, Literal(UnitConstant()))), + CaseDef(TypedTree(Alternatives(List(Literal(IntConstant(4)), Literal(IntConstant(5)))), TypeIdent(defn.IntClass)), None, Block(Nil, Literal(UnitConstant()))), + CaseDef(TypedTree(Wildcard(), TypeIdent(defn.IntClass)), None, Block(Nil, Literal(UnitConstant()))))) + matchTree.asExprOf[Unit] + } +} diff --git a/tests/pos-macros/i12188c/Test_2.scala b/tests/pos-macros/i12188c/Test_2.scala new file mode 100644 index 000000000000..90391314a3a7 --- /dev/null +++ b/tests/pos-macros/i12188c/Test_2.scala @@ -0,0 +1,2 @@ + +def test(a: Int) = MatchTest.test(a) \ No newline at end of file diff --git a/tests/run-macros/i12188.check b/tests/run-macros/i12188.check new file mode 100644 index 000000000000..922215bbeac5 --- /dev/null +++ b/tests/run-macros/i12188.check @@ -0,0 +1,3 @@ +PC1 +PC2 +default diff --git a/tests/run-macros/i12188/Macro_1.scala b/tests/run-macros/i12188/Macro_1.scala new file mode 100644 index 000000000000..9a507e3c6c81 --- /dev/null +++ b/tests/run-macros/i12188/Macro_1.scala @@ -0,0 +1,23 @@ +import scala.quoted.* + +object MatchTest { + inline def test[T](inline obj: T): String = ${testImpl('obj)} + + def testImpl[T](objExpr: Expr[T])(using qctx: Quotes, t: Type[T]): Expr[String] = { + import qctx.reflect.* + + val obj = objExpr.asTerm + val cases = obj.tpe.typeSymbol.children.map { child => + val subtype = TypeIdent(child) + val bind = Symbol.newBind(Symbol.spliceOwner, "c", Flags.EmptyFlags, subtype.tpe) + CaseDef(Bind(bind, Typed(Ref(bind), subtype)), None, Literal(StringConstant(subtype.show))) + } ::: { + CaseDef(Wildcard(), None, Literal(StringConstant("default"))) + } :: Nil + val bind = Symbol.newBind(Symbol.spliceOwner, "o", Flags.EmptyFlags, obj.tpe) + val result = Match(obj, cases) + val code = result.show(using Printer.TreeAnsiCode) + // println(code) + result.asExprOf[String] + } +} diff --git a/tests/run-macros/i12188/Test_2.scala b/tests/run-macros/i12188/Test_2.scala new file mode 100644 index 000000000000..a3d0f0de58fb --- /dev/null +++ b/tests/run-macros/i12188/Test_2.scala @@ -0,0 +1,8 @@ +sealed trait P +case class PC1(a: String) extends P +case class PC2(b: Int) extends P + +@main def Test = + println(MatchTest.test(PC1("ab"): P)) + println(MatchTest.test(PC2(10): P)) + println(MatchTest.test(null: P)) diff --git a/tests/run-staging/i5161.check b/tests/run-staging/i5161.check index a178c827d633..27c72498d7f1 100644 --- a/tests/run-staging/i5161.check +++ b/tests/run-staging/i5161.check @@ -1,6 +1,6 @@ run : Some(2) show : scala.Tuple2.apply[scala.Option[scala.Int], scala.Option[scala.Int]](scala.Some.apply[scala.Int](1), scala.Some.apply[scala.Int](1)) match { - case scala.Tuple2((scala.Some(x): scala.Some[scala.Int]), (scala.Some(y): scala.Some[scala.Int])) => + case scala.Tuple2(scala.Some(x), scala.Some(y)) => scala.Some.apply[scala.Int](x.+(y)) case _ => scala.None