免费〜>蹦床:递归程序因OutOfMemoryError崩溃 [英] Free ~> Trampoline : recursive program crashes with OutOfMemoryError

查看:98
本文介绍了免费〜>蹦床:递归程序因OutOfMemoryError崩溃的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

假设我正在尝试仅通过一项操作来实现一种非常简单的领域特定语言:

Suppose that I'm trying to implement a very simple domain specific language with only one operation:

printLine(line)

然后我要编写一个程序,该程序将整数n作为输入,如果n被10k整除,则会打印一些内容,然后用n + 1调用自身,直到n达到某个最大值.

Then I want to write a program that takes an integer n as input, prints something if n is divisible by 10k, and then calls itself with n + 1, until n reaches some maximum value N.

忽略所有由于理解而引起的语法噪音,我想要的是:

Omitting all syntactic noise caused by for-comprehensions, what I want is:

@annotation.tailrec def p(n: Int): Unit = {
  if (n % 10000 == 0) printLine("line")
  if (n > N) () else p(n + 1)
}

从本质上讲,这将是一种嘶嘶声".

Essentially, it would be a kind of "fizzbuzz".

这里有一些尝试使用Scalaz 7.3.0-M7中的Free monad来实现:

Here are a few attempts to implement this using the Free monad from Scalaz 7.3.0-M7:

import scalaz._

object Demo1 {

  // define operations of a little domain specific language
  sealed trait Lang[X]
  case class PrintLine(line: String) extends Lang[Unit]

  // define the domain specific language as the free monad of operations
  type Prog[X] = Free[Lang, X]

  import Free.{liftF, pure}

  // lift operations into the free monad
  def printLine(l: String): Prog[Unit] = liftF(PrintLine(l))
  def ret: Prog[Unit] = Free.pure(())

  // write a program that is just a loop that prints current index 
  // after every few iteration steps
  val mod =  100000
  val N =   1000000

  // straightforward syntax: deadly slow, exits with OutOfMemoryError
  def p0(i: Int): Prog[Unit] = for {
    _ <- (if (i % mod == 0) printLine("i = " + i) else ret)
    _ <- (if (i > N) ret else p0(i + 1))
  } yield ()

  // Same as above, but written out without `for`
  def p1(i: Int): Prog[Unit] = 
    (if (i % mod == 0) printLine("i = " + i) else ret).flatMap{
      ignore1 =>
      (if (i > N) ret else p1(i + 1)).map{ ignore2 => () }
    }

  // Same as above, with `map` attached to recursive call
  def p2(i: Int): Prog[Unit] = 
    (if (i % mod == 0) printLine("i = " + i) else ret).flatMap{
      ignore1 =>
      (if (i > N) ret else p2(i + 1).map{ ignore2 => () })
    }

  // Same as above, but without the `map`; performs ok.
  def p3(i: Int): Prog[Unit] = {
    (if (i % mod == 0) printLine("i = " + i) else ret).flatMap{ 
      ignore1 =>
      if (i > N) ret else p3(i + 1)
    }
  }

  // Variation of the above; Ok.
  def p4(i: Int): Prog[Unit] = (for {
    _ <- (if (i % mod == 0) printLine("i = " + i) else ret)
  } yield ()).flatMap{ ignored2 => 
    if (i > N) ret else p4(i + 1) 
  }

  // try to use the variable returned by the last generator after yield,
  // hope that the final `map` is optimized away (it's not optimized away...)
  def p5(i: Int): Prog[Unit] = for {
    _ <- (if (i % mod == 0) printLine("i = " + i) else ret)
    stopHere <- (if (i > N) ret else p5(i + 1))
  } yield stopHere

  // define an interpreter that translates the programs into Trampoline
  import scalaz.Trampoline
  type Exec[X] = Free.Trampoline[X]  
  val interpreter = new (Lang ~> Exec) {
    def apply[A](cmd: Lang[A]): Exec[A] = cmd match {
      case PrintLine(l) => Trampoline.delay(println(l))
    }
  }

  // try it out
  def main(args: Array[String]): Unit = {
    println("\n p0")
    p0(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError
    println("\n p1")
    p1(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError
    println("\n p2")
    p2(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError
    println("\n p3")
    p3(0).foldMap(interpreter).run // ok 
    println("\n p4")
    p4(0).foldMap(interpreter).run // ok
    println("\n p5")
    p5(0).foldMap(interpreter).run // OutOfMemory
  }
}

不幸的是,简单的转换(p0)似乎以O(N ^ 2)的开销运行,并因OutOfMemoryError崩溃.问题似乎是for -comprehension在对p0的递归调用之后追加了map{x => ()},这迫使Free monad充​​满整个内存,并带有提醒以完成'p0',然后什么也不做" ". 如果我手动展开" for理解,并显式写出最后一个flatMap(如在p3p4中一样),那么问题就消失了,并且一切运行顺利.但是,这是一个非常脆弱的解决方法:如果仅向其添加map(id),该程序的行为就会发生巨大变化,并且该map(id)甚至在代码中也不可见,因为它是由<自动生成的c6>-理解.

Unfortunately, the straightforward translation (p0) seems to run with some kind of O(N^2) overhead, and crashes with an OutOfMemoryError. The problem seems to be that the for-comprehension appends a map{x => ()} after the recursive call to p0, which forces the Free monad to fill the entire memory with reminders to "finish 'p0' and then do nothing". If I manually "unroll" the for comprehension, and write out the last flatMap explicitly (as in p3 and p4), then the problem goes away, and everything runs smoothly. This, however, is an extremely brittle workaround: the behavior of the program changes dramatically if we simply append a map(id) to it, and this map(id) isn't even visible in the code, because it is generated automatically by the for-comprehension.

在此较早的帖子中: https://apocalisp.wordpress.com/2011/10/26/tail-call-elimination-in-scala-monads/ 多次建议将递归调用包装到suspend中.这是尝试使用Applicative实例和suspend的情况:

In this older post here: https://apocalisp.wordpress.com/2011/10/26/tail-call-elimination-in-scala-monads/ it has been repeatedly advised to wrap recursive calls into a suspend. Here is an attempt with Applicative instance and suspend:

import scalaz._

// Essentially same as in `Demo1`, but this time with 
// an `Applicative` and an explicit `Suspend` in the 
// `for`-comprehension
object Demo2 {

  sealed trait Lang[H]

  case class Const[H](h: H) extends Lang[H]
  case class PrintLine[H](line: String) extends Lang[H]

  implicit object Lang extends Applicative[Lang] {
    def point[A](a: => A): Lang[A] = Const(a)
    def ap[A, B](a: => Lang[A])(f: => Lang[A => B]): Lang[B] = a match {
      case Const(x) => {
        f match {
          case Const(ab) => Const(ab(x))
          case _ => throw new Error
        }
      }
      case PrintLine(l) => PrintLine(l)
    }
  }

  type Prog[X] = Free[Lang, X]

  import Free.{liftF, pure}
  def printLine(l: String): Prog[Unit] = liftF(PrintLine(l))
  def ret: Prog[Unit] = Free.pure(())

  val mod = 100000
  val N = 2000000

  // try to suspend the entire second generator
  def p7(i: Int): Prog[Unit] = for {
    _ <- (if (i % mod == 0) printLine("i = " + i) else ret)
    _ <- Free.suspend(if (i > N) ret else p7(i + 1))
  } yield ()

  // try to suspend the recursive call
  def p8(i: Int): Prog[Unit] = for {
    _ <- (if (i % mod == 0) printLine("i = " + i) else ret)
    _ <- if (i > N) ret else Free.suspend(p8(i + 1))
  } yield ()

  import scalaz.Trampoline
  type Exec[X] = Free.Trampoline[X]

  val interpreter = new (Lang ~> Exec) {
    def apply[A](cmd: Lang[A]): Exec[A] = cmd match {
      case Const(x) => Trampoline.done(x)
      case PrintLine(l) => 
        (Trampoline.delay(println(l))).asInstanceOf[Exec[A]]
    }
  }

  def main(args: Array[String]): Unit = {
    p7(0).foldMap(interpreter).run // extremely slow; OutOfMemoryError
    p8(0).foldMap(interpreter).run // same...
  }
}

插入suspend并没有真正的帮助:仍然很慢,并且由于OutOfMemoryError而崩溃.

Inserting suspend did not really help: it's still slow, and crashes with OutOfMemoryErrors.

我应该以不同的方式使用suspend吗?

Should I use the suspend somehow differently?

也许有一些纯粹的语法上的补救措施可以使用理解而不最终产生map吗?

Maybe there is some purely syntactic remedy that makes it possible to use for-comprehensions without generating the map in the end?

如果有人能指出我在这里做错了什么以及如何修复它,我将不胜感激.

I'd really appreciate if someone could point out what I'm doing wrong here, and how to repair it.

推荐答案

Scala编译器添加的多余的map将递归从尾部位置移动到非尾部位置.可用的monad仍使该堆栈安全,但是空间复杂度变为 O(N)而不是 O(1). (具体来说,它仍然不是 O(N 2 ).)

That superfluous map added by the Scala compiler moves the recursion from tail position to non-tail position. Free monad still makes this stack safe, but space complexity becomes O(N) instead of O(1). (Specifically, it is still not O(N2).)

是否有可能使scalac优化以使map离开以解决一个单独的问题(我不知道答案).

Whether it is possible to make scalac optimize that map away makes for a separate question (which I don't know the answer to).

我将尝试说明解释p1p3时发生的情况. (我将忽略对Trampoline的翻译,这是多余的(请参见下文).)

I will try to illustrate what is going on when interpreting p1 versus p3. (I will ignore the translation to Trampoline, which is redundant (see below).)

让我使用以下速记:

def cont(i: Int): Unit => Prg[Unit] =
  ignore1 => if (i > N) ret else p3(i + 1)

现在p3(0)的解释如下

p3(0)
printLine("i = " + 0) flatMap cont(0)
// side-effect: println("i = 0")
cont(0)
p3(1)
ret flatMap cont(1)
cont(1)
p3(2)
ret flatMap cont(2)
cont(2)

以此类推...您会发现在任何时候所需的内存量都不会超过某个恒定的上限.

and so on... You see that the amount of memory needed at any point doesn't exceed some constant upper bound.

我将使用以下速记:

def cont(i: Int): Unit => Prg[Unit] =
  ignore1 => (if (i > N) ret else p1(i + 1)).map{ ignore2 => () }

def cpu: Unit => Prg[Unit] = // constant pure unit
  ignore => Free.pure(())

现在p1(0)的解释如下:

p1(0)
printLine("i = " + 0) flatMap cont(0)
// side-effect: println("i = 0")
cont(0)
p1(1) map { ignore2 => () }
// Free.map is implemented via flatMap
p1(1) flatMap cpu
(ret flatMap cont(1)) flatMap cpu
cont(1) flatMap cpu
(p1(2) map { ignore2 => () }) flatMap cpu
(p1(2) flatMap cpu) flatMap cpu
((ret flatMap cont(2)) flatMap cpu) flatMap cpu
(cont(2) flatMap cpu) flatMap cpu
((p1(3) map { ignore2 => () }) flatMap cpu) flatMap cpu
((p1(3) flatMap cpu) flatMap cpu) flatMap cpu
(((ret flatMap cont(3)) flatMap cpu) flatMap cpu) flatMap cpu

依此类推...您会发现内存消耗与N线性相关.我们只是将评估从堆栈移到了堆.

and so on... You see that the memory consumption depends linearly on N. We just moved the evaluation from stack to heap.

注意事项::要保持Free内存友好,请将递归保持在尾巴位置",即在flatMap(或map)的右侧.

Take away: To keep Free memory friendly, keep the recursion in "tail position", that is, on the right hand-side of flatMap (or map).

在旁边:由于Free已被抛光,因此无需翻译为Trampoline.您可以直接解释为Id并使用foldMapRec进行堆栈安全的解释:

Aside: The translation to Trampoline is not necessary, since Free is already trampolined. You could interpret directly to Id and use foldMapRec for stack-safe interpretation:

val idInterpreter = new (Lang ~> Id) {
  def apply[A](cmd: Lang[A]): Id[A] = cmd match {
    case PrintLine(l) => println(l)
  }
}

p0(0).foldMapRec(idInterpreter)

这将为您重新获得一部分内存(但不会使问题消失).

This will regain you some fraction of memory (but doesn't make the problem go away).

这篇关于免费〜&gt;蹦床:递归程序因OutOfMemoryError崩溃的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆