Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.

Commit 7862c00

Browse files
authored
[v1.9.x] modify erfinv implementation based on scipy (#20517)
* modify erfinv implementation based on scipy * fix lint * fix lint * fix host/device gpu error * fix flag
1 parent 1716402 commit 7862c00

1 file changed

Lines changed: 286 additions & 73 deletions

File tree

src/operator/contrib/erfinv-inl.h

Lines changed: 286 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,49 @@
11
/*
2-
* Copyright (c) 2014 Indiana University
2+
* Copyright (c) 2001-2002 Enthought, Inc. 2003-2019, SciPy Developers.
33
* All rights reserved.
4-
* Written by Prof. Gary L. Pavlis, Dept. of Geol. Sci.,
5-
* Indiana University, Bloomington, IN
6-
* This software is licensed under the New BSD license:
7-
* Redistribution and use in source and binary forms,
8-
* with or without modification, are permitted provided
9-
* that the following conditions are met:
10-
* Redistributions of source code must retain the above
11-
* copyright notice, this list of conditions and the
12-
* following disclaimer.
13-
* Redistributions in binary form must reproduce the
14-
* above copyright notice, this list of conditions and
15-
* the following disclaimer in the documentation and/or
16-
* other materials provided with the distribution.
17-
* Neither the name of Indiana University nor
18-
* the names of its contributors may be used to endorse
19-
* or promote products derived from this software without
20-
* specific prior written permission.
21-
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND
22-
* CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED
23-
* WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
24-
* WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A
25-
* PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL
26-
* THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR ANY
27-
* DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
28-
* CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
29-
* PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF
30-
* USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION)
31-
* HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER
32-
* IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING
33-
* NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
34-
* USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
35-
* POSSIBILITY OF SUCH DAMAGE.
4+
*
5+
* Redistribution and use in source and binary forms, with or without
6+
* modification, are permitted provided that the following conditions are
7+
* met:
8+
*
9+
* * Redistributions of source code must retain the above copyright
10+
* notice, this list of conditions and the following disclaimer.
11+
*
12+
* * Redistributions in binary form must reproduce the above
13+
* copyright notice, this list of conditions and the following
14+
* disclaimer in the documentation and/or other materials provided
15+
* with the distribution.
16+
*
17+
* * Neither the name of the copyright holder nor the names of its
18+
* contributors may be used to endorse or promote products derived
19+
* from this software without specific prior written permission.
20+
*
21+
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
22+
* "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
23+
* LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
24+
* A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
25+
* OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
26+
* SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
27+
* LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
28+
* DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
29+
* THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
30+
* (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
31+
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
3632
*/
33+
3734
/*
38-
* The next function is taken from
39-
* https://github.com/antelopeusersgroup/antelope_contrib/blob/master/lib/location/libgenloc/erfinv.c.
40-
* Output was modified to be inf or -inf when input is 1 or -1.
35+
* The functions in this file are taken from
36+
* https://github.com/scipy/scipy/blob/master/scipy/special/cephes/polevl.h
37+
* https://github.com/scipy/scipy/blob/master/scipy/special/cephes/ndtri.c
38+
* https://github.com/scipy/scipy/blob/master/scipy/special/cephes/erfinv.c
4139
*/
40+
4241
#ifndef MXNET_OPERATOR_CONTRIB_ERFINV_INL_H_
4342
#define MXNET_OPERATOR_CONTRIB_ERFINV_INL_H_
4443

4544
#define _USE_MATH_DEFINES
4645

46+
#include <assert.h>
4747
#include <mxnet/base.h>
4848
#include <limits>
4949
#include "math.h"
@@ -52,49 +52,262 @@ namespace mxnet {
5252
namespace op {
5353
namespace mshadow_op {
5454

55-
/*! \brief inverse gauss error function */
55+
56+
/*
57+
* Evaluate polynomial
58+
*
59+
*
60+
*
61+
* SYNOPSIS:
62+
*
63+
* int N;
64+
* double x, y, coef[N+1], polevl[];
65+
*
66+
* y = polevl( x, coef, N );
67+
*
68+
*
69+
*
70+
* DESCRIPTION:
71+
*
72+
* Evaluates polynomial of degree N:
73+
*
74+
* 2 N
75+
* y = C + C x + C x +...+ C x
76+
* 0 1 2 N
77+
*
78+
* Coefficients are stored in reverse order:
79+
*
80+
* coef[0] = C , ..., coef[N] = C .
81+
* N 0
82+
*
83+
* The function p1evl() assumes that coef[N] = 1.0 and is
84+
* omitted from the array. Its calling arguments are
85+
* otherwise the same as polevl().
86+
*
87+
*
88+
* SPEED:
89+
*
90+
* In the interest of speed, there are no checks for out
91+
* of bounds arithmetic. This routine is used by most of
92+
* the functions in the library. Depending on available
93+
* equipment features, the user may wish to rewrite the
94+
* program in microcode or assembly language.
95+
*
96+
*/
97+
98+
MSHADOW_XINLINE static double polevl(double x, const double coef[], int N) {
99+
const double *p;
100+
double ans;
101+
int i;
102+
103+
p = coef;
104+
ans = *p++;
105+
i = N;
106+
107+
do {
108+
ans = ans * x + *p++;
109+
} while (--i);
110+
111+
return (ans);
112+
}
113+
114+
MSHADOW_XINLINE static double p1evl(double x, const double coef[], int N) {
115+
const double *p;
116+
double ans;
117+
int i;
118+
119+
p = coef;
120+
ans = x + *p++;
121+
i = N - 1;
122+
123+
do {
124+
ans = ans * x + *p++;
125+
} while (--i);
126+
127+
return (ans);
128+
}
129+
130+
131+
/* Inverse of Normal distribution function
132+
*
133+
* SYNOPSIS:
134+
*
135+
* double x, y, ndtri();
136+
*
137+
* x = ndtri( y );
138+
*
139+
* domain: 0 < y < 1
140+
*
141+
*
142+
*
143+
* DESCRIPTION:
144+
*
145+
* Returns the argument, x, for which the area under the
146+
* Gaussian probability density function (integrated from
147+
* minus infinity to x) is equal to y.
148+
*
149+
*
150+
* For small arguments 0 < y < exp(-2), the program computes
151+
* z = sqrt( -2.0 * log(y) ); then the approximation is
152+
* x = z - log(z)/z - (1/z) P(1/z) / Q(1/z).
153+
* There are two rational functions P/Q, one for 0 < y < exp(-32)
154+
* and the other for y up to exp(-2). For larger arguments,
155+
* w = y - 0.5, and x/sqrt(2pi) = w + w**3 R(w**2)/S(w**2)).
156+
*
157+
*
158+
* ACCURACY:
159+
*
160+
* Relative error:
161+
* arithmetic domain # trials peak rms
162+
* IEEE 0.125, 1 20000 7.2e-16 1.3e-16
163+
* IEEE 3e-308, 0.135 50000 4.6e-16 9.8e-17
164+
*
165+
*/
166+
167+
MSHADOW_XINLINE static double ndtri(double y0) {
168+
assert(y0 > 0 && y0 < 1);
169+
170+
/* sqrt(2pi) */
171+
double s2pi = 2.50662827463100050242E0;
172+
173+
/* approximation for 0 <= |y - 0.5| <= 3/8 */
174+
double P0[5] = {
175+
-5.99633501014107895267E1,
176+
9.80010754185999661536E1,
177+
-5.66762857469070293439E1,
178+
1.39312609387279679503E1,
179+
-1.23916583867381258016E0,
180+
};
181+
double Q0[8] = {
182+
/* 1.00000000000000000000E0, */
183+
1.95448858338141759834E0,
184+
4.67627912898881538453E0,
185+
8.63602421390890590575E1,
186+
-2.25462687854119370527E2,
187+
2.00260212380060660359E2,
188+
-8.20372256168333339912E1,
189+
1.59056225126211695515E1,
190+
-1.18331621121330003142E0,
191+
};
192+
193+
/* Approximation for interval z = sqrt(-2 log y ) between 2 and 8
194+
* i.e., y between exp(-2) = .135 and exp(-32) = 1.27e-14.
195+
*/
196+
double P1[9] = {
197+
4.05544892305962419923E0,
198+
3.15251094599893866154E1,
199+
5.71628192246421288162E1,
200+
4.40805073893200834700E1,
201+
1.46849561928858024014E1,
202+
2.18663306850790267539E0,
203+
-1.40256079171354495875E-1,
204+
-3.50424626827848203418E-2,
205+
-8.57456785154685413611E-4,
206+
};
207+
double Q1[8] = {
208+
/* 1.00000000000000000000E0, */
209+
1.57799883256466749731E1,
210+
4.53907635128879210584E1,
211+
4.13172038254672030440E1,
212+
1.50425385692907503408E1,
213+
2.50464946208309415979E0,
214+
-1.42182922854787788574E-1,
215+
-3.80806407691578277194E-2,
216+
-9.33259480895457427372E-4,
217+
};
218+
219+
/* Approximation for interval z = sqrt(-2 log y ) between 8 and 64
220+
* i.e., y between exp(-32) = 1.27e-14 and exp(-2048) = 3.67e-890.
221+
*/
222+
double P2[9] = {
223+
3.23774891776946035970E0,
224+
6.91522889068984211695E0,
225+
3.93881025292474443415E0,
226+
1.33303460815807542389E0,
227+
2.01485389549179081538E-1,
228+
1.23716634817820021358E-2,
229+
3.01581553508235416007E-4,
230+
2.65806974686737550832E-6,
231+
6.23974539184983293730E-9,
232+
};
233+
double Q2[8] = {
234+
/* 1.00000000000000000000E0, */
235+
6.02427039364742014255E0,
236+
3.67983563856160859403E0,
237+
1.37702099489081330271E0,
238+
2.16236993594496635890E-1,
239+
1.34204006088543189037E-2,
240+
3.28014464682127739104E-4,
241+
2.89247864745380683936E-6,
242+
6.79019408009981274425E-9,
243+
};
244+
245+
double x, y, z, y2, x0, x1;
246+
bool code = true;
247+
y = y0;
248+
if (y > (1.0 - 0.13533528323661269189)) { /* 0.135... = exp(-2) */
249+
y = 1.0 - y;
250+
code = false;
251+
}
252+
253+
if (y > 0.13533528323661269189) {
254+
y = y - 0.5;
255+
y2 = y * y;
256+
x = y + y * (y2 * polevl(y2, P0, 4) / p1evl(y2, Q0, 8));
257+
x = x * s2pi;
258+
return (x);
259+
}
260+
261+
x = sqrt(-2.0 * log(y));
262+
x0 = x - log(x) / x;
263+
264+
z = 1.0 / x;
265+
if (x < 8.0) { /* y > exp(-32) = 1.2664165549e-14 */
266+
x1 = z * polevl(z, P1, 8) / p1evl(z, Q1, 8);
267+
} else {
268+
x1 = z * polevl(z, P2, 8) / p1evl(z, Q2, 8);
269+
}
270+
271+
x = x0 - x1;
272+
if (code) {
273+
x = -x;
274+
}
275+
return (x);
276+
}
277+
278+
279+
/*! \brief inverse of the error function */
56280
struct erfinv : public mxnet_op::tunable {
57281
template<typename DType>
58282
MSHADOW_XINLINE static DType Map(DType v) {
59-
/* Function to calculate inverse error function. Rational approximation
60-
is used to generate an initial approximation, which is then improved to
61-
full accuracy by two steps of Newton's method. Code is a direct
62-
translation of the erfinv m file in matlab version 2.0.
63-
Author: Gary L. Pavlis, Indiana University
64-
Date: February 1996
65-
*/
66-
const double central_range = 0.7;
283+
/* Inverse of the error function.
284+
* Computes the inverse of the error function on the restricted domain
285+
* -1 < y < 1. This restriction ensures the existence of a unique result
286+
* such that erf(erfinv(y)) = y.
287+
*/
288+
const double domain_lb = -1;
289+
const double domain_ub = 1;
290+
291+
const double thresh = 1e-7;
67292
double y = static_cast<double>(v);
68-
double y_fab = std::fabs(y);
69-
/*working variables */
70-
double x = 0.0;
71-
double z, num, dem;
72-
/* coefficients in rational expansion */
73-
double a[4]={ 0.886226899, -1.645349621, 0.914624893, -0.140543331};
74-
double b[4]={-2.118377725, 1.442710462, -0.329097515, 0.012229801};
75-
double c[4]={-1.970840454, -1.624906493, 3.429567803, 1.641345311};
76-
double d[2]={ 3.543889200, 1.637067800};
77-
if (y_fab > 1.0) {
78-
/* This needs IEEE constant*/
79-
return DType(std::numeric_limits<double>::quiet_NaN());
80-
} else if (y_fab == 1.0) {
81-
return DType((std::copysign(1.0, y))*std::numeric_limits<double>::infinity());
82-
} else if (y_fab <= central_range) {
83-
z = y*y;
84-
num = (((a[3]*z + a[2])*z + a[1])*z + a[0]);
85-
dem = ((((b[3]*z + b[2])*z + b[1])*z +b[0])*z + 1.0);
86-
x = y*num/dem;
87-
} else {
88-
z = std::sqrt(-std::log((1.0-y_fab)/2.0));
89-
num = ((c[3]*z + c[2])*z + c[1])*z + c[0];
90-
dem = (d[1]*z + d[0])*z + 1.0;
91-
x = (std::copysign(1.0, y))*num/dem;
293+
294+
/*
295+
* For small arguments, use the Taylor expansion
296+
* erf(y) = 2/\sqrt{\pi} (y - y^3 / 3 + O(y^5)), y\to 0
297+
* where we only retain the linear term.
298+
* Otherwise, y + 1 loses precision for |y| << 1.
299+
*/
300+
if ((-thresh < y) && (y < thresh)) {
301+
return DType(y / M_2_SQRTPI);
92302
}
93-
/* Two steps of Newton-Raphson correction */
94-
x = x - (std::erf(x) - y)/((2.0/std::sqrt(M_PI))*std::exp(-x*x));
95-
x = x - (std::erf(x) - y)/((2.0/std::sqrt(M_PI))*std::exp(-x*x));
96303

97-
return DType(x);
304+
if ((domain_lb < y) && (y < domain_ub)) {
305+
return DType(ndtri(0.5 * (y+1)) * M_SQRT1_2);
306+
} else if (y == domain_lb || y == domain_ub) {
307+
return DType(std::copysign(1.0, y) * std::numeric_limits<double>::infinity());
308+
} else {
309+
return DType(std::numeric_limits<double>::quiet_NaN());
310+
}
98311
}
99312
};
100313

0 commit comments

Comments
 (0)