11/*
2- * Copyright (c) 2014 Indiana University
2+ * Copyright (c) 2001-2002 Enthought, Inc. 2003-2019, SciPy Developers.
33 * All rights reserved.
4- * Written by Prof. Gary L. Pavlis, Dept. of Geol. Sci.,
5- * Indiana University, Bloomington, IN
6- * This software is licensed under the New BSD license:
7- * Redistribution and use in source and binary forms,
8- * with or without modification, are permitted provided
9- * that the following conditions are met:
10- * Redistributions of source code must retain the above
11- * copyright notice, this list of conditions and the
12- * following disclaimer.
13- * Redistributions in binary form must reproduce the
14- * above copyright notice, this list of conditions and
15- * the following disclaimer in the documentation and/or
16- * other materials provided with the distribution.
17- * Neither the name of Indiana University nor
18- * the names of its contributors may be used to endorse
19- * or promote products derived from this software without
20- * specific prior written permission.
21- * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
22- * CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED
23- * WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
24- * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
25- * PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL
26- * THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY
27- * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
28- * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
29- * PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF
30- * USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
31- * HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER
32- * IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
33- * NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
34- * USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
35- * POSSIBILITY OF SUCH DAMAGE.
4+ *
5+ * Redistribution and use in source and binary forms, with or without
6+ * modification, are permitted provided that the following conditions are
7+ * met:
8+ *
9+ * * Redistributions of source code must retain the above copyright
10+ * notice, this list of conditions and the following disclaimer.
11+ *
12+ * * Redistributions in binary form must reproduce the above
13+ * copyright notice, this list of conditions and the following
14+ * disclaimer in the documentation and/or other materials provided
15+ * with the distribution.
16+ *
17+ * * Neither the name of the copyright holder nor the names of its
18+ * contributors may be used to endorse or promote products derived
19+ * from this software without specific prior written permission.
20+ *
21+ * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22+ * "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23+ * LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24+ * A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
25+ * OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
26+ * SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
27+ * LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
28+ * DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
29+ * THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
30+ * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
31+ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
3632 */
33+
3734/*
38- * The next function is taken from
39- * https://github.com/antelopeusersgroup/antelope_contrib/blob/master/lib/location/libgenloc/erfinv.c.
40- * Output was modified to be inf or -inf when input is 1 or -1.
35+ * The functions in this file are taken from
36+ * https://github.com/scipy/scipy/blob/master/scipy/special/cephes/polevl.h
37+ * https://github.com/scipy/scipy/blob/master/scipy/special/cephes/ndtri.c
38+ * https://github.com/scipy/scipy/blob/master/scipy/special/cephes/erfinv.c
4139 */
40+
4241#ifndef MXNET_OPERATOR_CONTRIB_ERFINV_INL_H_
4342#define MXNET_OPERATOR_CONTRIB_ERFINV_INL_H_
4443
4544#define _USE_MATH_DEFINES
4645
46+ #include < assert.h>
4747#include < mxnet/base.h>
4848#include < limits>
4949#include " math.h"
@@ -52,49 +52,262 @@ namespace mxnet {
5252namespace op {
5353namespace mshadow_op {
5454
55- /* ! \brief inverse gauss error function */
55+
56+ /*
57+ * Evaluate polynomial
58+ *
59+ *
60+ *
61+ * SYNOPSIS:
62+ *
63+ * int N;
64+ * double x, y, coef[N+1], polevl[];
65+ *
66+ * y = polevl( x, coef, N );
67+ *
68+ *
69+ *
70+ * DESCRIPTION:
71+ *
72+ * Evaluates polynomial of degree N:
73+ *
74+ * 2 N
75+ * y = C + C x + C x +...+ C x
76+ * 0 1 2 N
77+ *
78+ * Coefficients are stored in reverse order:
79+ *
80+ * coef[0] = C , ..., coef[N] = C .
81+ * N 0
82+ *
83+ * The function p1evl() assumes that coef[N] = 1.0 and is
84+ * omitted from the array. Its calling arguments are
85+ * otherwise the same as polevl().
86+ *
87+ *
88+ * SPEED:
89+ *
90+ * In the interest of speed, there are no checks for out
91+ * of bounds arithmetic. This routine is used by most of
92+ * the functions in the library. Depending on available
93+ * equipment features, the user may wish to rewrite the
94+ * program in microcode or assembly language.
95+ *
96+ */
97+
98+ MSHADOW_XINLINE static double polevl (double x, const double coef[], int N) {
99+ const double *p;
100+ double ans;
101+ int i;
102+
103+ p = coef;
104+ ans = *p++;
105+ i = N;
106+
107+ do {
108+ ans = ans * x + *p++;
109+ } while (--i);
110+
111+ return (ans);
112+ }
113+
114+ MSHADOW_XINLINE static double p1evl (double x, const double coef[], int N) {
115+ const double *p;
116+ double ans;
117+ int i;
118+
119+ p = coef;
120+ ans = x + *p++;
121+ i = N - 1 ;
122+
123+ do {
124+ ans = ans * x + *p++;
125+ } while (--i);
126+
127+ return (ans);
128+ }
129+
130+
131+ /* Inverse of Normal distribution function
132+ *
133+ * SYNOPSIS:
134+ *
135+ * double x, y, ndtri();
136+ *
137+ * x = ndtri( y );
138+ *
139+ * domain: 0 < y < 1
140+ *
141+ *
142+ *
143+ * DESCRIPTION:
144+ *
145+ * Returns the argument, x, for which the area under the
146+ * Gaussian probability density function (integrated from
147+ * minus infinity to x) is equal to y.
148+ *
149+ *
150+ * For small arguments 0 < y < exp(-2), the program computes
151+ * z = sqrt( -2.0 * log(y) ); then the approximation is
152+ * x = z - log(z)/z - (1/z) P(1/z) / Q(1/z).
153+ * There are two rational functions P/Q, one for 0 < y < exp(-32)
154+ * and the other for y up to exp(-2). For larger arguments,
155+ * w = y - 0.5, and x/sqrt(2pi) = w + w**3 R(w**2)/S(w**2)).
156+ *
157+ *
158+ * ACCURACY:
159+ *
160+ * Relative error:
161+ * arithmetic domain # trials peak rms
162+ * IEEE 0.125, 1 20000 7.2e-16 1.3e-16
163+ * IEEE 3e-308, 0.135 50000 4.6e-16 9.8e-17
164+ *
165+ */
166+
167+ MSHADOW_XINLINE static double ndtri (double y0) {
168+ assert (y0 > 0 && y0 < 1 );
169+
170+ /* sqrt(2pi) */
171+ double s2pi = 2.50662827463100050242E0 ;
172+
173+ /* approximation for 0 <= |y - 0.5| <= 3/8 */
174+ double P0[5 ] = {
175+ -5.99633501014107895267E1 ,
176+ 9.80010754185999661536E1 ,
177+ -5.66762857469070293439E1 ,
178+ 1.39312609387279679503E1 ,
179+ -1.23916583867381258016E0 ,
180+ };
181+ double Q0[8 ] = {
182+ /* 1.00000000000000000000E0, */
183+ 1.95448858338141759834E0 ,
184+ 4.67627912898881538453E0 ,
185+ 8.63602421390890590575E1 ,
186+ -2.25462687854119370527E2 ,
187+ 2.00260212380060660359E2 ,
188+ -8.20372256168333339912E1 ,
189+ 1.59056225126211695515E1 ,
190+ -1.18331621121330003142E0 ,
191+ };
192+
193+ /* Approximation for interval z = sqrt(-2 log y ) between 2 and 8
194+ * i.e., y between exp(-2) = .135 and exp(-32) = 1.27e-14.
195+ */
196+ double P1[9 ] = {
197+ 4.05544892305962419923E0 ,
198+ 3.15251094599893866154E1 ,
199+ 5.71628192246421288162E1 ,
200+ 4.40805073893200834700E1 ,
201+ 1.46849561928858024014E1 ,
202+ 2.18663306850790267539E0 ,
203+ -1.40256079171354495875E-1 ,
204+ -3.50424626827848203418E-2 ,
205+ -8.57456785154685413611E-4 ,
206+ };
207+ double Q1[8 ] = {
208+ /* 1.00000000000000000000E0, */
209+ 1.57799883256466749731E1 ,
210+ 4.53907635128879210584E1 ,
211+ 4.13172038254672030440E1 ,
212+ 1.50425385692907503408E1 ,
213+ 2.50464946208309415979E0 ,
214+ -1.42182922854787788574E-1 ,
215+ -3.80806407691578277194E-2 ,
216+ -9.33259480895457427372E-4 ,
217+ };
218+
219+ /* Approximation for interval z = sqrt(-2 log y ) between 8 and 64
220+ * i.e., y between exp(-32) = 1.27e-14 and exp(-2048) = 3.67e-890.
221+ */
222+ double P2[9 ] = {
223+ 3.23774891776946035970E0 ,
224+ 6.91522889068984211695E0 ,
225+ 3.93881025292474443415E0 ,
226+ 1.33303460815807542389E0 ,
227+ 2.01485389549179081538E-1 ,
228+ 1.23716634817820021358E-2 ,
229+ 3.01581553508235416007E-4 ,
230+ 2.65806974686737550832E-6 ,
231+ 6.23974539184983293730E-9 ,
232+ };
233+ double Q2[8 ] = {
234+ /* 1.00000000000000000000E0, */
235+ 6.02427039364742014255E0 ,
236+ 3.67983563856160859403E0 ,
237+ 1.37702099489081330271E0 ,
238+ 2.16236993594496635890E-1 ,
239+ 1.34204006088543189037E-2 ,
240+ 3.28014464682127739104E-4 ,
241+ 2.89247864745380683936E-6 ,
242+ 6.79019408009981274425E-9 ,
243+ };
244+
245+ double x, y, z, y2, x0, x1;
246+ bool code = true ;
247+ y = y0;
248+ if (y > (1.0 - 0.13533528323661269189 )) { /* 0.135... = exp(-2) */
249+ y = 1.0 - y;
250+ code = false ;
251+ }
252+
253+ if (y > 0.13533528323661269189 ) {
254+ y = y - 0.5 ;
255+ y2 = y * y;
256+ x = y + y * (y2 * polevl (y2, P0, 4 ) / p1evl (y2, Q0, 8 ));
257+ x = x * s2pi;
258+ return (x);
259+ }
260+
261+ x = sqrt (-2.0 * log (y));
262+ x0 = x - log (x) / x;
263+
264+ z = 1.0 / x;
265+ if (x < 8.0 ) { /* y > exp(-32) = 1.2664165549e-14 */
266+ x1 = z * polevl (z, P1, 8 ) / p1evl (z, Q1, 8 );
267+ } else {
268+ x1 = z * polevl (z, P2, 8 ) / p1evl (z, Q2, 8 );
269+ }
270+
271+ x = x0 - x1;
272+ if (code) {
273+ x = -x;
274+ }
275+ return (x);
276+ }
277+
278+
279+ /* ! \brief inverse of the error function */
56280struct erfinv : public mxnet_op ::tunable {
57281 template <typename DType>
58282 MSHADOW_XINLINE static DType Map (DType v) {
59- /* Function to calculate inverse error function. Rational approximation
60- is used to generate an initial approximation, which is then improved to
61- full accuracy by two steps of Newton's method. Code is a direct
62- translation of the erfinv m file in matlab version 2.0.
63- Author: Gary L. Pavlis, Indiana University
64- Date: February 1996
65- */
66- const double central_range = 0.7 ;
283+ /* Inverse of the error function.
284+ * Computes the inverse of the error function on the restricted domain
285+ * -1 < y < 1. This restriction ensures the existence of a unique result
286+ * such that erf(erfinv(y)) = y.
287+ */
288+ const double domain_lb = -1 ;
289+ const double domain_ub = 1 ;
290+
291+ const double thresh = 1e-7 ;
67292 double y = static_cast <double >(v);
68- double y_fab = std::fabs (y);
69- /* working variables */
70- double x = 0.0 ;
71- double z, num, dem;
72- /* coefficients in rational expansion */
73- double a[4 ]={ 0.886226899 , -1.645349621 , 0.914624893 , -0.140543331 };
74- double b[4 ]={-2.118377725 , 1.442710462 , -0.329097515 , 0.012229801 };
75- double c[4 ]={-1.970840454 , -1.624906493 , 3.429567803 , 1.641345311 };
76- double d[2 ]={ 3.543889200 , 1.637067800 };
77- if (y_fab > 1.0 ) {
78- /* This needs IEEE constant*/
79- return DType (std::numeric_limits<double >::quiet_NaN ());
80- } else if (y_fab == 1.0 ) {
81- return DType ((std::copysign (1.0 , y))*std::numeric_limits<double >::infinity ());
82- } else if (y_fab <= central_range) {
83- z = y*y;
84- num = (((a[3 ]*z + a[2 ])*z + a[1 ])*z + a[0 ]);
85- dem = ((((b[3 ]*z + b[2 ])*z + b[1 ])*z +b[0 ])*z + 1.0 );
86- x = y*num/dem;
87- } else {
88- z = std::sqrt (-std::log ((1.0 -y_fab)/2.0 ));
89- num = ((c[3 ]*z + c[2 ])*z + c[1 ])*z + c[0 ];
90- dem = (d[1 ]*z + d[0 ])*z + 1.0 ;
91- x = (std::copysign (1.0 , y))*num/dem;
293+
294+ /*
295+ * For small arguments, use the Taylor expansion
296+ * erf(y) = 2/\sqrt{\pi} (y - y^3 / 3 + O(y^5)), y\to 0
297+ * where we only retain the linear term.
298+ * Otherwise, y + 1 loses precision for |y| << 1.
299+ */
300+ if ((-thresh < y) && (y < thresh)) {
301+ return DType (y / M_2_SQRTPI);
92302 }
93- /* Two steps of Newton-Raphson correction */
94- x = x - (std::erf (x) - y)/((2.0 /std::sqrt (M_PI))*std::exp (-x*x));
95- x = x - (std::erf (x) - y)/((2.0 /std::sqrt (M_PI))*std::exp (-x*x));
96303
97- return DType (x);
304+ if ((domain_lb < y) && (y < domain_ub)) {
305+ return DType (ndtri (0.5 * (y+1 )) * M_SQRT1_2);
306+ } else if (y == domain_lb || y == domain_ub) {
307+ return DType (std::copysign (1.0 , y) * std::numeric_limits<double >::infinity ());
308+ } else {
309+ return DType (std::numeric_limits<double >::quiet_NaN ());
310+ }
98311 }
99312};
100313
0 commit comments