From 143d5b3f9de70f7c99a288227e6bc1cdbb8e6b2d Mon Sep 17 00:00:00 2001 From: Yilin Wei Date: Sat, 10 Feb 2024 16:48:48 +0000 Subject: [PATCH] Add initial experiments on mirror support. --- .../dotty/tools/dotc/core/Definitions.scala | 3 ++ .../src/dotty/tools/dotc/core/SymUtils.scala | 13 +++++-- .../tools/dotc/transform/PatternMatcher.scala | 18 +++------- .../dotty/tools/dotc/typer/Synthesizer.scala | 8 +++++ .../src/scala/runtime/JavaRecordMirror.scala | 34 +++++++++++++++++++ .../java-records-mirror/FromScala.scala | 5 +++ tests/pos-java16+/java-records-mirror/R2.java | 1 + 7 files changed, 65 insertions(+), 17 deletions(-) create mode 100644 library/src/scala/runtime/JavaRecordMirror.scala create mode 100644 tests/pos-java16+/java-records-mirror/FromScala.scala create mode 100644 tests/pos-java16+/java-records-mirror/R2.java diff --git a/compiler/src/dotty/tools/dotc/core/Definitions.scala b/compiler/src/dotty/tools/dotc/core/Definitions.scala index 6195a79ba0e2..d758f10aefb9 100644 --- a/compiler/src/dotty/tools/dotc/core/Definitions.scala +++ b/compiler/src/dotty/tools/dotc/core/Definitions.scala @@ -974,6 +974,9 @@ class Definitions { @tu lazy val RuntimeTuples_isInstanceOfEmptyTuple: Symbol = RuntimeTuplesModule.requiredMethod("isInstanceOfEmptyTuple") @tu lazy val RuntimeTuples_isInstanceOfNonEmptyTuple: Symbol = RuntimeTuplesModule.requiredMethod("isInstanceOfNonEmptyTuple") + @tu lazy val JavaRecordReflectMirrorTypeRef: TypeRef = requiredClassRef("scala.runtime.JavaRecordMirror") + @tu lazy val JavaRecordReflectMirrorModule: Symbol = requiredModule("scala.runtime.JavaRecordMirror") + @tu lazy val TupledFunctionTypeRef: TypeRef = requiredClassRef("scala.util.TupledFunction") def TupledFunctionClass(using Context): ClassSymbol = TupledFunctionTypeRef.symbol.asClass def RuntimeTupleFunctionsModule(using Context): Symbol = requiredModule("scala.runtime.TupledFunctions") diff --git a/compiler/src/dotty/tools/dotc/core/SymUtils.scala b/compiler/src/dotty/tools/dotc/core/SymUtils.scala index ef64119bcd20..db57e689392c 100644 --- a/compiler/src/dotty/tools/dotc/core/SymUtils.scala +++ b/compiler/src/dotty/tools/dotc/core/SymUtils.scala @@ -99,13 +99,14 @@ class SymUtils: def canAccessCtor: Boolean = def isAccessible(sym: Symbol): Boolean = ctx.owner.isContainedIn(sym) def isSub(sym: Symbol): Boolean = ctx.owner.ownersIterator.exists(_.derivesFrom(sym)) - val ctor = self.primaryConstructor + val ctor = if defn.isJavaRecordClass(self) then self.javaCanonicalConstructor else self.primaryConstructor (!ctor.isOneOf(Private | Protected) || isSub(self)) // we cant access the ctor because we do not extend cls && (!ctor.privateWithin.exists || isAccessible(ctor.privateWithin)) // check scope is compatible def companionMirror = self.useCompanionAsProductMirror - if (!self.is(CaseClass)) "it is not a case class" + + if (!(self.is(CaseClass) || defn.isJavaRecordClass(self))) "it is not a case class or record class" else if (self.is(Abstract)) "it is an abstract class" else if (self.primaryConstructor.info.paramInfoss.length != 1) "it takes more than one parameter list" else if self.isDerivedValueClass then "it is a value class" @@ -146,7 +147,7 @@ class SymUtils: && (!self.is(Method) || self.is(Accessor)) def useCompanionAsProductMirror(using Context): Boolean = - self.linkedClass.exists && !self.is(Scala2x) && !self.linkedClass.is(Case) + self.linkedClass.exists && !self.is(Scala2x) && !self.linkedClass.is(Case) && !defn.isJavaRecordClass(self) def useCompanionAsSumMirror(using Context): Boolean = def companionExtendsSum(using Context): Boolean = @@ -249,6 +250,12 @@ class SymUtils: def caseAccessors(using Context): List[Symbol] = self.info.decls.filter(_.is(CaseAccessor)) + // TODO: I'm convinced that we need to introduce a flag to get the canonical constructor. + // we should also check whether the names are erased in the ctor. If not, we should + // be able to infer the components directly from the constructor. + def javaCanonicalConstructor(using Context): Symbol = + self.info.decls.filter(_.isConstructor).tail.head + // TODO: Check if `Synthetic` is stamped properly def javaRecordComponents(using Context): List[Symbol] = self.info.decls.filter(sym => sym.is(Synthetic) && sym.is(Method) && !sym.isConstructor) diff --git a/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala b/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala index a4dd0494bc2c..0d0fd3b6e900 100644 --- a/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala +++ b/compiler/src/dotty/tools/dotc/transform/PatternMatcher.scala @@ -345,23 +345,13 @@ object PatternMatcher { def resultTypeSym = unapp.symbol.info.resultType.typeSymbol - def isSyntheticJavaRecordUnapply(sym: Symbol) = - // Since the `unapply` symbol is marked as inline, the `Typer` wraps the body of the `unapply` in a separate - // anonymous class. The result type alone is not enough to distinguish that we're calling the synthesized unapply — - // we could have defined a separate `unapply` method returning a Java record somewhere, hence we resort to using - // the `coord`. - sym.is(Synthetic) && sym.isAnonymousClass && { - val resultSym = resultTypeSym - // TODO: Can a user define a separate unapply function in Java? - val unapplyFn = resultSym.linkedClass.info.decl(nme.unapply) - // TODO: This is nasty, can we add an attachment on the anonymous function for a prior link? - defn.isJavaRecordClass(resultSym) && unapplyFn.symbol.coord == sym.coord - } - + // TODO: Check Scala -> Java, erased? + def isJavaRecordUnapply(sym: Symbol) = defn.isJavaRecordClass(resultTypeSym) def tupleSel(sym: Symbol) = ref(scrutinee).select(sym) def recordSel(sym: Symbol) = tupleSel(sym).appliedToTermArgs(Nil) - if (isSyntheticJavaRecordUnapply(unapp.symbol.owner)) + // TODO: Move this to the correct location + if (isJavaRecordUnapply(unapp.symbol.owner)) val components = resultTypeSym.javaRecordComponents.map(recordSel) matchArgsPlan(components, args, onSuccess) else if (isSyntheticScala2Unapply(unapp.symbol) && caseAccessors.length == args.length) diff --git a/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala b/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala index c94724faf4d4..285cecab2033 100644 --- a/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala +++ b/compiler/src/dotty/tools/dotc/typer/Synthesizer.scala @@ -407,6 +407,12 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context): def newTupleMirror(arity: Int): Tree = New(defn.RuntimeTupleMirrorTypeRef, Literal(Constant(arity)) :: Nil) + def newJavaRecordReflectMirror(tpe: Type) = + ref(defn.JavaRecordReflectMirrorModule) + .select(nme.apply) + .appliedToType(tpe) + .appliedTo(clsOf(tpe)) + def makeProductMirror(pre: Type, cls: Symbol, tps: Option[List[Type]]): TreeWithErrors = val accessors = cls.caseAccessors val elemLabels = accessors.map(acc => ConstantType(Constant(acc.name.toString))) @@ -427,6 +433,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context): } val mirrorRef = if cls.useCompanionAsProductMirror then companionPath(mirroredType, span) + else if defn.isJavaRecordClass(cls) then newJavaRecordReflectMirror(cls.typeRef) else if defn.isTupleClass(cls) then newTupleMirror(typeElems.size) // TODO: cls == defn.PairClass when > 22 else anonymousMirror(monoType, MirrorImpl.OfProduct(pre), span) withNoErrors(mirrorRef.cast(mirrorType).withSpan(span)) @@ -458,6 +465,7 @@ class Synthesizer(typer: Typer)(using @constructorOnly c: Context): val reason = s"it reduces to a tuple with arity $arity, expected arity <= $maxArity" withErrors(i"${defn.PairClass} is not a generic product because $reason") case MirrorSource.ClassSymbol(pre, cls) => + if cls.isGenericProduct then if ctx.runZincPhases then // The mirror should be resynthesized if the constructor of the diff --git a/library/src/scala/runtime/JavaRecordMirror.scala b/library/src/scala/runtime/JavaRecordMirror.scala new file mode 100644 index 000000000000..1beb26a5e31c --- /dev/null +++ b/library/src/scala/runtime/JavaRecordMirror.scala @@ -0,0 +1,34 @@ +package scala.runtime + +import java.lang.Record +import java.lang.reflect.Constructor +import scala.reflect.ClassTag + +// TODO: Rename to JavaRecordReflectMirror +object JavaRecordMirror: + + def apply[T <: Record](clazz: Class[T]): JavaRecordMirror[T] = + val components = clazz.getRecordComponents.nn + val constructorTypes = components.map(_.nn.getType.nn) + val constr = clazz.getDeclaredConstructor(constructorTypes*).nn + new JavaRecordMirror(components.length, constr) + + def of[T <: Record : ClassTag]: JavaRecordMirror[T] = + JavaRecordMirror(summon[ClassTag[T]].runtimeClass.asInstanceOf[Class[T]]) + +// TODO: Is a constructor serializable? +final class JavaRecordMirror[T] private(arity: Int, constr: Constructor[T]) extends scala.deriving.Mirror.Product with Serializable: + + override type MirroredMonoType <: Record + + final def fromProduct(product: Product): MirroredMonoType = + if product.productArity != arity then + throw IllegalArgumentException(s"expected Product with $arity elements, got ${product.productArity}") + else + // TODO: Check this byte code, we want to unroll to give a happy medium between JIT'ing and having tons of extra classes + val t = arity match + case 0 => constr.newInstance() + case 1 => constr.newInstance(product.productElement(0)) + case 2 => constr.newInstance(product.productElement(0), product.productElement(1)) + + t.nn.asInstanceOf[MirroredMonoType] diff --git a/tests/pos-java16+/java-records-mirror/FromScala.scala b/tests/pos-java16+/java-records-mirror/FromScala.scala new file mode 100644 index 000000000000..facf93a12c4e --- /dev/null +++ b/tests/pos-java16+/java-records-mirror/FromScala.scala @@ -0,0 +1,5 @@ +import scala.deriving.Mirror + +object C: + def useR2: Unit = + summon[Mirror.Of[R2]] diff --git a/tests/pos-java16+/java-records-mirror/R2.java b/tests/pos-java16+/java-records-mirror/R2.java new file mode 100644 index 000000000000..9ea613fd1ca9 --- /dev/null +++ b/tests/pos-java16+/java-records-mirror/R2.java @@ -0,0 +1 @@ +public record R2(int i, String s) {}