Skip to content

Commit 1ae99e5

Browse files
uwimtnoc0lour
authored andcommitted
tests: Improve coverage for mapping submodule
Adds missing tests to the tests/ subdirectory in the test_mapping.py file.
1 parent 8acbe04 commit 1ae99e5

File tree

5 files changed

+114
-1
lines changed

5 files changed

+114
-1
lines changed

poetry.lock

+15-1
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

+1
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ jaxtyping = "^0.2.36"
3838
pre-commit = "*"
3939
pytest = "*"
4040
pytest-coverage = "*"
41+
pytest-datafiles = "*"
4142
black = "*"
4243
flake8 = "*"
4344
pylint = "*"
71.9 KB
Binary file not shown.
7.75 KB
Binary file not shown.

tests/mokka/test_mapping.py

+98
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,14 @@
11
from mokka import mapping
2+
import mokka.utils.bitops.torch
3+
import mokka.utils.bitops.numpy
24
import torch
35
import numpy as np
6+
import pytest
7+
8+
from pathlib import Path
9+
10+
FIXTURE_DIR = Path(__file__).parent.resolve() / "data"
11+
412

513
# PyTorch tests
614

@@ -43,6 +51,56 @@ def test_qam_constellation_mapper():
4351
assert np.allclose(symbols, reference_symbols)
4452

4553

54+
def test_custom_constellation_mapper():
55+
m = 4
56+
bits = mokka.utils.bitops.numpy.generate_all_input_bits(m)
57+
symbols = mokka.utils.bitops.torch.bits2idx(torch.from_numpy(bits))
58+
mapper = mapping.torch.CustomConstellationMapper(m, symbols)
59+
reference_symbols = mapper.get_constellation().detach().numpy().flatten()
60+
assert np.allclose(symbols, reference_symbols)
61+
62+
63+
def test_separated_constellation_mapper():
64+
m = 4
65+
for i in range(0, 3):
66+
if i == 0:
67+
mapper = mapping.torch.SeparatedConstellationMapper(m, qam_init=True)
68+
symbols = mapper.get_constellation().detach().numpy().flatten()
69+
reference_symbols = mapping.numpy.QAM(m).get_constellation().flatten()
70+
assert np.allclose(symbols, reference_symbols)
71+
elif i == 1:
72+
mapper = mapping.torch.SeparatedConstellationMapper(m, m_real=1)
73+
symbols = mapper.get_constellation()
74+
assert symbols.shape[0] == 16
75+
elif i == 2:
76+
mapper = mapping.torch.SeparatedConstellationMapper(m, m_imag=1)
77+
symbols = mapper.get_constellation()
78+
assert symbols.shape[0] == 16
79+
80+
81+
@pytest.mark.datafiles(
82+
FIXTURE_DIR / "AWGN_16QAM_mapper.bin", FIXTURE_DIR / "AWGN_16QAM_demapper.bin"
83+
)
84+
def test_constellation_demapper(datafiles):
85+
m = 4
86+
bits = torch.from_numpy(
87+
np.array([[0, 1, 0, 0], [1, 0, 0, 0], [1, 1, 1, 1], [0, 1, 0, 1]], dtype=int)
88+
)
89+
for df in datafiles.iterdir():
90+
if "demapper" in df.name:
91+
demapper_dict = torch.load(df, map_location=torch.device("cpu"))
92+
elif "mapper" in df.name:
93+
mapper_dict = torch.load(df, map_location=torch.device("cpu"))
94+
else:
95+
raise ValueError("Neither mapper nor demapper in filename")
96+
mapper = mapping.torch.ConstellationMapper(m).load_model(mapper_dict)
97+
symbols = mapper(bits).flatten()
98+
demapper = mapping.torch.ConstellationDemapper(m).load_model(demapper_dict)
99+
llrs = demapper(symbols)
100+
rx_bits = (llrs.detach().numpy() < 0).astype(int)
101+
assert np.allclose(bits, rx_bits)
102+
103+
46104
def test_classical_demapper():
47105
m = 4
48106
sigma = torch.tensor(0.1)
@@ -59,6 +117,24 @@ def test_classical_demapper():
59117
assert np.allclose(bits, rx_bits)
60118

61119

120+
def test_classical_demapper_symbolwise():
121+
m = 4
122+
sigma = torch.tensor(0.1)
123+
bits = torch.from_numpy(
124+
np.array([[0, 1, 0, 0], [1, 0, 0, 0], [1, 1, 1, 1], [0, 1, 0, 1]], dtype=int)
125+
)
126+
onehot = mokka.utils.bitops.torch.bits_to_onehot(bits)
127+
mapper = mapping.torch.QAMConstellationMapper(m)
128+
symbols = mapper(bits).flatten()
129+
demapper = mapping.torch.ClassicalDemapper(
130+
sigma, mapper.get_constellation().flatten(), bitwise=False
131+
)
132+
q_value = demapper(symbols)
133+
print(q_value)
134+
rx_onehot = (q_value.detach().numpy() >= 1.0000e-5).astype(float)
135+
assert np.allclose(onehot, rx_onehot)
136+
137+
62138
def test_gaussian_demapper():
63139
m = 4
64140
bits = torch.from_numpy(
@@ -73,3 +149,25 @@ def test_gaussian_demapper():
73149
print(bits)
74150
print(rx_bits)
75151
assert np.allclose(bits, rx_bits)
152+
153+
154+
def test_separated_simple_demapper():
155+
m = 4
156+
bits = torch.from_numpy(
157+
np.array([[0, 1, 0, 0], [1, 0, 0, 0], [1, 1, 1, 1], [0, 1, 0, 1]], dtype=int)
158+
)
159+
mapper = mapping.torch.QAMConstellationMapper(m)
160+
symbols = mapper(bits).flatten()
161+
demapper = mapping.torch.SeparatedSimpleDemapper(m, demapper_width=128)
162+
llrs = demapper(symbols)
163+
rx_bits = (llrs.detach().numpy() < 0).astype(int)
164+
assert not np.allclose(bits, rx_bits)
165+
166+
167+
def test_pcssampler():
168+
m = 4
169+
batchsize = 10
170+
sampler = mapping.torch.PCSSampler(m)
171+
indices = sampler(batchsize).detach().numpy()
172+
print(indices)
173+
assert len(indices) == batchsize and max(indices) << 2 ^ m and min(indices) >= 0

0 commit comments

Comments
 (0)