有没有一种有效的方法来生成具有给定总和或平均值的范围内的N个随机整数? [英] Is there an efficient way to generate N random integers in a range that have a given sum or average?

查看:80
本文介绍了有没有一种有效的方法来生成具有给定总和或平均值的范围内的N个随机整数?的处理方法,对大家解决问题具有一定的参考价值,需要的朋友们下面随着小编来一起学习吧!

问题描述

是否存在一种有效的方法来生成N个整数的随机组合,使得&pdash?




  • 每个整数都在区间[ min max ],

  • 整数的总和为 sum

  • 整数可以任何顺序出现(例如,随机顺序),并且

  • 是从满足其他要求的所有组合中均匀地随机选择的吗?



是否存在类似的随机组合算法,其中整数必须按其值的排序顺序显示(而不是按任何顺序)?



(选择均值平均值的适当组合 sum = N * mean ,c $ c>是一个特例,此问题等同于生成 sum 分为N个部分,每个部分的间隔为[ min max ],并以任意顺序出现,或者按它们的值排序,如



我知道,对于以随机顺序出现的组合,可以通过以下方式解决此问题(EDIT [Apr. 27]:修改了算法。):


  1. 如果 N * max<总和 N *分钟> sum ,没有解决方案。


  2. 如果 N * max == sum ,只有一种解决方案,其中所有 N 个数字都等于 max 。如果 N * min ==和,则只有一种解决方案,其中所有 N 个数字均等于 min


  3. 使用Smith和Tromble中给出的算法(从单位单纯形抽样,2004年)来生成N个总和 sum的随机非负整数- N *分钟


  4. min 加到以此方式生成的每个数字中。


  5. 如果任何数字大于 max ,请转到步骤3。


但是,如果 max 远小于<$ c,则此算法速度较慢$ c>总和。例如,根据我的测试(上述特殊情况的实现涉及平均值),该算法平均拒绝$


  • 如果 N = 7,最小= 3,最大= 10,总和= 42,则大约有1.6个样本,但是

  • 如果 N = 20,最小= 3,最大= 10,总和= 120 ,则大约有30.6个样本。



是否有一种方法可以修改此算法,使其对大N高效,同时仍然满足上述要求?



编辑:



作为评论中建议的替代方法,一种有效的产生有效随机组合(满足除最后一个条件以外的条件)的有效方法是:


  1. 计算 X ,给定 sum 最小最大

  2. 选择 Y [0,X)中的统一随机整数。

  3. 转换( unrank) Y 为有效组合。

但是,是否有一个公式可以计算有效组合(或排列)的数量,并且有一种转换整数的方法一个有效的组合? ] = 0
如果x + y + z == sum并且x <= y和y <= z
end
combos [[x,y,z]] = 0
结束
结束
结束
结束

3000.times {| x |
f = algorithm(3,sum,mn,mx)
combos [f.sort] + = 1
permus [f] + = 1
}
p combos
p permus

编辑(4月29日):重新添加了当前实现的Ruby代码。 / p>

下面的代码示例在Ruby中给出,但我的问题与编程语言无关:

  def posintwithsum(n,总计)
如果n <= 0或总计<= 0
ls = [0]
ret = []
,而ls.length< n
c = 1 + rand(total-1)
= 1中的j的错误
... ls.length
如果ls [j] == c
找到=真
中断
结束
结束
如果找到== false; ls.push(c);结束
结束
ls.sort!
ls.push(total)
for in in 1 ... ls.length
ret.push(ls [i]-ls [i-1])$ ​​b $ b end
return ret
end

def integersWithSum(n,total)如果n <= 0或total <= 0则提高

ret = posintwithsum( n,总计+ n)i在0中的
... ret.length
ret [i] = ret [i]-1
end
return ret
end

#生成100个有效样本
mn = 3
mx = 10
sum = 42
n = 7
100.times {
而真
pp = integersWithSum(n,sum-n * mn).map {| x | x + mn}
如果!pp.find {| x | x> mx}
p pp; break#输出样本并中断
end
end
}


解决方案

这是我在Java中的解决方案。它具有完整的功能,并包含两个生成器:用于未排序分区的 PermutationPartitionGenerator 和用于已排序分区的 CombinationPartitionGenerator 。您的生成器还在类 SmithTromblePartitionGenerator 中实现,以进行比较。 SequentialEnumerator 类按顺序枚举所有可能的分区(未排序或已排序,取决于参数)。我已经为所有这些生成器添加了全面的测试(包括您的测试用例)。
大部分情况下,实现是不言自明的。如果您有任何疑问,我会在两天内回答。

  import java.util.Random; 
import java.util.function.Supplier;

公共抽象类PartitionGenerators实现Supplier {
公共静态最终Random rand = new Random();
保护的最终int numberCount;
保护的最终int分钟;
保护的最终int范围;
保护的最终整数金额; //移和
受保护的最终布尔值排序;

受保护的PartitionGenerator(int numberCount,int min,int max,int sum,布尔排序){
if(numberCount <= 0)
throw new IllegalArgumentException( Number count应该是积极的);
this.numberCount = numberCount;
this.min = min;
范围=最大值-最小值;
if(范围< 0)
throw new IllegalArgumentException( min> max);
sum-= numberCount *分钟;
if(sum< 0)
throw new IllegalArgumentException( Sum is too small);
if(numberCount * range∑ sum)
throw new IllegalArgumentException( Sum太大);
this.sum =总和;
this.sorted =已排序;
}

//此生成器是否返回排序后的数组(即组合)
public final boolean isSorted(){
return sorted;
}

公共接口GeneratorFactory {
PartitionGenerator create(int numberCount,int min,int max,int sum);
}
}

import java.math.BigInteger;

//具有给定总和的具有重复的置换(即未排序的向量)
公共类PermutationPartitionGenerator扩展PartitionGenerator {
private final double [] [] distributionTable;

public PermutationPartitionGenerator(int numberCount,int min,int max,int sum){
super(numberCount,min,max,sum,false);复制代码
distributionTable = calculateSolutionCountTable();
}

private double [] [] calculateSolutionCountTable(){
double [] [] table = new double [numberCount + 1] [sum + 1];
BigInteger [] a = new BigInteger [sum + 1];
BigInteger [] b =新的BigInteger [sum + 1];
for(int i = 1; i <= sum; i ++)
a [i] = BigInteger.ZERO;
a [0] = BigInteger.ONE;
table [0] [0] = 1.0;
for(int n = 1; n< = numberCount; n ++){
double [] t = table [n];
for(int s = 0; s< = sum; s ++){
BigInteger z = BigInteger.ZERO;
for(int i = Math.max(0,s-range); i< = s; i ++)
z = z.add(a [i]);
b [s] = z;
t [s] = z.doubleValue();
}
//交换a和b
BigInteger [] c = b;
b = a;
a = c;
}
回报表;
}

@Override
public int [] get(){
int [] p = new int [numberCount];
int s =总和; //当前总和
for(int i = numberCount-1; i> = 0; i--){
double t = rand.nextDouble()* distributionTable [i + 1] [s] ;
double [] tableRow = distributionTable [i];
int oldSum = s;
// lowerBound仅出于安全性而引入,不应越过
int lowerBound = s-range;
if(lowerBound< 0)
lowerBound = 0;
s ++;
do
t-= tableRow [-s];
// s可以等于LowerBound,t> 0仅由于不精确的减法
而(t> 0& s> lowerBound);
p [i] = min +(oldSum-s);
}
assert s == 0;
return p;
}

public static final GeneratorFactory factory =(numberCount,min,max,sum)->
新的PermutationPartitionGenerator(numberCount,min,max,sum);
}

import java.math.BigInteger;

//具有给定总和的重复组合(即排序的向量)
公共类CombinationPartitionGenerator扩展PartitionGenerator {
private final double [] [] [] distributionTable;

public CombinationPartitionGenerator(int numberCount,int min,int max,int sum){
super(numberCount,min,max,sum,true);复制代码
distributionTable = calculateSolutionCountTable();
}

private double [] [] [] calculateSolutionCountTable(){
double [] [] [] table = new double [numberCount + 1] [range + 1] [sum + 1];
BigInteger [] [] a = new BigInteger [range +1] [sum +1];
BigInteger [] [] b = new BigInteger [range +1] [sum +1];
double [] [] t = table [0];
for(int m = 0; m< = range; m ++){
a [m] [0] = BigInteger.ONE;
t [m] [0] = 1.0;
for(int s = 1; s< = sum; s ++){
a [m] [s] = BigInteger.ZERO;
t [m] [s] = 0.0;
}
}
for(int n = 1; n< = numberCount; n ++){
t = table [n];
for(int m = 0; m< = range; m ++)
for(int s = 0; s< = sum; s ++){
BigInteger z;
if(m == 0)
z = a [0] [s];
else {
z = b [m-1] [s];
if(m< = s)
z = z.add(a [m] [s-m]);
}
b [m] [s] = z;
t [m] [s] = z.doubleValue();
}
//交换a和b
BigInteger [] [] c = b;
b = a;
a = c;
}
回报表;
}

@Override
public int [] get(){
int [] p = new int [numberCount];
int m =范围; //当前最大
int s = sum; //当前总和
for(int i = numberCount-1; i> = 0; i--){
double t = rand.nextDouble()* distributionTable [i +1] [m] [s];
double [] [] tableCut = distributionTable [i];如果(s m = s,则
;
s-= m;
而(true){
t-= tableCut [m] [s];
// m可以为0,t>仅当(t< = 0 || m == 0)
中断时,才由于不精确的减法
而返回0;
m--;
s ++;
}
p [i] = min + m;
}
assert s == 0;
return p;
}

public static final GeneratorFactory factory =(numberCount,min,max,sum)->
新的CombinationPartitionGenerator(numberCount,min,max,sum);
}

import java.util。*;

公共类SmithTromblePartitionGenerator扩展了PartitionGenerator {
公共SmithTromblePartitionGenerator(int numberCount,int min,int max,int sum){
super(numberCount,min,max,sum,false) ;
}

@Override
public int [] get(){
List< Integer> ls =新的ArrayList<>(numberCount + 1);
int [] ret = new int [numberCount];
int增加总计=总数+数字计数;
而(true){
ls.add(0);
而(ls.size()< numberCount){
int c = 1 + rand.nextInt(increasedSum-1);
if(!ls.contains(c))
ls.add(c);
}
Collections.sort(ls);
ls.add(increasedSum);
boolean good = true;
for(int i = 0; i< numberCount; i ++){
int x = ls.get(i + 1)-ls.get(i)-1;
if(x> range){
good = false;
休息时间;
}
ret [i] = x;
}
如果(好){
for(int i = 0; i< numberCount; i ++)
ret [i] + = min;
回程;
}
ls.clear();
}
}

public static final GeneratorFactory工厂=(numberCount,min,max,sum)->
新的SmithTromblePartitionGenerator(numberCount,min,max,sum);
}

import java.util.Arrays;

//枚举具有给定参数的所有分区
公共类SequentialEnumerator扩展PartitionGenerator {
private final int max;
private final int [] p;
个私有布尔值已完成;

public SequentialEnumerator(int numberCount,int min,int max,int sum,布尔排序){
super(numberCount,min,max,sum,sorted);
this.max = max;
p = new int [numberCount];
startOver();
}

private void startOver(){
完成=否;
int unshiftedSum =和+ numberCount *分钟;
fillMinimal(0,Math.max(min,unshiftedSum-(numberCount-1)* max),unshiftedSum);
}

private void fillMinimal(int beginIndex,int minValue,int fillSum){
int fillRange = max-minValue;
if(fillRange == 0)
Arrays.fill(p,beginIndex,numberCount,max);
else {
int fillCount = numberCount-beginIndex;
fillSum-= fillCount * minValue;
int maxCount = fillSum / fillRange;
int maxStartIndex = numberCount-maxCount;
Arrays.fill(p,maxStartIndex,numberCount,max);
fillSum-= maxCount * fillRange;
Arrays.fill(p,beginIndex,maxStartIndex,minValue);
if(fillSum!= 0)
p [maxStartIndex-1] = minValue + fillSum;
}
}

@Override
public int [] get(){//当没有更多分区时返回null,然后从
开始如果(完成){
startOver();
返回null;
}
int [] pCopy = p.clone();
if(numberCount> 1){
int i = numberCount;
int s = p [-i];
而(i> 0){
int x = p [-i];
if(x == max){
s + = x;
继续;
}
x ++;
s--;
int minRest =已排序? x:分钟;
if(s< minRest *(numberCount-i-1)){
s + = x;
继续;
}
p [i ++] ++;
fillMinimal(i,minRest,s);
返回pCopy;
}
}
完成= true;
返回pCopy;
}

public static final GeneratorFactory permutationFactory =(numberCount,min,max,sum)->
新的SequentialEnumerator(numberCount,min,max,sum,false);
public static final GeneratorFactory combinFactory =(numberCount,min,max,sum)->
新的SequentialEnumerator(numberCount,min,max,sum,true);
}

import java.util。*;
import java.util.function.BiConsumer;
import PartitionGenerator.GeneratorFactory;

公共类测试{
private final int numberCount;
私人最终int分钟;
private final int max;
私人最终int金额;
private final int repeatCount;
私有最终BiConsumer< PartitionGenerator,Test>程序;

公共测试(int numberCount,int min,int max,int sum,int repeatCount,
BiConsumer< PartitionGenerator,Test>过程){
this.numberCount = numberCount;
this.min = min;
this.max = max;
this.sum =总和;
this.repeatCount = repeatCount;
this.procedure =过程;
}

@Override
public String toString(){
返回String.format( ===%d个来自[%d,%d]的数字,其中%d,%d次迭代===,
numberCount,min,max,sum,repeatCount);
}

私有静态类GeneratedVector {
final int [] v;

GeneratedVector(int [] vect){
v = vect;
}

@Override
public int hashCode(){
return Arrays.hashCode(v);
}

@Override
public boolean equals(Object obj){
if(this == obj)
return true;
return Arrays.equals(v,(((GeneratedVector)obj).v);
}

@Override
public String toString(){
return Arrays.toString(v);
}
}

私有静态最终Comparator< Map.Entry< GeneratedVector,Integer>词典=(e1,e2)-> {
int [] v1 = e1.getKey()。v;
int [] v2 = e2.getKey()。v;
int len = v1.length;
int d = len-v2.length;
如果(d!= 0)
返回d;
for(int i = 0; i d = v1 [i]-v2 [i];
如果(d!= 0)
返回d;
}
返回0;
};

私有静态最终Comparator< Map.Entry< GeneratedVector,Integer>> < Map.Entry< GeneratedVector,Integer>> comparingInt(Map.Entry :: getValue)
.thenComparing(lexicographical); byCount =
比较器。

public static int SHOW_MISSING_LIMIT = 10;

private static void checkMissingPartitions(Map< GeneratedVector,Integer>地图,PartitionGenerator参考){
int missingCount = 0;
而(true){
int [] v = reference.get();
如果(v == null)
中断;
GeneratedVector gv = new GeneratedVector(v);
if(!map.containsKey(gv)){
if(missingCount == 0)
System.out.println( Missing:);
if(++ missingCount> SHOW_MISSING_LIMIT){
System.out.println(。。);
休息时间;
}
System.out.println(gv);
}
}
}

公共静态最终BiConsumer< PartitionGenerator,Test> distributionTest(boolean sortByCount){
return(PartitionGenerator gen,Test test)-> {
System.out.print( \n + getName(gen)+ \n\n);
Map< GeneratedVector,Integer> combos =新的HashMap<>();
//没有必要检查排序后的生成器
//的动量,因为它们与它们的组合相同
Map< GeneratedVector,Integer> permus = gen.isSorted()吗? null:新的HashMap<>();
for(int i = 0; i< test.repeatCount; i ++){
int [] v = gen.get();
if(v == null& gen instanceof SequentialEnumerator)
中断;
if(permus!= null){
permus.merge(new GeneratedVector(v),1,Integer :: sum);
v = v.clone();
Arrays.sort(v);
}
combos.merge(new GeneratedVector(v),1,Integer :: sum);
}
Set< Map.Entry< GeneratedVector,Integer> sortedEntries = new TreeSet<>(
sortByCount?byCount:字典顺序);
System.out.println( Combos +(gen.isSorted()?::(不必统一):));
sortedEntries.addAll(combos.entrySet());
用于(Map.Entry< GeneratedVector,Integer> e:sortedEntries)
System.out.println(e);
checkMissingPartitions(combos,test.getGenerator(SequentialEnumerator.combinationFactory));
if(permus!= null){
System.out.println( \nPermus:);
sortedEntries.clear();
sortedEntries.addAll(permus.entrySet());
用于(Map.Entry< GeneratedVector,Integer> e:sortedEntries)
System.out.println(e);
checkMissingPartitions(permus,test.getGenerator(SequentialEnumerator.permutationFactory));
}
};
}

公共静态最终BiConsumer< PartitionGenerator,Test> correctnessTest =
(PartitionGenerator gen,测试测试)-> {
String genName = getName(gen);
for(int i = 0; i< test.repeatCount; i ++){
int [] v = gen.get();
if(v == null& gen实例的SequentialEnumerator)
v = gen.get();
if(v.length!= test.numberCount)
抛出新的RuntimeException(genName +:错误长度的数组);
int s = 0;
if(gen.isSorted()){
if(v [0]< test.min || v [v.length-1]> test.max)
抛出新的RuntimeException(genName +:生成的数字超出范围);
int prev = test.min;
for(int x:v){
if(x< prev)
throw new RuntimeException(genName +:unsorted array);
s + = x;
prev = x;
}
} else
for(int x:v){
if(x< test.min || x> test.max)
throw new RuntimeException(genName +:生成的数字超出范围);
s + = x;
}
if(s!= test.sum)
抛出新的RuntimeException(genName +:错误的总和);
}
System.out.format(%30s:正确性测试已通过%n,genName);
};

公共静态最终BiConsumer< PartitionGenerator,Test> performanceTest =
(PartitionGenerator gen,测试测试)-> {
long time = System.nanoTime();
for(int i = 0; i< test.repeatCount; i ++)
gen.get();
time = System.nanoTime()-时间;
System.out.format(%30s:%8.3f s%10.0f ns / test%n,getName(gen),time * 1e-9,time * 1.0 / test.repeatCount);
};

public PartitionGenerator getGenerator(GeneratorFactory factory){
return factory.create(numberCount,min,max,sum);
}

public static String getName(PartitionGenerator gen){
String name = gen.getClass()。getSimpleName();
if(gen SequentialEnumerator的实例)
return(gen.isSorted()? Sorted: Unsorted)+名称;
否则
返回名称;
}

public static GeneratorFactory [] factory = {SmithTromblePartitionGenerator.factory,
PermutationPartitionGenerator.factory,CombinationPartitionGenerator.factory,
SequentialEnumerator.permutationFactory,SequentialEnumerator.combinationFactory};

public static void main(String [] args){
Test [] tests = {
new Test(3,0,3,5,3_000,distributionTest(false) ),
新测试(3、0、6、12、3_000,distributionTest(true)),
新测试(50,-10、20、70、2_000,correctnessTest),
new Test(7,3,10,42,1_000_000,performanceTest),
new Test(20,3,10,120,100_000,performanceTest)
}; (测试t:测试)的
{
System.out.println(t); (GeneratorFactory工厂:工厂)的
{
PartitionGenerator候选人= t.getGenerator(factory);
t.procedure.accept(candidate,t);
}
System.out.println();
}
}
}

您可以 在Ideone上尝试


Is there an efficient way to generate a random combination of N integers such that—

  • each integer is in the interval [min, max],
  • the integers have a sum of sum,
  • the integers can appear in any order (e.g., random order), and
  • the combination is chosen uniformly at random from among all combinations that meet the other requirements?

Is there a similar algorithm for random combinations in which the integers must appear in sorted order by their values (rather than in any order)?

(Choosing an appropriate combination with a mean of mean is a special case, if sum = N * mean. This problem is equivalent to generating a uniform random partition of sum into N parts that are each in the interval [min, max] and appear in any order or in sorted order by their values, as the case may be.)

I am aware that this problem can be solved in the following way for combinations that appear in random order (EDIT [Apr. 27]: Algorithm modified.):

  1. If N * max < sum or N * min > sum, there is no solution.

  2. If N * max == sum, there is only one solution, in which all N numbers are equal to max. If N * min == sum, there is only one solution, in which all N numbers are equal to min.

  3. Use the algorithm given in Smith and Tromble ("Sampling from the Unit Simplex", 2004) to generate N random non-negative integers with the sum sum - N * min.

  4. Add min to each number generated this way.

  5. If any number is greater than max, go to step 3.

However, this algorithm is slow if max is much less than sum. For example, according to my tests (with an implementation of the special case above involving mean), the algorithm rejects, on average—

  • about 1.6 samples if N = 7, min = 3, max = 10, sum = 42, but
  • about 30.6 samples if N = 20, min = 3, max = 10, sum = 120.

Is there a way to modify this algorithm to be efficient for large N while still meeting the requirements above?

EDIT:

As an alternative suggested in the comments, an efficient way of producing a valid random combination (that satisfies all but the last requirement) is:

  1. Calculate X, the number of valid combinations possible given sum, min, and max.
  2. Choose Y, a uniform random integer in [0, X).
  3. Convert ("unrank") Y to a valid combination.

However, is there a formula for calculating the number of valid combinations (or permutations), and is there a way to convert an integer to a valid combination? [EDIT (Apr. 28): Same for permutations rather than combinations].

EDIT (Apr. 27):

After reading Devroye's Non-Uniform Random Variate Generation (1986), I can confirm that this is a problem of generating a random partition. Also, Exercise 2 (especially part E) on page 661 is relevant to this question.

EDIT (Apr. 28):

As it turned out the algorithm I gave is uniform where the integers involved are given in random order, as opposed to sorted order by their values. Since both problems are of general interest, I have modified this question to seek a canonical answer for both problems.

The following Ruby code can be used to verify potential solutions for uniformity (where algorithm(...) is the candidate algorithm):

combos={}
permus={}
mn=0
mx=6
sum=12
for x in mn..mx
  for y in mn..mx
    for z in mn..mx
      if x+y+z==sum
        permus[[x,y,z]]=0
      end
      if x+y+z==sum and x<=y and y<=z
        combos[[x,y,z]]=0
      end
    end
  end
end

3000.times {|x|
 f=algorithm(3,sum,mn,mx)
 combos[f.sort]+=1
 permus[f]+=1
}
p combos
p permus

EDIT (Apr. 29): Re-added Ruby code of current implementation.

The following code example is given in Ruby, but my question is independent of programming language:

def posintwithsum(n, total)
    raise if n <= 0 or total <=0
    ls = [0]
    ret = []
    while ls.length < n
      c = 1+rand(total-1)
      found = false
      for j in 1...ls.length
        if ls[j] == c
          found = true
          break
        end
      end
      if found == false;ls.push(c);end
    end
    ls.sort!
    ls.push(total)
    for i in 1...ls.length
       ret.push(ls[i] - ls[i - 1])
    end
    return ret
end

def integersWithSum(n, total)
 raise if n <= 0 or total <=0
 ret = posintwithsum(n, total + n)
 for i in 0...ret.length
    ret[i] = ret[i] - 1
 end
 return ret
end

# Generate 100 valid samples
mn=3
mx=10
sum=42
n=7
100.times {
 while true
    pp=integersWithSum(n,sum-n*mn).map{|x| x+mn }
    if !pp.find{|x| x>mx }
      p pp; break # Output the sample and break
    end
 end
}

解决方案

Here's my solution in Java. It is fully functional and contains two generators: PermutationPartitionGenerator for unsorted partitions and CombinationPartitionGenerator for sorted partitions. Your generator also implemented in the class SmithTromblePartitionGenerator for comparison. The class SequentialEnumerator enumerates all possible partitions (unsorted or sorted, depending on the parameter) in sequential order. I have added thorough tests (including your test cases) for all of these generators. The implementation is self-explainable for the most part. If you have any questions, I will answer them in couple of days.

import java.util.Random;
import java.util.function.Supplier;

public abstract class PartitionGenerator implements Supplier<int[]>{
    public static final Random rand = new Random();
    protected final int numberCount;
    protected final int min;
    protected final int range;
    protected final int sum; // shifted sum
    protected final boolean sorted;

    protected PartitionGenerator(int numberCount, int min, int max, int sum, boolean sorted) {
        if (numberCount <= 0)
            throw new IllegalArgumentException("Number count should be positive");
        this.numberCount = numberCount;
        this.min = min;
        range = max - min;
        if (range < 0)
            throw new IllegalArgumentException("min > max");
        sum -= numberCount * min;
        if (sum < 0)
            throw new IllegalArgumentException("Sum is too small");
        if (numberCount * range < sum)
            throw new IllegalArgumentException("Sum is too large");
        this.sum = sum;
        this.sorted = sorted;
    }

    // Whether this generator returns sorted arrays (i.e. combinations)
    public final boolean isSorted() {
        return sorted;
    }

    public interface GeneratorFactory {
        PartitionGenerator create(int numberCount, int min, int max, int sum);
    }
}

import java.math.BigInteger;

// Permutations with repetition (i.e. unsorted vectors) with given sum
public class PermutationPartitionGenerator extends PartitionGenerator {
    private final double[][] distributionTable;

    public PermutationPartitionGenerator(int numberCount, int min, int max, int sum) {
        super(numberCount, min, max, sum, false);
        distributionTable = calculateSolutionCountTable();
    }

    private double[][] calculateSolutionCountTable() {
        double[][] table = new double[numberCount + 1][sum + 1];
        BigInteger[] a = new BigInteger[sum + 1];
        BigInteger[] b = new BigInteger[sum + 1];
        for (int i = 1; i <= sum; i++)
            a[i] = BigInteger.ZERO;
        a[0] = BigInteger.ONE;
        table[0][0] = 1.0;
        for (int n = 1; n <= numberCount; n++) {
            double[] t = table[n];
            for (int s = 0; s <= sum; s++) {
                BigInteger z = BigInteger.ZERO;
                for (int i = Math.max(0, s - range); i <= s; i++)
                    z = z.add(a[i]);
                b[s] = z;
                t[s] = z.doubleValue();
            }
            // swap a and b
            BigInteger[] c = b;
            b = a;
            a = c;
        }
        return table;
    }

    @Override
    public int[] get() {
        int[] p = new int[numberCount];
        int s = sum; // current sum
        for (int i = numberCount - 1; i >= 0; i--) {
            double t = rand.nextDouble() * distributionTable[i + 1][s];
            double[] tableRow = distributionTable[i];
            int oldSum = s;
            // lowerBound is introduced only for safety, it shouldn't be crossed 
            int lowerBound = s - range;
            if (lowerBound < 0)
                lowerBound = 0;
            s++;
            do
                t -= tableRow[--s];
            // s can be equal to lowerBound here with t > 0 only due to imprecise subtraction
            while (t > 0 && s > lowerBound);
            p[i] = min + (oldSum - s);
        }
        assert s == 0;
        return p;
    }

    public static final GeneratorFactory factory = (numberCount, min, max,sum) ->
        new PermutationPartitionGenerator(numberCount, min, max, sum);
}

import java.math.BigInteger;

// Combinations with repetition (i.e. sorted vectors) with given sum 
public class CombinationPartitionGenerator extends PartitionGenerator {
    private final double[][][] distributionTable;

    public CombinationPartitionGenerator(int numberCount, int min, int max, int sum) {
        super(numberCount, min, max, sum, true);
        distributionTable = calculateSolutionCountTable();
    }

    private double[][][] calculateSolutionCountTable() {
        double[][][] table = new double[numberCount + 1][range + 1][sum + 1];
        BigInteger[][] a = new BigInteger[range + 1][sum + 1];
        BigInteger[][] b = new BigInteger[range + 1][sum + 1];
        double[][] t = table[0];
        for (int m = 0; m <= range; m++) {
            a[m][0] = BigInteger.ONE;
            t[m][0] = 1.0;
            for (int s = 1; s <= sum; s++) {
                a[m][s] = BigInteger.ZERO;
                t[m][s] = 0.0;
            }
        }
        for (int n = 1; n <= numberCount; n++) {
            t = table[n];
            for (int m = 0; m <= range; m++)
                for (int s = 0; s <= sum; s++) {
                    BigInteger z;
                    if (m == 0)
                        z = a[0][s];
                    else {
                        z = b[m - 1][s];
                        if (m <= s)
                            z = z.add(a[m][s - m]);
                    }
                    b[m][s] = z;
                    t[m][s] = z.doubleValue();
                }
            // swap a and b
            BigInteger[][] c = b;
            b = a;
            a = c;
        }
        return table;
    }

    @Override
    public int[] get() {
        int[] p = new int[numberCount];
        int m = range; // current max
        int s = sum; // current sum
        for (int i = numberCount - 1; i >= 0; i--) {
            double t = rand.nextDouble() * distributionTable[i + 1][m][s];
            double[][] tableCut = distributionTable[i];
            if (s < m)
                m = s;
            s -= m;
            while (true) {
                t -= tableCut[m][s];
                // m can be 0 here with t > 0 only due to imprecise subtraction
                if (t <= 0 || m == 0)
                    break;
                m--;
                s++;
            }
            p[i] = min + m;
        }
        assert s == 0;
        return p;
    }

    public static final GeneratorFactory factory = (numberCount, min, max, sum) ->
        new CombinationPartitionGenerator(numberCount, min, max, sum);
}

import java.util.*;

public class SmithTromblePartitionGenerator extends PartitionGenerator {
    public SmithTromblePartitionGenerator(int numberCount, int min, int max, int sum) {
        super(numberCount, min, max, sum, false);
    }

    @Override
    public int[] get() {
        List<Integer> ls = new ArrayList<>(numberCount + 1);
        int[] ret = new int[numberCount];
        int increasedSum = sum + numberCount;
        while (true) {
            ls.add(0);
            while (ls.size() < numberCount) {
                int c = 1 + rand.nextInt(increasedSum - 1);
                if (!ls.contains(c))
                    ls.add(c);
            }
            Collections.sort(ls);
            ls.add(increasedSum);
            boolean good = true;
            for (int i = 0; i < numberCount; i++) {
                int x = ls.get(i + 1) - ls.get(i) - 1;
                if (x > range) {
                    good = false;
                    break;
                }
                ret[i] = x;
            }
            if (good) {
                for (int i = 0; i < numberCount; i++)
                    ret[i] += min;
                return ret;
            }
            ls.clear();
        }
    }

    public static final GeneratorFactory factory = (numberCount, min, max, sum) ->
        new SmithTromblePartitionGenerator(numberCount, min, max, sum);
}

import java.util.Arrays;

// Enumerates all partitions with given parameters
public class SequentialEnumerator extends PartitionGenerator {
    private final int max;
    private final int[] p;
    private boolean finished;

    public SequentialEnumerator(int numberCount, int min, int max, int sum, boolean sorted) {
        super(numberCount, min, max, sum, sorted);
        this.max = max;
        p = new int[numberCount];
        startOver();
    }

    private void startOver() {
        finished = false;
        int unshiftedSum = sum + numberCount * min;
        fillMinimal(0, Math.max(min, unshiftedSum - (numberCount - 1) * max), unshiftedSum);
    }

    private void fillMinimal(int beginIndex, int minValue, int fillSum) {
        int fillRange = max - minValue;
        if (fillRange == 0)
            Arrays.fill(p, beginIndex, numberCount, max);
        else {
            int fillCount = numberCount - beginIndex;
            fillSum -= fillCount * minValue;
            int maxCount = fillSum / fillRange;
            int maxStartIndex = numberCount - maxCount;
            Arrays.fill(p, maxStartIndex, numberCount, max);
            fillSum -= maxCount * fillRange;
            Arrays.fill(p, beginIndex, maxStartIndex, minValue);
            if (fillSum != 0)
                p[maxStartIndex - 1] = minValue + fillSum;
        }
    }

    @Override
    public int[] get() { // returns null when there is no more partition, then starts over
        if (finished) {
            startOver();
            return null;
        }
        int[] pCopy = p.clone();
        if (numberCount > 1) {
            int i = numberCount;
            int s = p[--i];
            while (i > 0) {
                int x = p[--i];
                if (x == max) {
                    s += x;
                    continue;
                }
                x++;
                s--;
                int minRest = sorted ? x : min;
                if (s < minRest * (numberCount - i - 1)) {
                    s += x;
                    continue;
                }
                p[i++]++;
                fillMinimal(i, minRest, s);
                return pCopy;
            }
        }
        finished = true;
        return pCopy;
    }

    public static final GeneratorFactory permutationFactory = (numberCount, min, max, sum) ->
        new SequentialEnumerator(numberCount, min, max, sum, false);
    public static final GeneratorFactory combinationFactory = (numberCount, min, max, sum) ->
        new SequentialEnumerator(numberCount, min, max, sum, true);
}

import java.util.*;
import java.util.function.BiConsumer;
import PartitionGenerator.GeneratorFactory;

public class Test {
    private final int numberCount;
    private final int min;
    private final int max;
    private final int sum;
    private final int repeatCount;
    private final BiConsumer<PartitionGenerator, Test> procedure;

    public Test(int numberCount, int min, int max, int sum, int repeatCount,
            BiConsumer<PartitionGenerator, Test> procedure) {
        this.numberCount = numberCount;
        this.min = min;
        this.max = max;
        this.sum = sum;
        this.repeatCount = repeatCount;
        this.procedure = procedure;
    }

    @Override
    public String toString() {
        return String.format("=== %d numbers from [%d, %d] with sum %d, %d iterations ===",
                numberCount, min, max, sum, repeatCount);
    }

    private static class GeneratedVector {
        final int[] v;

        GeneratedVector(int[] vect) {
            v = vect;
        }

        @Override
        public int hashCode() {
            return Arrays.hashCode(v);
        }

        @Override
        public boolean equals(Object obj) {
            if (this == obj)
                return true;
            return Arrays.equals(v, ((GeneratedVector)obj).v);
        }

        @Override
        public String toString() {
            return Arrays.toString(v);
        }
    }

    private static final Comparator<Map.Entry<GeneratedVector, Integer>> lexicographical = (e1, e2) -> {
        int[] v1 = e1.getKey().v;
        int[] v2 = e2.getKey().v;
        int len = v1.length;
        int d = len - v2.length;
        if (d != 0)
            return d;
        for (int i = 0; i < len; i++) {
            d = v1[i] - v2[i];
            if (d != 0)
                return d;
        }
        return 0;
    };

    private static final Comparator<Map.Entry<GeneratedVector, Integer>> byCount =
            Comparator.<Map.Entry<GeneratedVector, Integer>>comparingInt(Map.Entry::getValue)
            .thenComparing(lexicographical);

    public static int SHOW_MISSING_LIMIT = 10;

    private static void checkMissingPartitions(Map<GeneratedVector, Integer> map, PartitionGenerator reference) {
        int missingCount = 0;
        while (true) {
            int[] v = reference.get();
            if (v == null)
                break;
            GeneratedVector gv = new GeneratedVector(v);
            if (!map.containsKey(gv)) {
                if (missingCount == 0)
                    System.out.println(" Missing:");
                if (++missingCount > SHOW_MISSING_LIMIT) {
                    System.out.println("  . . .");
                    break;
                }
                System.out.println(gv);
            }
        }
    }

    public static final BiConsumer<PartitionGenerator, Test> distributionTest(boolean sortByCount) {
        return (PartitionGenerator gen, Test test) -> {
            System.out.print("\n" + getName(gen) + "\n\n");
            Map<GeneratedVector, Integer> combos = new HashMap<>();
            // There's no point of checking permus for sorted generators
            // because they are the same as combos for them
            Map<GeneratedVector, Integer> permus = gen.isSorted() ? null : new HashMap<>();
            for (int i = 0; i < test.repeatCount; i++) {
                int[] v = gen.get();
                if (v == null && gen instanceof SequentialEnumerator)
                    break;
                if (permus != null) {
                    permus.merge(new GeneratedVector(v), 1, Integer::sum);
                    v = v.clone();
                    Arrays.sort(v);
                }
                combos.merge(new GeneratedVector(v), 1, Integer::sum);
            }
            Set<Map.Entry<GeneratedVector, Integer>> sortedEntries = new TreeSet<>(
                    sortByCount ? byCount : lexicographical);
            System.out.println("Combos" + (gen.isSorted() ? ":" : " (don't have to be uniform):"));
            sortedEntries.addAll(combos.entrySet());
            for (Map.Entry<GeneratedVector, Integer> e : sortedEntries)
                System.out.println(e);
            checkMissingPartitions(combos, test.getGenerator(SequentialEnumerator.combinationFactory));
            if (permus != null) {
                System.out.println("\nPermus:");
                sortedEntries.clear();
                sortedEntries.addAll(permus.entrySet());
                for (Map.Entry<GeneratedVector, Integer> e : sortedEntries)
                    System.out.println(e);
                checkMissingPartitions(permus, test.getGenerator(SequentialEnumerator.permutationFactory));
            }
        };
    }

    public static final BiConsumer<PartitionGenerator, Test> correctnessTest =
        (PartitionGenerator gen, Test test) -> {
        String genName = getName(gen);
        for (int i = 0; i < test.repeatCount; i++) {
            int[] v = gen.get();
            if (v == null && gen instanceof SequentialEnumerator)
                v = gen.get();
            if (v.length != test.numberCount)
                throw new RuntimeException(genName + ": array of wrong length");
            int s = 0;
            if (gen.isSorted()) {
                if (v[0] < test.min || v[v.length - 1] > test.max)
                    throw new RuntimeException(genName + ": generated number is out of range");
                int prev = test.min;
                for (int x : v) {
                    if (x < prev)
                        throw new RuntimeException(genName + ": unsorted array");
                    s += x;
                    prev = x;
                }
            } else
                for (int x : v) {
                    if (x < test.min || x > test.max)
                        throw new RuntimeException(genName + ": generated number is out of range");
                    s += x;
                }
            if (s != test.sum)
                throw new RuntimeException(genName + ": wrong sum");
        }
        System.out.format("%30s :   correctness test passed%n", genName);
    };

    public static final BiConsumer<PartitionGenerator, Test> performanceTest =
        (PartitionGenerator gen, Test test) -> {
        long time = System.nanoTime();
        for (int i = 0; i < test.repeatCount; i++)
            gen.get();
        time = System.nanoTime() - time;
        System.out.format("%30s : %8.3f s %10.0f ns/test%n", getName(gen), time * 1e-9, time * 1.0 / test.repeatCount);
    };

    public PartitionGenerator getGenerator(GeneratorFactory factory) {
        return factory.create(numberCount, min, max, sum);
    }

    public static String getName(PartitionGenerator gen) {
        String name = gen.getClass().getSimpleName();
        if (gen instanceof SequentialEnumerator)
            return (gen.isSorted() ? "Sorted " : "Unsorted ") + name;
        else
            return name;
    }

    public static GeneratorFactory[] factories = { SmithTromblePartitionGenerator.factory,
            PermutationPartitionGenerator.factory, CombinationPartitionGenerator.factory,
            SequentialEnumerator.permutationFactory, SequentialEnumerator.combinationFactory };

    public static void main(String[] args) {
        Test[] tests = {
                            new Test(3, 0, 3, 5, 3_000, distributionTest(false)),
                            new Test(3, 0, 6, 12, 3_000, distributionTest(true)),
                            new Test(50, -10, 20, 70, 2_000, correctnessTest),
                            new Test(7, 3, 10, 42, 1_000_000, performanceTest),
                            new Test(20, 3, 10, 120, 100_000, performanceTest)
                       };
        for (Test t : tests) {
            System.out.println(t);
            for (GeneratorFactory factory : factories) {
                PartitionGenerator candidate = t.getGenerator(factory);
                t.procedure.accept(candidate, t);
            }
            System.out.println();
        }
    }
}

You can try this on Ideone.

这篇关于有没有一种有效的方法来生成具有给定总和或平均值的范围内的N个随机整数?的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持IT屋!

查看全文
登录 关闭
扫码关注1秒登录
发送“验证码”获取 | 15天全站免登陆