cloudy  trunk
 All Data Structures Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
vectorize_sqrt.cpp
Go to the documentation of this file.
1 /* This file is part of Cloudy and is copyright (C)1978-2022 by Gary J. Ferland and
2  * others. For conditions of distribution and use see copyright notice in license.txt */
3 #include "cddefines.h"
4 #include "vectorize.h"
5 #include "vectorize_math.h"
6 #include "vectorize_sqrt_core.h"
7 
8 //
9 // Written by Peter A.M. van Hoof, Royal Observatory of Belgium, Brussels
10 //
11 // this file contains vectorized versions of the single and double variants of the sqrt()
12 // and hypot() functions. They are vectorized using AVX instructions, but also make use of
13 // AVX2, FMA, and AVX512 instructions when available. The basic algorithms for calculating
14 // the sqrt() functions were derived from the algorithm for calculating rsqrt() described
15 // here: http://en.wikipedia.org/wiki/Fast_inverse_square_root
16 //
17 // Alternatively one can also use the sqrt hardware instruction, but on some hardware the
18 // the software implementation is faster... The hardware instruction is chosen as the
19 // default implementation below.
20 //
21 
22 #ifdef __AVX__
23 
24 #ifdef __AVX512F__
25 
26 inline v8df v1sqrtd(v8df x)
27 {
28  __mmask8 invalid1 = _mm512_cmp_pd_mask(x, zero, _CMP_LT_OQ);
29  __mmask8 invalid2 = _mm512_cmp_pd_mask(x, dbl_max, _CMP_NLE_UQ);
30  if( ! _mm512_kortestz(invalid1, invalid2) )
31  {
32  __mmask8 invalid = invalid1 | invalid2;
33  throw domain_error( DEMsg("v1sqrtd", x, invalid) );
34  }
35  return v1sqrtd_core(x);
36 }
37 
38 inline v8df v1hypotd(v8df x, v8df y)
39 {
40  v8di ix = _mm512_castpd_si512(x);
41  v8di iy = _mm512_castpd_si512(y);
42  ix = _mm512_and_si512(ix, sqrt_mask1);
43  iy = _mm512_and_si512(iy, sqrt_mask1);
44  x = _mm512_castsi512_pd(ix);
45  __mmask8 invalid1 = _mm512_cmp_pd_mask(x, dbl_max, _CMP_NLE_UQ);
46  y = _mm512_castsi512_pd(iy);
47  __mmask8 invalid2 = _mm512_cmp_pd_mask(y, dbl_max, _CMP_NLE_UQ);
48  if( ! _mm512_kortestz(invalid1, invalid2) )
49  throw domain_error( DEMsg("v1hypotd", x, invalid1, y, invalid2) );
50  return v1hypotd_core(x, y);
51 }
52 
53 inline v16sf v1sqrtf(v16sf x)
54 {
55  __mmask16 invalid1 = _mm512_cmp_ps_mask(x, zerof, _CMP_LT_OQ);
56  __mmask16 invalid2 = _mm512_cmp_ps_mask(x, flt_max, _CMP_NLE_UQ);
57  if( ! _mm512_kortestz(invalid1, invalid2) )
58  {
59  __mmask16 invalid = invalid1 | invalid2;
60  throw domain_error( DEMsg("v1sqrtf", x, invalid) );
61  }
62  return v1sqrtf_core(x);
63 }
64 
65 inline v16sf v1hypotf(v16sf x, v16sf y)
66 {
67  v16si ix = _mm512_castps_si512(x);
68  v16si iy = _mm512_castps_si512(y);
69  ix = _mm512_and_si512(ix, sqrt_mask1f);
70  iy = _mm512_and_si512(iy, sqrt_mask1f);
71  x = _mm512_castsi512_ps(ix);
72  __mmask16 invalid1 = _mm512_cmp_ps_mask(x, flt_max, _CMP_NLE_UQ);
73  y = _mm512_castsi512_ps(iy);
74  __mmask16 invalid2 = _mm512_cmp_ps_mask(y, flt_max, _CMP_NLE_UQ);
75  if( ! _mm512_kortestz(invalid1, invalid2) )
76  throw domain_error( DEMsg("v1hypotf", x, invalid1, y, invalid2) );
77  return v1hypotf_core(x, y);
78 }
79 
80 #else
81 
82 inline v4df v1sqrtd(v4df x)
83 {
84  v4df invalid1 = _mm256_cmp_pd(x, zero, _CMP_LT_OQ);
85  v4df invalid2 = _mm256_cmp_pd(x, dbl_max, _CMP_NLE_UQ);
86  v4df invalid = _mm256_or_pd(invalid1, invalid2);
87  if( ! _mm256_testz_pd(invalid, invalid) )
88  throw domain_error( DEMsg("v1sqrtd", x, invalid) );
89  return v1sqrtd_core(x);
90 }
91 
92 inline v4df v1hypotd(v4df x, v4df y)
93 {
94  v4df mask1 = _mm256_castsi256_pd(sqrt_mask1);
95  x = _mm256_and_pd(x, mask1);
96  v4df invalid1 = _mm256_cmp_pd(x, dbl_max, _CMP_NLE_UQ);
97  y = _mm256_and_pd(y, mask1);
98  v4df invalid2 = _mm256_cmp_pd(y, dbl_max, _CMP_NLE_UQ);
99  v4df invalid = _mm256_or_pd(invalid1, invalid2);
100  if( ! _mm256_testz_pd(invalid, invalid) )
101  throw domain_error( DEMsg("v1hypotd", x, invalid1, y, invalid2) );
102  return v1hypotd_core(x, y);
103 }
104 
105 inline v8sf v1sqrtf(v8sf x)
106 {
107  v8sf invalid1 = _mm256_cmp_ps(x, zerof, _CMP_LT_OQ);
108  v8sf invalid2 = _mm256_cmp_ps(x, flt_max, _CMP_NLE_UQ);
109  v8sf invalid = _mm256_or_ps(invalid1, invalid2);
110  if( ! _mm256_testz_ps(invalid, invalid) )
111  throw domain_error( DEMsg("v1sqrtf", x, invalid) );
112  return v1sqrtf_core(x);
113 }
114 
115 inline v8sf v1hypotf(v8sf x, v8sf y)
116 {
117  v8sf mask1 = _mm256_castsi256_ps(sqrt_mask1f);
118  x = _mm256_and_ps(x, mask1);
119  v8sf invalid1 = _mm256_cmp_ps(x, flt_max, _CMP_NLE_UQ);
120  y = _mm256_and_ps(y, mask1);
121  v8sf invalid2 = _mm256_cmp_ps(y, flt_max, _CMP_NLE_UQ);
122  v8sf invalid = _mm256_or_ps(invalid1, invalid2);
123  if( ! _mm256_testz_ps(invalid, invalid) )
124  throw domain_error( DEMsg("v1hypotf", x, invalid1, y, invalid2) );
125  return v1hypotf_core(x, y);
126 }
127 
128 #endif // __AVX512F__
129 
130 #else
131 
132 // stub routines, should never be called
133 inline int v1sqrtd(int) { return 0; }
134 inline int v1hypotd(int, int) { return 0; }
135 inline int v1sqrtf(int) { return 0; }
136 inline int v1hypotf(int, int) { return 0; }
137 
138 #endif // __AVX__
139 
140 // wrapper routines to give math functions C++ linkage
141 // this prevents warnings from the Oracle Studio compiler
142 inline double wr_sqrtd(double x)
143 {
144  return sqrt(x);
145 }
146 
147 inline double wr_hypotd(double x, double y)
148 {
149  return hypot(x, y);
150 }
151 
153 {
154  return sqrtf(x);
155 }
156 
158 {
159  return hypotf(x, y);
160 }
161 
162 void vsqrt(const double x[], double y[], long nlo, long nhi)
163 {
164  DEBUG_ENTRY( "vsqrt()" );
165 
166  vecfun( x, y, nlo, nhi, wr_sqrtd, v1sqrtd );
167 }
168 
169 void vhypot(const double x1[], const double x2[], double y[], long nlo, long nhi)
170 {
171  DEBUG_ENTRY( "vhypot()" );
172 
173  vecfun2( x1, x2, y, nlo, nhi, wr_hypotd, v1hypotd );
174 }
175 
176 void vsqrt(const sys_float x[], sys_float y[], long nlo, long nhi)
177 {
178  DEBUG_ENTRY( "vsqrt()" );
179 
180  vecfun( x, y, nlo, nhi, wr_sqrtf, v1sqrtf );
181 }
182 
183 void vhypot(const sys_float x1[], const sys_float x2[], sys_float y[], long nlo, long nhi)
184 {
185  DEBUG_ENTRY( "vhypot()" );
186 
187  vecfun2( x1, x2, y, nlo, nhi, wr_hypotf, v1hypotf );
188 }
189 
190 void vsqrt(double *y, double x0, double x1, double x2, double x3)
191 {
192  V1FUN_PD_4(sqrt, 1.);
193 }
194 
195 void vhypot(double *z, double x0, double y0, double x1, double y1, double x2, double y2, double x3, double y3)
196 {
197  V1FUN2_PD_4(hypot, 1.);
198 }
199 
200 void vsqrt(double *y, double x0, double x1, double x2, double x3, double x4, double x5, double x6, double x7)
201 {
202  V1FUN_PD_8(sqrt, 1.);
203 }
204 
206 {
207  V1FUN_PS_4(sqrt, 1.f);
208 }
209 
211  sys_float x3, sys_float y3)
212 {
213  V1FUN2_PS_4(hypot, 1.f);
214 }
215 
217  sys_float x6, sys_float x7)
218 {
219  V1FUN_PS_8(sqrt, 1.f);
220 }
221 
224  sys_float y6, sys_float x7, sys_float y7)
225 {
226  V1FUN2_PS_8(hypot, 1.f);
227 }
228 
230  sys_float x6, sys_float x7, sys_float x8, sys_float x9, sys_float x10, sys_float x11, sys_float x12,
231  sys_float x13, sys_float x14, sys_float x15)
232 {
233  V1FUN_PS_16(sqrt, 1.f);
234 }
#define V1FUN_PD_8(FUN, V)
static double x2[63]
static double x1[83]
#define V1FUN2_PS_8(FUN, V)
void vecfun2(const T x1[], const T x2[], T y[], long nlo, long nhi, T(*scalfun1)(T, T), V(*)(V, V))
int v1hypotf(int, int)
#define V1FUN2_PS_4(FUN, V)
void vhypot(const double x1[], const double x2[], double y[], long nlo, long nhi)
void zero(void)
Definition: zero.cpp:43
void vecfun(const T x[], T y[], long nlo, long nhi, T(*scalfun1)(T), V(*)(V))
#define V1FUN2_PD_4(FUN, V)
double wr_hypotd(double x, double y)
int v1sqrtf(int)
static double x0[83]
float sys_float
Definition: cddefines.h:127
sys_float wr_sqrtf(sys_float x)
#define V1FUN_PS_4(FUN, V)
#define V1FUN_PD_4(FUN, V)
void vsqrt(const double x[], double y[], long nlo, long nhi)
#define DEBUG_ENTRY(funcname)
Definition: cddefines.h:723
int v1sqrtd(int)
#define V1FUN_PS_8(FUN, V)
int v1hypotd(int, int)
#define V1FUN_PS_16(FUN, V)
sys_float wr_hypotf(sys_float x, sys_float y)
double wr_sqrtd(double x)