linear_congruential_engine: Fixes for __lce_alg_picker (#81080)
This fixes two major mistakes in the implementation of `linear_congruential_engine` that allowed it to produce incorrect output. Specifically, these mistakes are in `__lce_alg_picker`, which is used to determine whether Schrage's algorithm is valid and needed. The first mistake is in the definition of `_OverflowOK`. The code comment and the description of [D65041](https://reviews.llvm.org/D65041) both indicate that it's supposed to be true iff `m` is a power of two. However, the definition used does not work out to that, and instead is true whenever `m` is even. This could result in `linear_congruential_engine` using an invalid implementation, as it would incorrectly assume that any integer overflow can't change the result. I changed the implementation to one that accurately checks if `m` is a power of two. Technically, this implementation has an edge case where it considers `0` to be a power of two, but in this case this is actually accurate behavior, as `m = 0` indicates a modulus of 2^w where w is the size of `result_type` in bits, which *is* a power of two. The second mistake is in the static assert. The original static assert erroneously included an unnecessary `a != 0 || m != 0`. Combined with the `|| !_MightOverflow`, this actually resulted in the static assert being impossible to fail. Applying De Morgan's law and expanding `_MightOverflow` gives that the only way this static assert can be triggered is if `a == 0 && m == 0 && a != 0 && m != 0 && ...`, which clearly cannot be true. I simply removed the explicit checks against `a` and `m`, as the intended checks are already included in `_MightOverflow` and `_SchrageOK`, and their inclusion doesn't provide any obvious semantic benefit. This should fix all the current instances where `linear_congruential_engine` uses an invalid implementation. This technically isn't a complete implementation, though, since the static assert will cause some instantiations of `linear_congruential_engine` not disallowed by the standard from compiling. However, this should still be an improvement, as all compiling instantiations of `linear_congruential_engine` should use a valid implementation. Fixing the cases where the static assert triggers will require adding additional implementations, some of which will be fairly non-trivial, so I'd rather leave those for another PR so they don't hold up these more important fixes. Fixes #33554
This commit is contained in:
parent
0f1847cb2c
commit
fc027e10ba
@ -31,10 +31,10 @@ template <unsigned long long __a,
|
||||
unsigned long long __m,
|
||||
unsigned long long _Mp,
|
||||
bool _MightOverflow = (__a != 0 && __m != 0 && __m - 1 > (_Mp - __c) / __a),
|
||||
bool _OverflowOK = ((__m | (__m - 1)) > __m), // m = 2^n
|
||||
bool _OverflowOK = ((__m & (__m - 1)) == 0ull), // m = 2^n
|
||||
bool _SchrageOK = (__a != 0 && __m != 0 && __m % __a <= __m / __a)> // r <= q
|
||||
struct __lce_alg_picker {
|
||||
static_assert(__a != 0 || __m != 0 || !_MightOverflow || _OverflowOK || _SchrageOK,
|
||||
static_assert(!_MightOverflow || _OverflowOK || _SchrageOK,
|
||||
"The current values of a, c, and m cannot generate a number "
|
||||
"within bounds of linear_congruential_engine.");
|
||||
|
||||
|
@ -22,13 +22,13 @@ int main(int, char**)
|
||||
{
|
||||
typedef unsigned long long T;
|
||||
|
||||
// m might overflow, but the overflow is OK so it shouldn't use schrage's algorithm
|
||||
// m might overflow, but the overflow is OK so it shouldn't use Schrage's algorithm
|
||||
typedef std::linear_congruential_engine<T, 25214903917ull, 1, (1ull << 48)> E1;
|
||||
E1 e1;
|
||||
// make sure the right algorithm was used
|
||||
assert(e1() == 25214903918);
|
||||
assert(e1() == 205774354444503);
|
||||
assert(e1() == 158051849450892);
|
||||
assert(e1() == 25214903918ull);
|
||||
assert(e1() == 205774354444503ull);
|
||||
assert(e1() == 158051849450892ull);
|
||||
// make sure result is in bounds
|
||||
assert(e1() < (1ull << 48));
|
||||
assert(e1() < (1ull << 48));
|
||||
@ -37,33 +37,48 @@ int main(int, char**)
|
||||
assert(e1() < (1ull << 48));
|
||||
|
||||
// m might overflow. The overflow is not OK and result will be in bounds
|
||||
// so we should use shrage's algorithm
|
||||
typedef std::linear_congruential_engine<T, (1ull<<2), 0, (1ull<<63) + 1> E2;
|
||||
// so we should use Schrage's algorithm
|
||||
typedef std::linear_congruential_engine<T, (1ull << 32), 0, (1ull << 63) + 1> E2;
|
||||
E2 e2;
|
||||
// make sure shrage's algorithm is used (it would be 0s otherwise)
|
||||
assert(e2() == 4);
|
||||
assert(e2() == 16);
|
||||
assert(e2() == 64);
|
||||
// make sure Schrage's algorithm is used (it would be 0s after the first otherwise)
|
||||
assert(e2() == (1ull << 32));
|
||||
assert(e2() == (1ull << 63) - 1ull);
|
||||
assert(e2() == (1ull << 63) - (1ull << 33) + 1ull);
|
||||
// make sure result is in bounds
|
||||
assert(e2() < (1ull<<48) + 1);
|
||||
assert(e2() < (1ull<<48) + 1);
|
||||
assert(e2() < (1ull<<48) + 1);
|
||||
assert(e2() < (1ull<<48) + 1);
|
||||
assert(e2() < (1ull<<48) + 1);
|
||||
assert(e2() < (1ull << 63) + 1);
|
||||
assert(e2() < (1ull << 63) + 1);
|
||||
assert(e2() < (1ull << 63) + 1);
|
||||
assert(e2() < (1ull << 63) + 1);
|
||||
assert(e2() < (1ull << 63) + 1);
|
||||
|
||||
// m will not overflow so we should not use shrage's algorithm
|
||||
typedef std::linear_congruential_engine<T, 1ull, 1, (1ull<<48)> E3;
|
||||
// m might overflow. The overflow is not OK and result will be in bounds
|
||||
// so we should use Schrage's algorithm. m is even
|
||||
typedef std::linear_congruential_engine<T, 0x18000001ull, 0x12347ull, (3ull << 56)> E3;
|
||||
E3 e3;
|
||||
// make sure the correct algorithm was used
|
||||
assert(e3() == 2);
|
||||
assert(e3() == 3);
|
||||
assert(e3() == 4);
|
||||
// make sure Schrage's algorithm is used
|
||||
assert(e3() == 402727752ull);
|
||||
assert(e3() == 162159612030764687ull);
|
||||
assert(e3() == 108176466184989142ull);
|
||||
// make sure result is in bounds
|
||||
assert(e3() < (1ull<<48));
|
||||
assert(e3() < (1ull<<48));
|
||||
assert(e3() < (1ull<<48));
|
||||
assert(e3() < (1ull<<48));
|
||||
assert(e2() < (1ull<<48));
|
||||
assert(e3() < (3ull << 56));
|
||||
assert(e3() < (3ull << 56));
|
||||
assert(e3() < (3ull << 56));
|
||||
assert(e3() < (3ull << 56));
|
||||
assert(e3() < (3ull << 56));
|
||||
|
||||
// m will not overflow so we should not use Schrage's algorithm
|
||||
typedef std::linear_congruential_engine<T, 1ull, 1, (1ull << 48)> E4;
|
||||
E4 e4;
|
||||
// make sure the correct algorithm was used
|
||||
assert(e4() == 2ull);
|
||||
assert(e4() == 3ull);
|
||||
assert(e4() == 4ull);
|
||||
// make sure result is in bounds
|
||||
assert(e4() < (1ull << 48));
|
||||
assert(e4() < (1ull << 48));
|
||||
assert(e4() < (1ull << 48));
|
||||
assert(e4() < (1ull << 48));
|
||||
assert(e4() < (1ull << 48));
|
||||
|
||||
return 0;
|
||||
}
|
@ -15,6 +15,7 @@
|
||||
|
||||
#include <random>
|
||||
#include <cassert>
|
||||
#include <climits>
|
||||
|
||||
#include "test_macros.h"
|
||||
|
||||
@ -35,19 +36,41 @@ template <class T>
|
||||
void
|
||||
test()
|
||||
{
|
||||
test1<T, 0, 0, 0>();
|
||||
test1<T, 0, 1, 2>();
|
||||
test1<T, 1, 1, 2>();
|
||||
const int W = sizeof(T) * CHAR_BIT;
|
||||
const T M(static_cast<T>(-1));
|
||||
const T A(static_cast<T>((static_cast<T>(1) << (W / 2)) - 1));
|
||||
|
||||
// Cases where m = 0
|
||||
test1<T, 0, 0, 0>();
|
||||
test1<T, A, 0, 0>();
|
||||
test1<T, 0, 1, 0>();
|
||||
test1<T, A, 1, 0>();
|
||||
|
||||
// Cases where m = 2^n for n < w
|
||||
test1<T, 0, 0, 256>();
|
||||
test1<T, 5, 0, 256>();
|
||||
test1<T, 0, 1, 256>();
|
||||
test1<T, 5, 1, 256>();
|
||||
|
||||
// Cases where m is odd and a = 0
|
||||
test1<T, 0, 0, M>();
|
||||
test1<T, 0, M - 2, M>();
|
||||
test1<T, 0, M - 1, M>();
|
||||
|
||||
// Cases where m is odd and m % a <= m / a (Schrage)
|
||||
test1<T, A, 0, M>();
|
||||
test1<T, A, M - 2, M>();
|
||||
test1<T, A, M - 1, M>();
|
||||
|
||||
/*
|
||||
// Cases where m is odd and m % a > m / a (not implemented)
|
||||
test1<T, M - 2, 0, M>();
|
||||
test1<T, M - 2, M - 2, M>();
|
||||
test1<T, M - 2, M - 1, M>();
|
||||
test1<T, M - 1, 0, M>();
|
||||
test1<T, M - 1, M - 2, M>();
|
||||
test1<T, M - 1, M - 1, M>();
|
||||
*/
|
||||
}
|
||||
|
||||
int main(int, char**)
|
||||
|
@ -35,19 +35,41 @@ template <class T>
|
||||
void
|
||||
test()
|
||||
{
|
||||
test1<T, 0, 0, 0>();
|
||||
test1<T, 0, 1, 2>();
|
||||
test1<T, 1, 1, 2>();
|
||||
const int W = sizeof(T) * CHAR_BIT;
|
||||
const T M(static_cast<T>(-1));
|
||||
const T A(static_cast<T>((static_cast<T>(1) << (W / 2)) - 1));
|
||||
|
||||
// Cases where m = 0
|
||||
test1<T, 0, 0, 0>();
|
||||
test1<T, A, 0, 0>();
|
||||
test1<T, 0, 1, 0>();
|
||||
test1<T, A, 1, 0>();
|
||||
|
||||
// Cases where m = 2^n for n < w
|
||||
test1<T, 0, 0, 256>();
|
||||
test1<T, 5, 0, 256>();
|
||||
test1<T, 0, 1, 256>();
|
||||
test1<T, 5, 1, 256>();
|
||||
|
||||
// Cases where m is odd and a = 0
|
||||
test1<T, 0, 0, M>();
|
||||
test1<T, 0, M - 2, M>();
|
||||
test1<T, 0, M - 1, M>();
|
||||
|
||||
// Cases where m is odd and m % a <= m / a (Schrage)
|
||||
test1<T, A, 0, M>();
|
||||
test1<T, A, M - 2, M>();
|
||||
test1<T, A, M - 1, M>();
|
||||
|
||||
/*
|
||||
// Cases where m is odd and m % a > m / a (not implemented)
|
||||
test1<T, M - 2, 0, M>();
|
||||
test1<T, M - 2, M - 2, M>();
|
||||
test1<T, M - 2, M - 1, M>();
|
||||
test1<T, M - 1, 0, M>();
|
||||
test1<T, M - 1, M - 2, M>();
|
||||
test1<T, M - 1, M - 1, M>();
|
||||
*/
|
||||
}
|
||||
|
||||
int main(int, char**)
|
||||
|
@ -33,19 +33,41 @@ template <class T>
|
||||
void
|
||||
test()
|
||||
{
|
||||
test1<T, 0, 0, 0>();
|
||||
test1<T, 0, 1, 2>();
|
||||
test1<T, 1, 1, 2>();
|
||||
const int W = sizeof(T) * CHAR_BIT;
|
||||
const T M(static_cast<T>(-1));
|
||||
const T A(static_cast<T>((static_cast<T>(1) << (W / 2)) - 1));
|
||||
|
||||
// Cases where m = 0
|
||||
test1<T, 0, 0, 0>();
|
||||
test1<T, A, 0, 0>();
|
||||
test1<T, 0, 1, 0>();
|
||||
test1<T, A, 1, 0>();
|
||||
|
||||
// Cases where m = 2^n for n < w
|
||||
test1<T, 0, 0, 256>();
|
||||
test1<T, 5, 0, 256>();
|
||||
test1<T, 0, 1, 256>();
|
||||
test1<T, 5, 1, 256>();
|
||||
|
||||
// Cases where m is odd and a = 0
|
||||
test1<T, 0, 0, M>();
|
||||
test1<T, 0, M - 2, M>();
|
||||
test1<T, 0, M - 1, M>();
|
||||
|
||||
// Cases where m is odd and m % a <= m / a (Schrage)
|
||||
test1<T, A, 0, M>();
|
||||
test1<T, A, M - 2, M>();
|
||||
test1<T, A, M - 1, M>();
|
||||
|
||||
/*
|
||||
// Cases where m is odd and m % a > m / a (not implemented)
|
||||
test1<T, M - 2, 0, M>();
|
||||
test1<T, M - 2, M - 2, M>();
|
||||
test1<T, M - 2, M - 1, M>();
|
||||
test1<T, M - 1, 0, M>();
|
||||
test1<T, M - 1, M - 2, M>();
|
||||
test1<T, M - 1, M - 1, M>();
|
||||
*/
|
||||
}
|
||||
|
||||
int main(int, char**)
|
||||
|
@ -66,19 +66,41 @@ template <class T>
|
||||
void
|
||||
test()
|
||||
{
|
||||
test1<T, 0, 0, 0>();
|
||||
test1<T, 0, 1, 2>();
|
||||
test1<T, 1, 1, 2>();
|
||||
const int W = sizeof(T) * CHAR_BIT;
|
||||
const T M(static_cast<T>(-1));
|
||||
const T A(static_cast<T>((static_cast<T>(1) << (W / 2)) - 1));
|
||||
|
||||
// Cases where m = 0
|
||||
test1<T, 0, 0, 0>();
|
||||
test1<T, A, 0, 0>();
|
||||
test1<T, 0, 1, 0>();
|
||||
test1<T, A, 1, 0>();
|
||||
|
||||
// Cases where m = 2^n for n < w
|
||||
test1<T, 0, 0, 256>();
|
||||
test1<T, 5, 0, 256>();
|
||||
test1<T, 0, 1, 256>();
|
||||
test1<T, 5, 1, 256>();
|
||||
|
||||
// Cases where m is odd and a = 0
|
||||
test1<T, 0, 0, M>();
|
||||
test1<T, 0, M - 2, M>();
|
||||
test1<T, 0, M - 1, M>();
|
||||
|
||||
// Cases where m is odd and m % a <= m / a (Schrage)
|
||||
test1<T, A, 0, M>();
|
||||
test1<T, A, M - 2, M>();
|
||||
test1<T, A, M - 1, M>();
|
||||
|
||||
/*
|
||||
// Cases where m is odd and m % a > m / a (not implemented)
|
||||
test1<T, M - 2, 0, M>();
|
||||
test1<T, M - 2, M - 2, M>();
|
||||
test1<T, M - 2, M - 1, M>();
|
||||
test1<T, M - 1, 0, M>();
|
||||
test1<T, M - 1, M - 2, M>();
|
||||
test1<T, M - 1, M - 1, M>();
|
||||
*/
|
||||
}
|
||||
|
||||
int main(int, char**)
|
||||
|
Loading…
x
Reference in New Issue
Block a user