Edit this page on GitHub

How to write a type class `derived` method using macros

In the main derivation documentation page, we explained the details behind Mirrors and type class derivation. Here we demonstrate how to implement a type class derived method using macros only. We follow the same example of deriving Eq instances and for simplicity we support a Product type e.g., a case class Person. The low-level technique that we will use to implement the derived method exploits quotes, splices of both expressions and types and the scala.quoted.Expr.summon method which is the equivalent of scala.compiletime.summonFrom. The former is suitable for use in a quote context, used within macros.

As in the original code, the type class definition is the same:

trait Eq[T]:
  def eqv(x: T, y: T): Boolean

We need to implement an inline method Eq.derived on the companion object of Eq that calls into a macro to produce a quoted instance for Eq[T]. Here is a possible signature:

inline def derived[T]: Eq[T] = ${ derivedMacro[T] }

def derivedMacro[T: Type](using Quotes): Expr[Eq[T]] = ???

Note, that since a type is used in a subsequent macro compilation stage it will need to be lifted to a quoted.Type by using the corresponding context bound (seen in derivedMacro).

For comparison, here is the signature of the inline derived method from the main derivation page:

inline def derived[T](using m: Mirror.Of[T]): Eq[T] = ???

Note that the macro-based derived signature does not have a Mirror parameter. This is because we can summon the Mirror inside the body of derivedMacro thus we can omit it from the signature.

One additional possibility with the body of derivedMacro here as opposed to the one with inline is that with macros it is simpler to create a fully optimised method body for eqv.

Let's say we wanted to derive an Eq instance for the following case class Person,

case class Person(name: String, age: Int) derives Eq

the equality check we are going to generate is the following:

(x: Person, y: Person) =>
  summon[Eq[String]].eqv(x.productElement(0), y.productElement(0))
  && summon[Eq[Int]].eqv(x.productElement(1), y.productElement(1))

Note that it is possible, by using the reflection API, to further optimise and directly reference the fields of Person, but for clear understanding we will only use quoted expressions.

The code to generates this body can be seen in the eqProductBody method, shown here as part of the definition for the derivedMacro method:

def derivedMacro[T: Type](using Quotes): Expr[Eq[T]] =

  val ev: Expr[Mirror.Of[T]] = Expr.summon[Mirror.Of[T]].get

  ev match
    case '{ $m: Mirror.ProductOf[T] { type MirroredElemTypes = elementTypes }} =>
      val elemInstances = summonInstances[T, elementTypes]
      def eqProductBody(x: Expr[Product], y: Expr[Product])(using Quotes): Expr[Boolean] = {
        if elemInstances.isEmpty then
          Expr(true)
        else
          elemInstances.zipWithIndex.map {
            case ('{ $elem: Eq[t] }, index) =>
              val indexExpr = Expr(index)
              val e1 = '{ $x.productElement($indexExpr).asInstanceOf[t] }
              val e2 = '{ $y.productElement($indexExpr).asInstanceOf[t] }
              '{ $elem.eqv($e1, $e2) }
          }.reduce((acc, elem) => '{ $acc && $elem })
        end if
      }
      '{ eqProduct((x: T, y: T) => ${eqProductBody('x.asExprOf[Product], 'y.asExprOf[Product])}) }

    // case for Mirror.SumOf[T] ...

Note, that in the version without macros, we can merely write summonInstances[T, m.MirroredElemTypes] inside the inline method but here, since Expr.summon is required, we can extract the element types in a macro fashion. Being inside a macro, our first reaction would be to write the code below:

'{
  summonInstances[T, $m.MirroredElemTypes]
}

However, since the path inside the type argument is not stable this cannot be used. Instead we extract the tuple-type for element types using pattern matching over quotes and more specifically of the refined type:

   case '{ $m: Mirror.ProductOf[T] { type MirroredElemTypes = elementTypes }} => ...

Shown below is the implementation of summonInstances as a macro, which for each type elem in the tuple type, calls deriveOrSummon[T, elem].

To understand deriveOrSummon, consider that if elem derives from the parent T type, then it is a recursive derivation. Recursive derivation usually happens for types such as scala.collection.immutable.::. If elem does not derive from T, then there must exist a contextual Eq[elem] instance.

def summonInstances[T: Type, Elems: Type](using Quotes): List[Expr[Eq[?]]] =
  Type.of[Elems] match
    case '[elem *: elems] => deriveOrSummon[T, elem] :: summonInstances[T, elems]
    case '[EmptyTuple]    => Nil

def deriveOrSummon[T: Type, Elem: Type](using Quotes): Expr[Eq[Elem]] =
  Type.of[Elem] match
    case '[T] => deriveRec[T, Elem]
    case _    => '{ summonInline[Eq[Elem]] }

def deriveRec[T: Type, Elem: Type](using Quotes): Expr[Eq[Elem]] =
  Type.of[T] match
    case '[Elem] => '{ error("infinite recursive derivation") }
    case _       => derivedMacro[Elem] // recursive derivation

The full code is shown below:

import compiletime.*
import scala.deriving.*
import scala.quoted.*


trait Eq[T]:
  def eqv(x: T, y: T): Boolean

object Eq:
  given Eq[String]:
    def eqv(x: String, y: String) = x == y

  given Eq[Int]:
    def eqv(x: Int, y: Int) = x == y

  def eqProduct[T](body: (T, T) => Boolean): Eq[T] =
    new Eq[T]:
      def eqv(x: T, y: T): Boolean = body(x, y)

  def eqSum[T](body: (T, T) => Boolean): Eq[T] =
    new Eq[T]:
      def eqv(x: T, y: T): Boolean = body(x, y)

  def summonInstances[T: Type, Elems: Type](using Quotes): List[Expr[Eq[?]]] =
    Type.of[Elems] match
      case '[elem *: elems] => deriveOrSummon[T, elem] :: summonInstances[T, elems]
      case '[EmptyTuple]    => Nil

  def deriveOrSummon[T: Type, Elem: Type](using Quotes): Expr[Eq[Elem]] =
    Type.of[Elem] match
      case '[T] => deriveRec[T, Elem]
      case _    => '{ summonInline[Eq[Elem]] }

  def deriveRec[T: Type, Elem: Type](using Quotes): Expr[Eq[Elem]] =
    Type.of[T] match
      case '[Elem] => '{ error("infinite recursive derivation") }
      case _       => derivedMacro[Elem] // recursive derivation

  inline def derived[T]: Eq[T] = ${ derivedMacro[T] }

  def derivedMacro[T: Type](using Quotes): Expr[Eq[T]] =

    val ev: Expr[Mirror.Of[T]] = Expr.summon[Mirror.Of[T]].get

    ev match
      case '{ $m: Mirror.ProductOf[T] { type MirroredElemTypes = elementTypes }} =>
        val elemInstances = summonInstances[T, elementTypes]
        def eqProductBody(x: Expr[Product], y: Expr[Product])(using Quotes): Expr[Boolean] = {
          if elemInstances.isEmpty then
            Expr(true)
          else
            elemInstances.zipWithIndex.map {
              case ('{ $elem: Eq[t] }, index) =>
                val indexExpr = Expr(index)
                val e1 = '{ $x.productElement($indexExpr).asInstanceOf[t] }
                val e2 = '{ $y.productElement($indexExpr).asInstanceOf[t] }
                '{ $elem.eqv($e1, $e2) }
            }.reduce((acc, elem) => '{ $acc && $elem })
          end if
        }
        '{ eqProduct((x: T, y: T) => ${eqProductBody('x.asExprOf[Product], 'y.asExprOf[Product])}) }

      case '{ $m: Mirror.SumOf[T] { type MirroredElemTypes = elementTypes }} =>
        val elemInstances = summonInstances[T, elementTypes]
        val elements = Expr.ofList(elemInstances)

        def eqSumBody(x: Expr[T], y: Expr[T])(using Quotes): Expr[Boolean] =
          val ordx = '{ $m.ordinal($x) }
          val ordy = '{ $m.ordinal($y) }
          '{ $ordx == $ordy && $elements($ordx).asInstanceOf[Eq[Any]].eqv($x, $y) }

        '{ eqSum((x: T, y: T) => ${eqSumBody('x, 'y)}) }
  end derivedMacro
end Eq