1
1
from mokka import mapping
2
+ import mokka .utils .bitops .torch
3
+ import mokka .utils .bitops .numpy
2
4
import torch
3
5
import numpy as np
6
+ import pytest
7
+
8
+ from pathlib import Path
9
+
10
+ FIXTURE_DIR = Path (__file__ ).parent .resolve () / "data"
11
+
4
12
5
13
# PyTorch tests
6
14
@@ -43,6 +51,56 @@ def test_qam_constellation_mapper():
43
51
assert np .allclose (symbols , reference_symbols )
44
52
45
53
54
+ def test_custom_constellation_mapper ():
55
+ m = 4
56
+ bits = mokka .utils .bitops .numpy .generate_all_input_bits (m )
57
+ symbols = mokka .utils .bitops .torch .bits2idx (torch .from_numpy (bits ))
58
+ mapper = mapping .torch .CustomConstellationMapper (m , symbols )
59
+ reference_symbols = mapper .get_constellation ().detach ().numpy ().flatten ()
60
+ assert np .allclose (symbols , reference_symbols )
61
+
62
+
63
+ def test_separated_constellation_mapper ():
64
+ m = 4
65
+ for i in range (0 , 3 ):
66
+ if i == 0 :
67
+ mapper = mapping .torch .SeparatedConstellationMapper (m , qam_init = True )
68
+ symbols = mapper .get_constellation ().detach ().numpy ().flatten ()
69
+ reference_symbols = mapping .numpy .QAM (m ).get_constellation ().flatten ()
70
+ assert np .allclose (symbols , reference_symbols )
71
+ elif i == 1 :
72
+ mapper = mapping .torch .SeparatedConstellationMapper (m , m_real = 1 )
73
+ symbols = mapper .get_constellation ()
74
+ assert symbols .shape [0 ] == 16
75
+ elif i == 2 :
76
+ mapper = mapping .torch .SeparatedConstellationMapper (m , m_imag = 1 )
77
+ symbols = mapper .get_constellation ()
78
+ assert symbols .shape [0 ] == 16
79
+
80
+
81
+ @pytest .mark .datafiles (
82
+ FIXTURE_DIR / "AWGN_16QAM_mapper.bin" , FIXTURE_DIR / "AWGN_16QAM_demapper.bin"
83
+ )
84
+ def test_constellation_demapper (datafiles ):
85
+ m = 4
86
+ bits = torch .from_numpy (
87
+ np .array ([[0 , 1 , 0 , 0 ], [1 , 0 , 0 , 0 ], [1 , 1 , 1 , 1 ], [0 , 1 , 0 , 1 ]], dtype = int )
88
+ )
89
+ for df in datafiles .iterdir ():
90
+ if "demapper" in df .name :
91
+ demapper_dict = torch .load (df , map_location = torch .device ("cpu" ))
92
+ elif "mapper" in df .name :
93
+ mapper_dict = torch .load (df , map_location = torch .device ("cpu" ))
94
+ else :
95
+ raise ValueError ("Neither mapper nor demapper in filename" )
96
+ mapper = mapping .torch .ConstellationMapper (m ).load_model (mapper_dict )
97
+ symbols = mapper (bits ).flatten ()
98
+ demapper = mapping .torch .ConstellationDemapper (m ).load_model (demapper_dict )
99
+ llrs = demapper (symbols )
100
+ rx_bits = (llrs .detach ().numpy () < 0 ).astype (int )
101
+ assert np .allclose (bits , rx_bits )
102
+
103
+
46
104
def test_classical_demapper ():
47
105
m = 4
48
106
sigma = torch .tensor (0.1 )
@@ -59,6 +117,24 @@ def test_classical_demapper():
59
117
assert np .allclose (bits , rx_bits )
60
118
61
119
120
+ def test_classical_demapper_symbolwise ():
121
+ m = 4
122
+ sigma = torch .tensor (0.1 )
123
+ bits = torch .from_numpy (
124
+ np .array ([[0 , 1 , 0 , 0 ], [1 , 0 , 0 , 0 ], [1 , 1 , 1 , 1 ], [0 , 1 , 0 , 1 ]], dtype = int )
125
+ )
126
+ onehot = mokka .utils .bitops .torch .bits_to_onehot (bits )
127
+ mapper = mapping .torch .QAMConstellationMapper (m )
128
+ symbols = mapper (bits ).flatten ()
129
+ demapper = mapping .torch .ClassicalDemapper (
130
+ sigma , mapper .get_constellation ().flatten (), bitwise = False
131
+ )
132
+ q_value = demapper (symbols )
133
+ print (q_value )
134
+ rx_onehot = (q_value .detach ().numpy () >= 1.0000e-5 ).astype (float )
135
+ assert np .allclose (onehot , rx_onehot )
136
+
137
+
62
138
def test_gaussian_demapper ():
63
139
m = 4
64
140
bits = torch .from_numpy (
@@ -73,3 +149,25 @@ def test_gaussian_demapper():
73
149
print (bits )
74
150
print (rx_bits )
75
151
assert np .allclose (bits , rx_bits )
152
+
153
+
154
+ def test_separated_simple_demapper ():
155
+ m = 4
156
+ bits = torch .from_numpy (
157
+ np .array ([[0 , 1 , 0 , 0 ], [1 , 0 , 0 , 0 ], [1 , 1 , 1 , 1 ], [0 , 1 , 0 , 1 ]], dtype = int )
158
+ )
159
+ mapper = mapping .torch .QAMConstellationMapper (m )
160
+ symbols = mapper (bits ).flatten ()
161
+ demapper = mapping .torch .SeparatedSimpleDemapper (m , demapper_width = 128 )
162
+ llrs = demapper (symbols )
163
+ rx_bits = (llrs .detach ().numpy () < 0 ).astype (int )
164
+ assert not np .allclose (bits , rx_bits )
165
+
166
+
167
+ def test_pcssampler ():
168
+ m = 4
169
+ batchsize = 10
170
+ sampler = mapping .torch .PCSSampler (m )
171
+ indices = sampler (batchsize ).detach ().numpy ()
172
+ print (indices )
173
+ assert len (indices ) == batchsize and max (indices ) << 2 ^ m and min (indices ) >= 0
0 commit comments