来自“Programming Scala"的合并排序导致堆栈溢出 [英] Merge sort from "Programming Scala" causes stack overflow
问题描述
以下算法的直接剪切和粘贴:
A direct cut and paste of the following algorithm:
def msort[T](less: (T, T) => Boolean)
(xs: List[T]): List[T] = {
def merge(xs: List[T], ys: List[T]): List[T] =
(xs, ys) match {
case (Nil, _) => ys
case (_, Nil) => xs
case (x :: xs1, y :: ys1) =>
if (less(x, y)) x :: merge(xs1, ys)
else y :: merge(xs, ys1)
}
val n = xs.length / 2
if (n == 0) xs
else {
val (ys, zs) = xs splitAt n
merge(msort(less)(ys), msort(less)(zs))
}
}
在 5000 个长列表上导致 StackOverflowError.
causes a StackOverflowError on 5000 long lists.
有没有什么办法可以优化这种情况,以免发生这种情况?
Is there any way to optimize this so that this doesn't occur?
推荐答案
这样做是因为它不是尾递归的.您可以通过使用非严格集合或使其成为尾递归来解决此问题.
It is doing this because it isn't tail-recursive. You can fix this by either using a non-strict collection, or by making it tail-recursive.
后一种解决方案是这样的:
The latter solution goes like this:
def msort[T](less: (T, T) => Boolean)
(xs: List[T]): List[T] = {
def merge(xs: List[T], ys: List[T], acc: List[T]): List[T] =
(xs, ys) match {
case (Nil, _) => ys.reverse ::: acc
case (_, Nil) => xs.reverse ::: acc
case (x :: xs1, y :: ys1) =>
if (less(x, y)) merge(xs1, ys, x :: acc)
else merge(xs, ys1, y :: acc)
}
val n = xs.length / 2
if (n == 0) xs
else {
val (ys, zs) = xs splitAt n
merge(msort(less)(ys), msort(less)(zs), Nil).reverse
}
}
使用非严格性涉及按名称传递参数,或使用非严格性集合,例如 Stream
.以下代码使用 Stream
只是为了防止堆栈溢出,而在其他地方使用 List
:
Using non-strictness involves either passing parameters by-name, or using non-strict collections such as Stream
. The following code uses Stream
just to prevent stack overflow, and List
elsewhere:
def msort[T](less: (T, T) => Boolean)
(xs: List[T]): List[T] = {
def merge(left: List[T], right: List[T]): Stream[T] = (left, right) match {
case (x :: xs, y :: ys) if less(x, y) => Stream.cons(x, merge(xs, right))
case (x :: xs, y :: ys) => Stream.cons(y, merge(left, ys))
case _ => if (left.isEmpty) right.toStream else left.toStream
}
val n = xs.length / 2
if (n == 0) xs
else {
val (ys, zs) = xs splitAt n
merge(msort(less)(ys), msort(less)(zs)).toList
}
}
这篇关于来自“Programming Scala"的合并排序导致堆栈溢出的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!