forked from dotnet/machinelearning
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathIntUtils.cs
More file actions
321 lines (294 loc) · 14.2 KB
/
Copy pathIntUtils.cs
File metadata and controls
321 lines (294 loc) · 14.2 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT license.
// See the LICENSE file in the project root for more information.
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Security;
using Microsoft.ML.Internal.CpuMath.Core;
namespace Microsoft.ML.Internal.CpuMath
{
[BestFriend]
internal static class IntUtils
{
/// <summary>
/// Add src to the 128 bits contained in dst. Ignores overflow, that is, the addition is done modulo 2^128.
/// </summary>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static void Add(ref ulong dstHi, ref ulong dstLo, ulong src)
{
if ((dstLo += src) < src)
dstHi++;
}
/// <summary>
/// Add src to dst. Ignores overflow, that is, the addition is done modulo 2^128.
/// </summary>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static void Add(ref ulong dstHi, ref ulong dstLo, ulong srcHi, ulong srcLo)
{
if ((dstLo += srcLo) < srcLo)
dstHi++;
dstHi += srcHi;
}
/// <summary>
/// Subtract src from the 128 bits contained in dst. Ignores overflow, that is, the subtraction is
/// done modulo 2^128.
/// </summary>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static void Sub(ref ulong dstHi, ref ulong dstLo, ulong src)
{
if (dstLo < src)
dstHi--;
dstLo -= src;
}
/// <summary>
/// Subtract src from dst. Ignores overflow, that is, the subtraction is done modulo 2^128.
/// </summary>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static void Sub(ref ulong dstHi, ref ulong dstLo, ulong srcHi, ulong srcLo)
{
dstHi -= srcHi;
if (dstLo < srcLo)
dstHi--;
dstLo -= srcLo;
}
/// <summary>
/// Return true if a is less than b.
/// </summary>
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static bool LessThan(ulong a1, ulong a0, ulong b1, ulong b0)
{
return a1 < b1 || a1 == b1 && a0 < b0;
}
/// <summary>
/// Divide the 128 bit value in <paramref name="lo"/> and <paramref name="hi"/> by <paramref name="den"/>.
/// returning the quotient and placing the remainder in <paramref name="rem"/>. Throws on overflow.
/// Note that <paramref name="lo"/> comes before <paramref name="hi"/>.
/// </summary>
#if !CORECLR
[DllImport(Thunk.NativePath), SuppressUnmanagedCodeSecurity]
private static extern ulong Div64(ulong lo, ulong hi, ulong den, out ulong rem);
#else
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static ulong Div64(ulong lo, ulong hi, ulong den, out ulong rem)
{
if (den == 0)
throw new DivideByZeroException();
if (hi >= den)
throw new OverflowException();
return Div64Core(lo, hi, den, out rem);
}
// REVIEW: on Linux, the hardware divide-by-zero exception is not translated into
// a managed exception properly by CoreCLR so the process will crash. This is a temporary fix
// until CoreCLR addresses this issue.
[DllImport(Thunk.NativePath, CharSet = CharSet.Unicode, EntryPoint = "Div64"), SuppressUnmanagedCodeSecurity]
private static extern ulong Div64Core(ulong lo, ulong hi, ulong den, out ulong rem);
#endif
/// <summary>
/// Multiple the two 64-bit values to produce 128 bit result.
/// </summary>
[DllImport(Thunk.NativePath), SuppressUnmanagedCodeSecurity]
private static extern ulong Mul64(ulong a, ulong b, out ulong hi);
/// <summary>
/// Divide and round to nearest using unbiased rounding. Throws on overflow.
/// Note that <paramref name="lo"/> comes before <paramref name="hi"/>.
/// </summary>
public static ulong DivRound(ulong lo, ulong hi, ulong den)
{
// Divide and get the remainder.
ulong rem;
ulong quo = Div64(lo, hi, den, out rem);
Contracts.Assert(rem < den);
// Perform unbiased rounding, ie, tie goes to the even value.
if (rem > den - rem || (rem == den - rem && (quo & 1) == 1))
quo = checked(quo + 1);
return quo;
}
/// <summary>
/// Divide and round to nearest using unbiased rounding. Throws on overflow.
/// Note that <paramref name="numLo"/> comes before <paramref name="numHi"/>.
/// </summary>
public static ulong DivRound(ulong numLo, ulong numHi, ulong denLo, ulong denHi)
{
// If the denominator fits in 64 bits use the simple overload above.
if (denHi == 0)
return DivRound(numLo, numHi, denLo);
// At this point, we're guaranteed that the quotient doesn't overflow (since denHi > 0),
// but rounding might still overflow.
// Our goal is to set quo to the correct value and modify numXx to contain the remainder,
// then fall through to the rounding code at the end.
ulong quo;
if (LessThan(numHi, numLo, denHi, denLo))
{
// Since num < den, quotient is zero and num is already the remainder.
quo = 0;
}
else if ((long)denHi < 0)
{
// The high bit of den is set, so den <= num < den * 2. Thus, quotient is one and the
// remainder is num - den.
Contracts.Assert((long)numHi < 0);
Sub(ref numHi, ref numLo, denHi, denLo);
Contracts.Assert(LessThan(numHi, numLo, denHi, denLo));
quo = 1;
}
else
{
// Shift num and den so that denHi has its high bit set. This requires 3 ulongs for num.
int cbitShiftLeft = CbitHighZero(denHi);
int cbitShiftRight = 64 - cbitShiftLeft;
Contracts.Assert(0 < cbitShiftLeft & cbitShiftLeft < 64);
denHi = (denHi << cbitShiftLeft) | (denLo >> cbitShiftRight);
denLo <<= cbitShiftLeft;
// The shifted numerator is (numEx, numHi, numLo). Note that the high bit of numEx must be zero,
// since cbitShiftRight > 0.
ulong numEx = numHi >> cbitShiftRight;
numHi = (numHi << cbitShiftLeft) | (numLo >> cbitShiftRight);
numLo <<= cbitShiftLeft;
Contracts.Assert((long)numEx >= 0);
Contracts.Assert((long)denHi < 0);
Contracts.Assert(numEx < denHi);
// Get the trial quotient and remainder by dividing (numEx, numHi) by denHi, and storing the
// remainder in numHi.
quo = Div64(numHi, numEx, denHi, out numHi);
// Note that the quotient could be slightly too big, but never too small. To see this:
//
// * Use notation [ABC] as short-hand for A*X*X + B*X + C, where X = 2^64 is the "base"
// and A, B, C are the "digits", so 0 <= A < X, 0 <= B < X, and 0 <= C < X.
//
// * Given: numerator [ABC] and denominator [DE].
// * The shifting above ensures X/2 <= D.
// * Given: Q and R with [AB] = D*Q + R, and 0 <= R <= D - 1.
// * Then [ABC] = [DE]*Q - E*Q + R*X + C = [DE]*Q + [RC] - E*Q.
// * Thus [RC] - E*Q is the signed remainder when using quotient Q. Note that it is not
// necessarily normalized to be between 0 and [DE], so Q is not necessarily the correct quotient.
// * However [RC] - E*Q <= [RC] < [DE] (since R < D), so Q is definitely NOT too small.
// * However [RC] - E*Q can clearly be negative, implying that Q might be too big.
// * Note that decreasing Q by x increases the remainder by x*[DE]. To get the correct quotient,
// we need the remainder r to satisfy 0 <= r < [DE].
// * Since D >= X/2 (by construction), [DE] >= X*X/2.
// * Trivially, [RC] - E*Q > -X*X >= -2*[DE]. This demonstrates that Q may need
// to decrease by at most two for the remainder to become non-negative.
//
// We can actually produce a tighter bound. Let k = cbitShiftLeft. Then 1 <= k <= 63.
// We know that the low k bits of E are zero and only the low k bits of A can be non-zero.
// Then A < 2^k and E = e*2^k where 0 <= e < 2^(64-k). We need to demonstrate that
// (1) [RC] - E*Q + [DE] >= 0,
// since that is equivalent to Q being too large by at most one.
// Suppose (1) is false. That is, suppose [RC] - E*Q + [DE] < 0. Then
// [DE] < E*Q - [RC] <= E*Q,
// so
// (2) D*X < (Q-1)*E.
// Note Q = (A*X+B-R) / D <= ((2^k-1)*X + (X-1)) / (X/2) < 2^(k+1). So Q < 2^(k+1). Since Q
// is an integer, this implies
// (3) Q <= 2^(k+1) - 1.
// Then
// 2^127 = X*X/2
// <= D*X since D >= X/2
// < (Q-1)*E by (2)
// <= (2^(k+1)-2)*e*2^k by (3)
// <= 2*(2^k-1)*(2^(64-k)-1)*2^k
// = 2^(65+k) - 2^(2k+1) - 2^65 + 2^(k+1)
// The only value of k that has a hope to make this true is k=63 (recall that k <= 63),
// in which case the right hand side is:
// = 2^128 - 2^127 - 2^65 + 2^64
// = 2^127 - 2^64
// Which is a contradiction, so our supposition that [RC] - E*Q + [DE] < 0 is impossible,
// implying that (1) holds, implying that Q is too big by at most one. Also note that Q is
// too big iff [RC] - E*Q < 0 iff [RC] < E*Q.
// Compute E*Q = denLo * quo, stored in (p1, p0).
ulong p1;
ulong p0 = Mul64(denLo, quo, out p1);
// See whether [RC] < E*Q, which is true iff [RC] - E*Q < 0 iff Q is too big.
if (LessThan(numHi, numLo, p1, p0))
{
// Need to decrement quo and add the denominator into the remainder.
Contracts.Assert(quo > 1);
quo--;
Add(ref numHi, ref numLo, denHi, denLo);
}
// Subtract E*Q from the remainder.
Sub(ref numHi, ref numLo, p1, p0);
}
// At this point, num is the remainder, so num < den.
Contracts.Assert(LessThan(numHi, numLo, denHi, denLo));
// Set den = den - num and then compare to num, to determine whether the remainder is closer
// to zero or to the denominator. If there is a tie, round to the even value. Note that the increment
// of quo might overflow, hence the "checked".
Sub(ref denHi, ref denLo, numHi, numLo);
if (LessThan(denHi, denLo, numHi, numLo) || (quo & 1) == 1 && denHi == numHi && denLo == numLo)
quo = checked(quo + 1);
return quo;
}
/// <summary>
/// Return the number of zero bits on the high end.
/// </summary>
private static int CbitHighZero(ulong u)
{
if (u == 0)
return 64;
int cbit = 0;
if ((u & 0xFFFFFFFF00000000) == 0)
{
cbit += 32;
u <<= 32;
}
if ((u & 0xFFFF000000000000) == 0)
{
cbit += 16;
u <<= 16;
}
if ((u & 0xFF00000000000000) == 0)
{
cbit += 8;
u <<= 8;
}
if ((u & 0xF000000000000000) == 0)
{
cbit += 4;
u <<= 4;
}
if ((u & 0xC000000000000000) == 0)
{
cbit += 2;
u <<= 2;
}
if ((u & 0x8000000000000000) == 0)
cbit += 1;
return cbit;
}
/// <summary>
/// Multiply <paramref name="a"/> and <paramref name="b"/> and divide by <paramref name="den"/>,
/// returning the quotient and placing the remainder in <paramref name="rem"/>. Throws on overflow.
/// </summary>
#if !CORECLR
[DllImport(Thunk.NativePath), SuppressUnmanagedCodeSecurity]
private static extern ulong MulDiv64Core(ulong a, ulong b, ulong den, out ulong rem);
public static ulong MulDiv64(ulong a, ulong b, ulong den, out ulong rem)
{
return MulDiv64Core(a, b, den, out rem);
}
#else
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static ulong MulDiv64(ulong a, ulong b, ulong den, out ulong rem)
{
if (den == 0)
throw new DivideByZeroException();
ulong quo;
if (!TryMulDiv64(a, b, den, out quo, out rem))
throw new OverflowException();
return quo;
}
#endif
/// <summary>
/// Multiply <paramref name="a"/> and <paramref name="b"/> and divide by <paramref name="den"/>,
/// placing the quotient in <paramref name="quo"/> and the remainder in <paramref name="rem"/>.
/// Returns true on success. On overflow, places zero in the out parameters and returns false.
/// </summary>
[DllImport(Thunk.NativePath), SuppressUnmanagedCodeSecurity]
private static extern bool TryMulDiv64Core(ulong a, ulong b, ulong den, out ulong quo, out ulong rem);
public static bool TryMulDiv64(ulong a, ulong b, ulong den, out ulong quo, out ulong rem)
{
return TryMulDiv64Core(a, b, den, out quo, out rem);
}
}
}