View Javadoc
1   package nl.tudelft.simulation.jstats.math;
2   
3   import org.djutils.exceptions.Throw;
4   
5   /**
6    * The ProbMath class defines some very basic probabilistic mathematical functions.
7    * <p>
8    * Copyright (c) 2004-2025 Delft University of Technology, Jaffalaan 5, 2628 BX Delft, the Netherlands. All rights reserved. See
9    * for project information <a href="https://simulation.tudelft.nl/dsol/manual/" target="_blank">DSOL Manual</a>. The DSOL
10   * project is distributed under a three-clause BSD-style license, which can be found at
11   * <a href="https://simulation.tudelft.nl/dsol/docs/latest/license.html" target="_blank">DSOL License</a>.
12   * </p>
13   * @author <a href="https://github.com/averbraeck" target="_blank"> Alexander Verbraeck</a>
14   */
15  public final class ProbMath
16  {
17      /** stored values of n! as a double value. */
18      public static final double[] FACTORIAL_DOUBLE;
19  
20      /** stored values of n! as a long value. */
21      public static final long[] FACTORIAL_LONG;
22  
23      /** stored values of a^n as a double value. */
24      public static final double[] POW2_DOUBLE;
25  
26      /** stored values of 2^n as a long value. */
27      public static final long[] POW2_LONG;
28  
29      /** calculate n! and 2^n. */
30      static
31      {
32          FACTORIAL_DOUBLE = new double[171];
33          FACTORIAL_DOUBLE[0] = 1.0;
34          double d = 1.0;
35          for (int i = 1; i <= 170; i++)
36          {
37              d = d * i;
38              FACTORIAL_DOUBLE[i] = d;
39          }
40  
41          FACTORIAL_LONG = new long[21];
42          FACTORIAL_LONG[0] = 1L;
43          long n = 1L;
44          for (int i = 1; i <= 20; i++)
45          {
46              n = n * i;
47              FACTORIAL_LONG[i] = n;
48          }
49  
50          POW2_DOUBLE = new double[1024];
51          POW2_DOUBLE[0] = 1.0;
52          for (int i = 1; i < 1024; i++)
53          {
54              POW2_DOUBLE[i] = 2.0 * POW2_DOUBLE[i - 1];
55          }
56  
57          POW2_LONG = new long[63];
58          POW2_LONG[0] = 1L;
59          for (int i = 1; i < 63; i++)
60          {
61              POW2_LONG[i] = 2L * POW2_LONG[i - 1];
62          }
63      }
64  
65      /**
66       * Utility class.
67       */
68      private ProbMath()
69      {
70          // unreachable code for the utility class
71      }
72  
73      /**
74       * Compute n factorial as a double. fac(n) = n * (n-1) * (n-2) * ... 2 * 1.
75       * @param n int; the value to calculate n! for
76       * @return double; n factorial
77       */
78      public static double factorial(final int n)
79      {
80          Throw.when(n < 0, IllegalArgumentException.class, "n! with n<0 is invalid");
81          Throw.when(n > 170, IllegalArgumentException.class, "n! with n>170 is larger than Double.MAX_VALUE");
82          return FACTORIAL_DOUBLE[n];
83      }
84  
85      /**
86       * Compute n factorial as a double. fac(n) = n * (n-1) * (n-2) * ... 2 * 1.
87       * @param n long; the value to calculate n! for
88       * @return double; n factorial
89       */
90      public static double factorial(final long n)
91      {
92          Throw.when(n < 0, IllegalArgumentException.class, "n! with n<0 is invalid");
93          Throw.when(n > 170, IllegalArgumentException.class, "n! with n>170 is larger than Double.MAX_VALUE");
94          return FACTORIAL_DOUBLE[(int) n];
95      }
96  
97      /**
98       * Compute n factorial as a long. fac(n) = n * (n-1) * (n-2) * ... 2 * 1.
99       * @param n int; the value to calculate n! for
100      * @return double; n factorial
101      */
102     public static long fac(final int n)
103     {
104         Throw.when(n < 0, IllegalArgumentException.class, "n! with n<0 is invalid");
105         Throw.when(n > 20, IllegalArgumentException.class, "n! with n>20 is more than 64 bits long");
106         return FACTORIAL_LONG[n];
107     }
108 
109     /**
110      * Compute n factorial as a long. fac(n) = n * (n-1) * (n-2) * ... 2 * 1.
111      * @param n long; the value to calculate n! for
112      * @return double; n factorial
113      */
114     public static long fac(final long n)
115     {
116         Throw.when(n < 0, IllegalArgumentException.class, "n! with n<0 is invalid");
117         Throw.when(n > 20, IllegalArgumentException.class, "n! with n>20 is more than 64 bits long");
118         return FACTORIAL_LONG[(int) n];
119     }
120 
121     /**
122      * Compute the k-permutations of n.
123      * @param n int; the first parameter
124      * @param k int; the second parameter
125      * @return the k-permutations of n
126      */
127     public static double permutations(final int n, final int k)
128     {
129         if (k > n)
130         {
131             throw new IllegalArgumentException("permutations of (n,k) with k>n");
132         }
133         return factorial(n) / factorial(n - k);
134     }
135 
136     /**
137      * Compute the k-permutations of n.
138      * @param n long; the first parameter
139      * @param k long; the second parameter
140      * @return the k-permutations of n
141      */
142     public static double permutations(final long n, final long k)
143     {
144         if (k > n)
145         {
146             throw new IllegalArgumentException("permutations of (n,k) with k>n");
147         }
148         return factorial(n) / factorial(n - k);
149     }
150 
151     /**
152      * Computes the k-permutations of n as a long.
153      * @param n int; the first parameter
154      * @param k int; the second parameter
155      * @return the k-permutations of n
156      */
157     public static long perm(final int n, final int k)
158     {
159         if (k > n)
160         {
161             throw new IllegalArgumentException("permutations of (n,k) with k>n");
162         }
163         return fac(n) / fac(n - k);
164     }
165 
166     /**
167      * Computes the k-permutations of n as a long.
168      * @param n long; the first parameter
169      * @param k long; the second parameter
170      * @return the k-permutations of n
171      */
172     public static long perm(final long n, final long k)
173     {
174         if (k > n)
175         {
176             throw new IllegalArgumentException("permutations of (n,k) with k>n");
177         }
178         return fac(n) / fac(n - k);
179     }
180 
181     /**
182      * computes the combinations of n over k.
183      * @param n int; the first parameter
184      * @param k int; the second parameter
185      * @return the combinations of n over k
186      */
187     public static double combinations(final int n, final int k)
188     {
189         if (k > n)
190         {
191             throw new IllegalArgumentException("combinations of (n,k) with k>n");
192         }
193         return factorial(n) / (factorial(k) * factorial(n - k));
194     }
195 
196     /**
197      * computes the combinations of n over k.
198      * @param n long; the first parameter
199      * @param k long; the second parameter
200      * @return the combinations of n over k
201      */
202     public static double combinations(final long n, final long k)
203     {
204         if (k > n)
205         {
206             throw new IllegalArgumentException("combinations of (n,k) with k>n");
207         }
208         return factorial(n) / (factorial(k) * factorial(n - k));
209     }
210 
211     /**
212      * computes the combinations of n over k as a long.
213      * @param n int; the first parameter
214      * @param k int; the second parameter
215      * @return the combinations of n over k
216      */
217     public static long comb(final int n, final int k)
218     {
219         if (k > n)
220         {
221             throw new IllegalArgumentException("combinations of (n,k) with k>n");
222         }
223         return fac(n) / (fac(k) * fac(n - k));
224     }
225 
226     /**
227      * computes the combinations of n over k as a long.
228      * @param n long; the first parameter
229      * @param k long; the second parameter
230      * @return the combinations of n over k
231      */
232     public static long comb(final long n, final long k)
233     {
234         if (k > n)
235         {
236             throw new IllegalArgumentException("combinations of (n,k) with k>n");
237         }
238         return fac(n) / (fac(k) * fac(n - k));
239     }
240 
241     /**
242      * Approximates erf(z) using a Taylor series.<br>
243      * The Taylor series for erf(z) for abs(z) <u>&lt;</u> 0.5 that is used is:<br>
244      * &nbsp; &nbsp; erf(z) = (exp(-z<sup>2</sup>) / &radic;&pi;) &Sigma; [ 2z<sup>2n + 1</sup> / (2n + 1)!!]<br>
245      * The Taylor series for erf(z) for abs(z) &gt; 3.7 that is used is:<br>
246      * &nbsp; &nbsp; erf(z) = 1 - (exp(-z<sup>2</sup>) / &radic;&pi;) &Sigma; [ (-1)<sup>n</sup> (2n - 1)!! z<sup>-(2n +
247      * 1)</sup> / 2<sup>n</sup>]<br>
248      * See <a href="https://mathworld.wolfram.com/Erf.html">https://mathworld.wolfram.com/Erf.html</a>. <br>
249      * For 0.5 &lt; z &lt; 3.7 it approximates erf(z) using the following Taylor series:<br>
250      * &nbsp; &nbsp; erf(z) = (2/&radic;&pi;) (z - z<sup>3</sup>/3 + z<sup>5</sup>/10 - z<sup>7</sup>/42 + z<sup>9</sup>/216 -
251      * ...)<br>
252      * The factors are given by <a href="https://oeis.org/A007680">https://oeis.org/A007680</a>, which evaluates to a(n) =
253      * (2n+1)n!. See <a href="https://en.wikipedia.org/wiki/Error_function">https://en.wikipedia.org/wiki/Error_function</a>.
254      * @param z double; the value to calculate the error function for
255      * @return erf(z)
256      */
257     public static double erf(final double z)
258     {
259         double zpos = Math.abs(z);
260         if (zpos < 0.5)
261         {
262             return erfSmall(z);
263         }
264         if (zpos > 3.8)
265         {
266             return erfBig(z);
267         }
268         return erfTaylor(z);
269     }
270 
271     /**
272      * The Taylor series for erf(z) for abs(z) <u>&lt;</u> 0.5 that is used is:<br>
273      * &nbsp; &nbsp; erf(z) = (exp(-z<sup>2</sup>) / &radic;&pi;) &Sigma; [ 2z<sup>2n + 1</sup> / (2n + 1)!!]<br>
274      * where the !! operator is the 'double factorial' operator which is (n).(n-2)...8.4.2 for even n, and (n).(n-2)...3.5.1 for
275      * odd n. See <a href="https://mathworld.wolfram.com/Erf.html">https://mathworld.wolfram.com/Erf.html</a> formula (9) and
276      * (10). This function would work well for z <u>&lt;</u> 0.5.
277      * @param z double; the parameter
278      * @return double; erf(x)
279      */
280     private static double erfSmall(final double z)
281     {
282         double zpos = Math.abs(z);
283         // @formatter:off
284         double sum = zpos 
285                 + 2.0 * Math.pow(zpos, 3) / 3.0 
286                 + 4.0 * Math.pow(zpos, 5) / 15.0 
287                 + 8.0 * Math.pow(zpos, 7) / 105.0
288                 + 16.0 * Math.pow(zpos, 9) / 945.0 
289                 + 32.0 * Math.pow(zpos, 11) / 10395.0
290                 + 64.0 * Math.pow(zpos, 13) / 135135.0 
291                 + 128.0 * Math.pow(zpos, 15) / 2027025.0
292                 + 256.0 * Math.pow(zpos, 17) / 34459425.0 
293                 + 512.0 * Math.pow(zpos, 19) / 654729075.0
294                 + 1024.0 * Math.pow(zpos, 21) / 13749310575.0;
295         // @formatter:on
296         return Math.signum(z) * sum * 2.0 * Math.exp(-zpos * zpos) / Math.sqrt(Math.PI);
297     }
298 
299     /**
300      * Calculate erf(z) for large values using the Taylor series:<br>
301      * &nbsp; &nbsp; erf(z) = 1 - (exp(-z<sup>2</sup>) / &radic;&pi;) &Sigma; [ (-1)<sup>n</sup> (2n - 1)!! z<sup>-(2n +
302      * 1)</sup> / 2<sup>n</sup>]<br>
303      * where the !! operator is the 'double factorial' operator which is (n).(n-2)...8.4.2 for even n, and (n).(n-2)...3.5.1 for
304      * odd n. See <a href="https://mathworld.wolfram.com/Erf.html">https://mathworld.wolfram.com/Erf.html</a> formula (18) to
305      * (20). This function would work well for z <u>&gt;</u> 3.7.
306      * @param z double; the argument
307      * @return double; erf(z)
308      */
309     private static double erfBig(final double z)
310     {
311         double zpos = Math.abs(z);
312         // @formatter:off
313         double sum = 1.0 / zpos 
314                 - (1.0 / 2.0) * Math.pow(zpos, -3) 
315                 + (3.0 / 4.0) * Math.pow(zpos, -5)
316                 - (15.0 / 8.0) * Math.pow(zpos, -7) 
317                 + (105.0 / 16.0) * Math.pow(zpos, -9) 
318                 - (945.0 / 32.0) * Math.pow(zpos, -11)
319                 + (10395.0 / 64.0) * Math.pow(zpos, -13) 
320                 - (135135.0 / 128.0) * Math.pow(zpos, -15)
321                 + (2027025.0 / 256.0) * Math.pow(zpos, -17);
322         // @formatter:on
323         return Math.signum(z) * (1.0 - sum * Math.exp(-zpos * zpos) / Math.sqrt(Math.PI));
324     }
325 
326     /**
327      * Calculate erf(z) using the Taylor series:<br>
328      * &nbsp; &nbsp; erf(z) = (2/&radic;&pi;) (z - z<sup>3</sup>/3 + z<sup>5</sup>/10 - z<sup>7</sup>/42 + z<sup>9</sup>/216 -
329      * ...)<br>
330      * The factors are given by <a href="https://oeis.org/A007680">https://oeis.org/A007680</a>, which evaluates to a(n) =
331      * (2n+1)n!. See <a href="https://en.wikipedia.org/wiki/Error_function">https://en.wikipedia.org/wiki/Error_function</a>.
332      * This works pretty well on the interval [0.5,3.7].
333      * @param z double; the argument
334      * @return double; erf(z)
335      */
336     private static double erfTaylor(final double z)
337     {
338         double zpos = Math.abs(z);
339         double d = zpos;
340         double zpow = zpos;
341         for (int i = 1; i < 64; i++) // max 64 steps
342         {
343             // calculate Math.pow(zpos, 2 * i + 1) / ((2 * i + 1) * factorial(i));
344             zpow *= zpos * zpos;
345             double term = zpow / ((2.0 * i + 1.0) * ProbMath.factorial(i));
346             d += term * ((i & 1) == 0 ? 1 : -1);
347             if (term < 1E-16)
348             {
349                 break;
350             }
351         }
352         return Math.signum(z) * d * 2.0 / Math.sqrt(Math.PI);
353     }
354 
355     /**
356      * Approximates erf<sup>-1</sup>(p) based on
357      * <a href="http://www.naic.edu/~jeffh/inverse_cerf.c">http://www.naic.edu/~jeffh/inverse_cerf.c</a> code.
358      * @param y double; the cumulative probability to calculate the inverse error function for
359      * @return erf<sup>-1</sup>(p)
360      */
361     public static double erfInv(final double y)
362     {
363         double ax, t, ret;
364         ax = Math.abs(y);
365 
366         /*
367          * This approximation, taken from Table 10 of Blair et al., is valid for |x|<=0.75 and has a maximum relative error of
368          * 4.47 x 10^-8.
369          */
370         if (ax <= 0.75)
371         {
372 
373             double[] p = new double[] {-13.0959967422, 26.785225760, -9.289057635};
374             double[] q = new double[] {-12.0749426297, 30.960614529, -17.149977991, 1.00000000};
375 
376             t = ax * ax - 0.75 * 0.75;
377             ret = ax * (p[0] + t * (p[1] + t * p[2])) / (q[0] + t * (q[1] + t * (q[2] + t * q[3])));
378 
379         }
380         else if (ax >= 0.75 && ax <= 0.9375)
381         {
382             double[] p = new double[] {-.12402565221, 1.0688059574, -1.9594556078, .4230581357};
383             double[] q = new double[] {-.08827697997, .8900743359, -2.1757031196, 1.0000000000};
384 
385             /*
386              * This approximation, taken from Table 29 of Blair et al., is valid for .75<=|x|<=.9375 and has a maximum relative
387              * error of 4.17 x 10^-8.
388              */
389             t = ax * ax - 0.9375 * 0.9375;
390             ret = ax * (p[0] + t * (p[1] + t * (p[2] + t * p[3]))) / (q[0] + t * (q[1] + t * (q[2] + t * q[3])));
391 
392         }
393         else if (ax >= 0.9375 && ax <= (1.0 - 1.0e-9))
394         {
395             double[] p =
396                     new double[] {.1550470003116, 1.382719649631, .690969348887, -1.128081391617, .680544246825, -.16444156791};
397             double[] q = new double[] {.155024849822, 1.385228141995, 1.000000000000};
398 
399             /*
400              * This approximation, taken from Table 50 of Blair et al., is valid for .9375<=|x|<=1-10^-100 and has a maximum
401              * relative error of 2.45 x 10^-8.
402              */
403             t = 1.0 / Math.sqrt(-Math.log(1.0 - ax));
404             ret = (p[0] / t + p[1] + t * (p[2] + t * (p[3] + t * (p[4] + t * p[5])))) / (q[0] + t * (q[1] + t * (q[2])));
405         }
406         else
407         {
408             ret = Double.POSITIVE_INFINITY;
409         }
410 
411         return Math.signum(y) * ret;
412     }
413 
414     /** Coefficients for the ln(gamma(x)) function. */
415     private static final double[] GAMMALN_COF = {76.18009172947146, -86.50532032941677, 24.01409824083091, -1.231739572450155,
416             0.1208650973866179e-2, -0.5395239384953e-5};
417 
418     /**
419      * Calculates ln(gamma(x)). Java version of gammln function in Numerical Recipes in C, p.214.
420      * @param xx double; the value to calculate the gamma function for
421      * @return double; gamma(x)
422      * @throws IllegalArgumentException when x is &lt; 0
423      */
424     public static double gammaln(final double xx)
425     {
426         Throw.when(xx < 0, IllegalArgumentException.class, "gamma function not defined for real values < 0");
427         double x, y, tmp, ser;
428         x = xx;
429         y = x;
430         tmp = x + 5.5;
431         tmp -= (x + 0.5) * Math.log(tmp);
432         ser = 1.000000000190015;
433         for (int j = 0; j <= 5; j++)
434         {
435             ser += GAMMALN_COF[j] / ++y;
436         }
437         return -tmp + Math.log(2.5066282746310005 * ser / x);
438     }
439 
440     /**
441      * Calculates gamma(x). Based on the gammln function in Numerical Recipes in C, p.214.
442      * @param x double; the value to calculate the gamma function for
443      * @return double; gamma(x)
444      * @throws IllegalArgumentException when x is &lt; 0
445      */
446     public static double gamma(final double x)
447     {
448         return Math.exp(gammaln(x));
449     }
450 
451     /**
452      * Calculates Beta(z, w) where Beta(z, w) = &Gamma;(z) &Gamma;(w) / &Gamma;(z + w).
453      * @param z double; beta function parameter 1
454      * @param w ; beta function parameter 2
455      * @return double; beta(z, w)
456      * @throws IllegalArgumentException when one of the parameters is &lt; 0
457      */
458     public static double beta(final double z, final double w)
459     {
460         Throw.when(z < 0 || w < 0, IllegalArgumentException.class, "beta function not defined for negative arguments");
461         return Math.exp(gammaln(z) + gammaln(w) - gammaln(z + w));
462     }
463 
464 }