nearlink_sdr.common.polar 源代码

"""Polar 编码器与 SC(逐次消去)解码器 -- TXS-10002-2025 标准 6.9.1.4 节。"""


__all__ = [
    "RATE_TABLE",
    "RELIABILITY_SEQ_1024",
    "VALID_CODE_LENGTHS",
    "PolarDecoder",
    "PolarEncoder",
    "get_info_bit_count",
    "get_polar_decoder",
]


from functools import lru_cache

import numpy as np

try:
    from nearlink_sdr_accel import RustPolarDecoder as _RustPolarDecoder
    from nearlink_sdr_accel import RustPolarEncoder as _RustPolarEncoder
    _HAS_RUST_ACCEL = True
except ImportError:
    _HAS_RUST_ACCEL = False

# N_max=1024 的可靠性序列 Q_0^{N_max-1}
# 按可靠性升序排列,索引 i 对应比特索引 Q_i。
RELIABILITY_SEQ_1024 = [
    0,
    1,
    2,
    4,
    8,
    16,
    32,
    3,
    5,
    64,
    9,
    6,
    17,
    10,
    18,
    128,
    12,
    33,
    65,
    20,
    256,
    34,
    24,
    36,
    7,
    129,
    66,
    512,
    11,
    40,
    68,
    130,
    19,
    13,
    48,
    14,
    72,
    257,
    21,
    132,
    35,
    258,
    26,
    513,
    80,
    37,
    25,
    22,
    136,
    260,
    264,
    38,
    514,
    96,
    67,
    41,
    144,
    28,
    69,
    42,
    516,
    49,
    74,
    272,
    160,
    520,
    288,
    528,
    192,
    544,
    70,
    44,
    131,
    81,
    50,
    73,
    15,
    320,
    133,
    52,
    23,
    134,
    384,
    76,
    137,
    82,
    56,
    27,
    97,
    39,
    259,
    84,
    138,
    145,
    261,
    29,
    43,
    98,
    515,
    88,
    140,
    30,
    146,
    71,
    262,
    265,
    161,
    576,
    45,
    100,
    640,
    51,
    148,
    46,
    75,
    266,
    273,
    517,
    104,
    162,
    53,
    193,
    152,
    77,
    164,
    768,
    268,
    274,
    518,
    54,
    83,
    57,
    521,
    112,
    135,
    78,
    289,
    194,
    85,
    276,
    522,
    58,
    168,
    139,
    99,
    86,
    60,
    280,
    89,
    290,
    529,
    524,
    196,
    141,
    101,
    147,
    176,
    142,
    530,
    321,
    31,
    200,
    90,
    545,
    292,
    322,
    532,
    263,
    149,
    102,
    105,
    304,
    296,
    163,
    92,
    47,
    267,
    385,
    546,
    324,
    208,
    386,
    150,
    153,
    165,
    106,
    55,
    328,
    536,
    577,
    548,
    113,
    154,
    79,
    269,
    108,
    578,
    224,
    166,
    519,
    552,
    195,
    270,
    641,
    523,
    275,
    580,
    291,
    59,
    169,
    560,
    114,
    277,
    156,
    87,
    197,
    116,
    170,
    61,
    531,
    525,
    642,
    281,
    278,
    526,
    177,
    293,
    388,
    91,
    584,
    769,
    198,
    172,
    120,
    201,
    336,
    62,
    282,
    143,
    103,
    178,
    294,
    93,
    644,
    202,
    592,
    323,
    392,
    297,
    770,
    107,
    180,
    151,
    209,
    284,
    648,
    94,
    204,
    298,
    400,
    608,
    352,
    325,
    533,
    155,
    210,
    305,
    547,
    300,
    109,
    184,
    534,
    537,
    115,
    167,
    225,
    326,
    306,
    772,
    157,
    656,
    329,
    110,
    117,
    212,
    171,
    776,
    330,
    226,
    549,
    538,
    387,
    308,
    216,
    416,
    271,
    279,
    158,
    337,
    550,
    672,
    118,
    332,
    579,
    540,
    389,
    173,
    121,
    553,
    199,
    784,
    179,
    228,
    338,
    312,
    704,
    390,
    174,
    554,
    581,
    393,
    283,
    122,
    448,
    353,
    561,
    203,
    63,
    340,
    394,
    527,
    582,
    556,
    181,
    295,
    285,
    232,
    124,
    205,
    182,
    643,
    562,
    286,
    585,
    299,
    354,
    211,
    401,
    185,
    396,
    344,
    586,
    645,
    593,
    535,
    240,
    206,
    95,
    327,
    564,
    800,
    402,
    356,
    307,
    301,
    417,
    213,
    568,
    832,
    588,
    186,
    646,
    404,
    227,
    896,
    594,
    418,
    302,
    649,
    771,
    360,
    539,
    111,
    331,
    214,
    309,
    188,
    449,
    217,
    408,
    609,
    596,
    551,
    650,
    229,
    159,
    420,
    310,
    541,
    773,
    610,
    657,
    333,
    119,
    600,
    339,
    218,
    368,
    652,
    230,
    391,
    313,
    450,
    542,
    334,
    233,
    555,
    774,
    175,
    123,
    658,
    612,
    341,
    777,
    220,
    314,
    424,
    395,
    673,
    583,
    355,
    287,
    183,
    234,
    125,
    557,
    660,
    616,
    342,
    316,
    241,
    778,
    563,
    345,
    452,
    397,
    403,
    207,
    674,
    558,
    785,
    432,
    357,
    187,
    236,
    664,
    624,
    587,
    780,
    705,
    126,
    242,
    565,
    398,
    346,
    456,
    358,
    405,
    303,
    569,
    244,
    595,
    189,
    566,
    676,
    361,
    706,
    589,
    215,
    786,
    647,
    348,
    419,
    406,
    464,
    680,
    801,
    362,
    590,
    409,
    570,
    788,
    597,
    572,
    219,
    311,
    708,
    598,
    601,
    651,
    421,
    792,
    802,
    611,
    602,
    410,
    231,
    688,
    653,
    248,
    369,
    190,
    364,
    654,
    659,
    335,
    480,
    315,
    221,
    370,
    613,
    422,
    425,
    451,
    614,
    543,
    235,
    412,
    343,
    372,
    775,
    317,
    222,
    426,
    453,
    237,
    559,
    833,
    804,
    712,
    834,
    661,
    808,
    779,
    617,
    604,
    433,
    720,
    816,
    836,
    347,
    897,
    243,
    662,
    454,
    318,
    675,
    618,
    898,
    781,
    376,
    428,
    665,
    736,
    567,
    840,
    625,
    238,
    359,
    457,
    399,
    787,
    591,
    678,
    434,
    677,
    349,
    245,
    458,
    666,
    620,
    363,
    127,
    191,
    782,
    407,
    436,
    626,
    571,
    465,
    681,
    246,
    707,
    350,
    599,
    668,
    790,
    460,
    249,
    682,
    573,
    411,
    803,
    789,
    709,
    365,
    440,
    628,
    689,
    374,
    423,
    466,
    793,
    250,
    371,
    481,
    574,
    413,
    603,
    366,
    468,
    655,
    900,
    805,
    615,
    684,
    710,
    429,
    794,
    252,
    373,
    605,
    848,
    690,
    713,
    632,
    482,
    806,
    427,
    904,
    414,
    223,
    663,
    692,
    835,
    619,
    472,
    455,
    796,
    809,
    714,
    721,
    837,
    716,
    864,
    810,
    606,
    912,
    722,
    696,
    377,
    435,
    817,
    319,
    621,
    812,
    484,
    430,
    838,
    667,
    488,
    239,
    378,
    459,
    622,
    627,
    437,
    380,
    818,
    461,
    496,
    669,
    679,
    724,
    841,
    629,
    351,
    467,
    438,
    737,
    251,
    462,
    442,
    441,
    469,
    247,
    683,
    842,
    738,
    899,
    670,
    783,
    849,
    820,
    728,
    928,
    791,
    367,
    901,
    630,
    685,
    844,
    633,
    711,
    253,
    691,
    824,
    902,
    686,
    740,
    850,
    375,
    444,
    470,
    483,
    415,
    485,
    905,
    795,
    473,
    634,
    744,
    852,
    960,
    865,
    693,
    797,
    906,
    715,
    807,
    474,
    636,
    694,
    254,
    717,
    575,
    913,
    798,
    811,
    379,
    697,
    431,
    607,
    489,
    866,
    723,
    486,
    908,
    718,
    813,
    476,
    856,
    839,
    725,
    698,
    914,
    752,
    868,
    819,
    814,
    439,
    929,
    490,
    623,
    671,
    739,
    916,
    463,
    843,
    381,
    497,
    930,
    821,
    726,
    961,
    872,
    492,
    631,
    729,
    700,
    443,
    741,
    845,
    920,
    382,
    822,
    851,
    730,
    498,
    880,
    742,
    445,
    471,
    635,
    932,
    687,
    903,
    825,
    500,
    846,
    745,
    826,
    732,
    446,
    962,
    936,
    475,
    853,
    867,
    637,
    907,
    487,
    695,
    746,
    828,
    753,
    854,
    857,
    504,
    799,
    255,
    964,
    909,
    719,
    477,
    915,
    638,
    748,
    944,
    869,
    491,
    699,
    754,
    858,
    478,
    968,
    383,
    910,
    815,
    976,
    870,
    917,
    727,
    493,
    873,
    701,
    931,
    756,
    860,
    499,
    731,
    823,
    922,
    874,
    918,
    502,
    933,
    743,
    760,
    881,
    494,
    702,
    921,
    501,
    876,
    847,
    992,
    447,
    733,
    827,
    934,
    882,
    937,
    963,
    747,
    505,
    855,
    924,
    734,
    829,
    965,
    938,
    884,
    506,
    749,
    945,
    966,
    755,
    859,
    940,
    830,
    911,
    871,
    639,
    888,
    479,
    946,
    750,
    969,
    508,
    861,
    757,
    970,
    919,
    875,
    862,
    758,
    948,
    977,
    923,
    972,
    761,
    877,
    952,
    495,
    703,
    935,
    978,
    883,
    762,
    503,
    925,
    878,
    735,
    993,
    885,
    939,
    994,
    980,
    926,
    764,
    941,
    967,
    886,
    831,
    947,
    507,
    889,
    984,
    751,
    942,
    996,
    971,
    890,
    509,
    949,
    973,
    1000,
    892,
    950,
    863,
    759,
    1008,
    510,
    979,
    953,
    763,
    974,
    954,
    879,
    981,
    982,
    927,
    995,
    765,
    956,
    887,
    985,
    997,
    986,
    943,
    891,
    998,
    766,
    511,
    988,
    1001,
    951,
    1002,
    893,
    975,
    894,
    1009,
    955,
    1004,
    1010,
    957,
    983,
    958,
    987,
    1012,
    999,
    1016,
    767,
    989,
    1003,
    990,
    1005,
    959,
    1011,
    1013,
    895,
    1006,
    1014,
    1017,
    1018,
    991,
    1020,
    1007,
    1015,
    1019,
    1021,
    1022,
    1023,
]

# Rate matching table (table 23): rate_str -> {N: K}
RATE_TABLE = {
    "1/4": {1024: 256, 512: 96, 256: 48, 128: 24, 64: 12},
    "3/8": {1024: 384, 512: 160, 256: 80, 128: 40, 64: 20},
    "1/2": {1024: 512, 512: 224, 256: 112, 128: 56, 64: 28},
    "5/8": {1024: 640, 512: 288, 256: 144, 128: 72, 64: 36},
    "3/4": {1024: 768, 512: 352, 256: 176, 128: 88, 64: 44},
    "7/8": {1024: 896, 512: 416, 256: 208, 128: 104, 64: 52},
}

VALID_CODE_LENGTHS = {64, 128, 256, 512, 1024}


[文档] def get_info_bit_count(rate_str: str, N: int) -> int: """从速率匹配表中查询给定码率和码长对应的信息位数 K。""" if rate_str not in RATE_TABLE: raise ValueError(f"Unsupported rate '{rate_str}'. Valid: {list(RATE_TABLE.keys())}") if N not in RATE_TABLE[rate_str]: raise ValueError(f"Unsupported code length {N} for rate '{rate_str}'.") return RATE_TABLE[rate_str][N]
@lru_cache(maxsize=16) def _get_reliability_sequence(N: int) -> tuple[int, ...]: """从 N_max=1024 可靠性序列中提取码长 N 的可靠性序列。 过滤比特索引 < N 的条目,保持可靠性顺序不变。 """ return tuple(q for q in RELIABILITY_SEQ_1024 if q < N) def _get_frozen_and_info_sets(N: int, K: int) -> tuple[set[int], set[int]]: """确定冻结位和信息位的位置集合。 可靠性序列中最后 K 个(最可靠)为信息位, 前 N-K 个(最不可靠)为冻结位。 """ seq = _get_reliability_sequence(N) frozen_set = set(seq[: N - K]) info_set = set(seq[N - K :]) return frozen_set, info_set
[文档] class PolarEncoder: """使用 GF(2) 上蝶形(递归)算法的 Polar 编码器。""" def __init__(self, N: int, K: int) -> None: if N not in VALID_CODE_LENGTHS: raise ValueError(f"N must be in {sorted(VALID_CODE_LENGTHS)}, got {N}") if K >= N or K < 1: raise ValueError(f"K must satisfy 1 <= K < N, got K={K}, N={N}") self.N = N self.K = K self.frozen_set, self.info_set = _get_frozen_and_info_sets(N, K) seq = _get_reliability_sequence(N) self.info_positions = sorted(seq[N - K :]) self._info_pos_arr = np.array(self.info_positions, dtype=np.intp) # Rust 加速 self._rust_encoder = None if _HAS_RUST_ACCEL: self._rust_encoder = _RustPolarEncoder( N, np.array(self.info_positions, dtype=np.int64), )
[文档] def encode(self, info_bits: np.ndarray) -> np.ndarray: """将 K 个信息位编码为 N 个编码位。 :param info_bits: 长度为 K 的比特数组(取值 0 或 1)。 :returns: 长度为 N 的码字(取值 0 或 1)。 """ info_bits = np.asarray(info_bits, dtype=np.int8) if info_bits.shape != (self.K,): raise ValueError(f"Expected {self.K} info bits, got shape {info_bits.shape}") if self._rust_encoder is not None: return np.asarray(self._rust_encoder.encode(info_bits), dtype=np.int8) # Build encoder input u — vectorized insert u = np.zeros(self.N, dtype=np.int8) u[self._info_pos_arr] = info_bits # Butterfly encoding: vectorized GF(2) transform d = u.copy() stage = 1 while stage < self.N: for j in range(0, self.N, 2 * stage): d[j:j + stage] ^= d[j + stage:j + 2 * stage] stage <<= 1 return d
[文档] class PolarDecoder: """Successive Cancellation (SC) decoder in LLR domain。 使用预分配的二维数组存储中间 LLR 和 partial sums, 避免递归中频繁分配内存。 """ def __init__(self, N: int, K: int) -> None: if N not in VALID_CODE_LENGTHS: raise ValueError(f"N must be in {sorted(VALID_CODE_LENGTHS)}, got {N}") if K >= N or K < 1: raise ValueError(f"K must satisfy 1 <= K < N, got K={K}, N={N}") self.N = N self.K = K self.n = int(np.log2(N)) self.frozen_set, self.info_set = _get_frozen_and_info_sets(N, K) seq = _get_reliability_sequence(N) self.info_positions = sorted(seq[N - K :]) self._info_pos_arr = np.array(self.info_positions, dtype=np.intp) self._is_frozen = np.ones(N, dtype=bool) for pos in self.info_positions: self._is_frozen[pos] = False # 预分配 (n+1) 层, 每层 N 元素 self._L = np.zeros((self.n + 1, N), dtype=np.float64) self._B = np.zeros((self.n + 1, N), dtype=np.int8) # SSC: 预计算每个子树的类型 (rate-0 / rate-1 / partial) self._node_type = self._build_node_types() # Rust 加速: 可用时创建 Rust 解码器实例 self._rust_decoder = None if _HAS_RUST_ACCEL: self._rust_decoder = _RustPolarDecoder( self.n, self._is_frozen, np.array(self.info_positions, dtype=np.int64), ) def _build_node_types(self) -> dict[tuple[int, int], int]: """预计算所有子树的类型: 0=rate-0, 1=rate-1, 2=partial。 使用前缀和做 O(1) 范围查询, 迭代遍历避免递归开销。 rate-0/rate-1 子树不再向下遍历, 与 SSC 剪枝对齐。 """ N = self.N frozen = self._is_frozen # 前缀和: prefix[i] = sum(frozen[0:i]) prefix = np.empty(N + 1, dtype=np.int32) prefix[0] = 0 np.cumsum(frozen, out=prefix[1:]) node_type: dict[tuple[int, int], int] = {} stack = [(0, N)] while stack: start, length = stack.pop() cnt = int(prefix[start + length] - prefix[start]) if cnt == length: node_type[(start, length)] = 0 elif cnt == 0: node_type[(start, length)] = 1 else: node_type[(start, length)] = 2 if length > 1: half = length >> 1 stack.append((start, half)) stack.append((start + half, half)) return node_type
[文档] def decode(self, llr: np.ndarray) -> np.ndarray: """使用 SC 算法将 N 个信道 LLR 解码为 K 个信息位。 约定:正 LLR 表示比特 0 的可能性更高。 """ llr = np.asarray(llr, dtype=np.float64) if llr.shape != (self.N,): raise ValueError(f"Expected {self.N} LLRs, got shape {llr.shape}") if self._rust_decoder is not None: return np.asarray(self._rust_decoder.decode(llr), dtype=np.int8) self._L[0, :self.N] = llr self._sc(0, self.N, 0) return self._B[self.n, :self.N][self._info_pos_arr].copy()
def _sc(self, start: int, length: int, depth: int) -> None: """SSC 优化的 SC 解码。 rate-0 子树: 所有位为冻结位, 直接置零。 rate-1 子树: 硬判决 LLR 得到编码位, 极性变换得到信息位。 partial 子树: 标准 f/g 递归。 """ L = self._L B = self._B node_t = self._node_type.get((start, length), 2) if node_t == 0: # rate-0: 全冻结 → 解码位和部分和均为零 end = start + length B[self.n, start:end] = 0 B[depth, start:end] = 0 return if node_t == 1: # rate-1: 全信息位 end = start + length # 硬判决得到该层的编码位 x x = np.where(L[depth, start:end] < 0, 1, 0).astype(np.int8) # 部分和 = 编码位 (供父节点 g 操作和合并使用) B[depth, start:end] = x if length == 1: B[self.n, start] = x[0] else: # 极性变换 x → u (F^⊗n 自逆) u = x.copy() step = length while step >= 2: half_s = step >> 1 u_view = u.reshape(-1, step) u_view[:, :half_s] ^= u_view[:, half_s:] step >>= 1 B[self.n, start:end] = u return # partial: 标准 SC 递归 if length == 1: if self._is_frozen[start]: B[self.n, start] = 0 else: B[self.n, start] = 0 if L[depth, start] >= 0.0 else 1 return half = length >> 1 mid = start + half end = start + length d1 = depth + 1 # f 操作 a = L[depth, start:mid] b = L[depth, mid:end] dst = L[d1, start:mid] np.minimum(np.abs(a), np.abs(b), out=dst) dst *= np.sign(a) * np.sign(b) self._sc(start, half, d1) # g 操作 left_bits = B[d1, start:mid] dst = L[d1, mid:end] np.subtract(1.0, 2.0 * left_bits, out=dst) dst *= a dst += b self._sc(mid, half, d1) # 合并 partial sums left = B[d1, start:mid] right = B[d1, mid:end] B[depth, start:mid] = left ^ right B[depth, mid:end] = right
# --------------------------------------------------------------------------- # PolarDecoder 实例缓存 (避免重复初始化 + SSC 构建开销) # --------------------------------------------------------------------------- _decoder_cache: dict[tuple[int, int], PolarDecoder] = {}
[文档] def get_polar_decoder(n: int, k: int) -> PolarDecoder: """获取 PolarDecoder 实例, 相同 (N, K) 复用已创建的对象。""" key = (n, k) dec = _decoder_cache.get(key) if dec is None: dec = PolarDecoder(n, k) _decoder_cache[key] = dec return dec