-
Notifications
You must be signed in to change notification settings - Fork 1
/
seeding.py
120 lines (120 loc) · 4.35 KB
/
seeding.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "human-visibility",
"metadata": {},
"outputs": [],
"source": [
"import hashlib\n",
"import numpy as np\n",
"import os\n",
"import random as _random\n",
"import struct\n",
"import sys\n",
"\n",
"from gym import error\n",
"\n",
"def np_random(seed=None):\n",
" if seed is not None and not (isinstance(seed, int) and 0 <= seed):\n",
" raise error.Error('Seed must be a non-negative integer or omitted, not {}'.format(seed))\n",
"\n",
" seed = create_seed(seed)\n",
"\n",
" rng = np.random.RandomState()\n",
" rng.seed(_int_list_from_bigint(hash_seed(seed)))\n",
" return rng, seed\n",
"\n",
"def hash_seed(seed=None, max_bytes=8):\n",
" \"\"\"Any given evaluation is likely to have many PRNG's active at\n",
" once. (Most commonly, because the environment is running in\n",
" multiple processes.) There's literature indicating that having\n",
" linear correlations between seeds of multiple PRNG's can correlate\n",
" the outputs:\n",
" http://blogs.unity3d.com/2015/01/07/a-primer-on-repeatable-random-numbers/\n",
" http://stackoverflow.com/questions/1554958/how-different-do-random-seeds-need-to-be\n",
" http://dl.acm.org/citation.cfm?id=1276928\n",
" Thus, for sanity we hash the seeds before using them. (This scheme\n",
" is likely not crypto-strength, but it should be good enough to get\n",
" rid of simple correlations.)\n",
" Args:\n",
" seed (Optional[int]): None seeds from an operating system specific randomness source.\n",
" max_bytes: Maximum number of bytes to use in the hashed seed.\n",
" \"\"\"\n",
" if seed is None:\n",
" seed = create_seed(max_bytes=max_bytes)\n",
" hash = hashlib.sha512(str(seed).encode('utf8')).digest()\n",
" return _bigint_from_bytes(hash[:max_bytes])\n",
"\n",
"def create_seed(a=None, max_bytes=8):\n",
" \"\"\"Create a strong random seed. Otherwise, Python 2 would seed using\n",
" the system time, which might be non-robust especially in the\n",
" presence of concurrency.\n",
" Args:\n",
" a (Optional[int, str]): None seeds from an operating system specific randomness source.\n",
" max_bytes: Maximum number of bytes to use in the seed.\n",
" \"\"\"\n",
" # Adapted from https://svn.python.org/projects/python/tags/r32/Lib/random.py\n",
" if a is None:\n",
" a = _bigint_from_bytes(os.urandom(max_bytes))\n",
" elif isinstance(a, str):\n",
" a = a.encode('utf8')\n",
" a += hashlib.sha512(a).digest()\n",
" a = _bigint_from_bytes(a[:max_bytes])\n",
" elif isinstance(a, int):\n",
" a = a % 2**(8 * max_bytes)\n",
" else:\n",
" raise error.Error('Invalid type for seed: {} ({})'.format(type(a), a))\n",
"\n",
" return a\n",
"\n",
"# TODO: don't hardcode sizeof_int here\n",
"def _bigint_from_bytes(bytes):\n",
" sizeof_int = 4\n",
" padding = sizeof_int - len(bytes) % sizeof_int\n",
" bytes += b'\\0' * padding\n",
" int_count = int(len(bytes) / sizeof_int)\n",
" unpacked = struct.unpack(\"{}I\".format(int_count), bytes)\n",
" accum = 0\n",
" for i, val in enumerate(unpacked):\n",
" accum += 2 ** (sizeof_int * 8 * i) * val\n",
" return accum\n",
"\n",
"def _int_list_from_bigint(bigint):\n",
" # Special case 0\n",
" if bigint < 0:\n",
" raise error.Error('Seed must be non-negative, not {}'.format(bigint))\n",
" elif bigint == 0:\n",
" return [0]\n",
"\n",
" ints = []\n",
" while bigint > 0:\n",
" bigint, mod = divmod(bigint, 2 ** 32)\n",
" ints.append(mod)\n",
" return ints"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.5"
}
},
"nbformat": 4,
"nbformat_minor": 5
}