nearlink_sdr.phy.equalizer 源代码

"""TXS-10002-2025 均衡器: ZF / MMSE 频域与时域均衡。

对于 SparkLink SLE 的 PSK 调制 (帧类型 2-4),接收端需要
信道均衡来对抗多径衰落引起的符号间干扰 (ISI)。
"""


__all__ = [
    "equalize_1tap",
    "equalize_mmse_freq",
    "equalize_mmse_time",
    "equalize_zf",
    "estimate_channel_freq",
]


import numpy as np


[文档] def equalize_zf(rx_signal: np.ndarray, channel_freq: np.ndarray) -> np.ndarray: """零强制 (ZF) 频域均衡。 :param rx_signal: 接收信号 (时域)。 :param channel_freq: 信道频率响应 H[k], 长度等于 rx_signal 的 FFT 长度。 :returns: 均衡后的时域信号。 """ n = len(rx_signal) rx_freq = np.fft.fft(rx_signal, n) # 避免除零 h_safe = np.where(np.abs(channel_freq[:n]) > 1e-10, channel_freq[:n], 1e-10) eq_freq = rx_freq / h_safe return np.fft.ifft(eq_freq, n)
[文档] def equalize_mmse_freq( rx_signal: np.ndarray, channel_freq: np.ndarray, noise_var: float, ) -> np.ndarray: """MMSE 频域均衡。 H_mmse[k] = conj(H[k]) / (|H[k]|^2 + sigma^2) :param rx_signal: 接收信号 (时域)。 :param channel_freq: 信道频率响应 H[k]。 :param noise_var: 噪声方差 sigma^2。 :returns: 均衡后的时域信号。 """ n = len(rx_signal) rx_freq = np.fft.fft(rx_signal, n) h = channel_freq[:n] w = np.conj(h) / (np.abs(h) ** 2 + noise_var) eq_freq = rx_freq * w return np.fft.ifft(eq_freq, n)
[文档] def estimate_channel_freq( rx_signal: np.ndarray, tx_known: np.ndarray, n_fft: int | None = None, ) -> np.ndarray: """基于已知训练序列的频域信道估计 (LS)。 H_est[k] = Y[k] / X[k] :param rx_signal: 接收到的训练序列。 :param tx_known: 已知的发送训练序列。 :param n_fft: FFT 长度, 默认为信号长度。 :returns: 估计的信道频率响应。 """ if n_fft is None: n_fft = max(len(rx_signal), len(tx_known)) rx_freq = np.fft.fft(rx_signal, n_fft) tx_freq = np.fft.fft(tx_known, n_fft) # LS 估计, 避免除零 tx_safe = np.where(np.abs(tx_freq) > 1e-10, tx_freq, 1e-10) h_est = rx_freq / tx_safe return h_est
[文档] def equalize_mmse_time( rx_signal: np.ndarray, channel_taps: np.ndarray, noise_var: float, n_taps_eq: int = 11, ) -> np.ndarray: """MMSE 时域均衡器 (FIR Wiener 滤波器)。 计算 MMSE 最优 FIR 滤波器系数: w = R^{-1} p 其中 R = h自相关 + 噪声, p = h与期望延迟的互相关。 :param rx_signal: 接收信号。 :param channel_taps: 信道冲激响应 h[0], h[1], ..., h[L-1]。 :param noise_var: 噪声方差。 :param n_taps_eq: 均衡器 FIR 长度。 :returns: 均衡后的信号。 """ h = np.asarray(channel_taps, dtype=complex).flatten() l_h = len(h) l_eq = n_taps_eq # 自相关: r_hh[k] = sum_l h[l] * conj(h[l-k]) r_hh = np.correlate(h, h, mode='full') center_rh = l_h - 1 # 构造 Toeplitz R 矩阵 R = np.zeros((l_eq, l_eq), dtype=complex) for i in range(l_eq): for j in range(l_eq): lag = i - j idx = center_rh + lag if 0 <= idx < len(r_hh): R[i, j] = r_hh[idx] R += noise_var * np.eye(l_eq) # 最优延迟 (mode='same' 引入的组延迟补偿) delay = (l_eq - 1) // 2 # 互相关向量 p[i] = conj(h[delay - i]) p = np.zeros(l_eq, dtype=complex) for i in range(l_eq): k = delay - i if 0 <= k < l_h: p[i] = np.conj(h[k]) # 求解 Wiener-Hopf 方程 w = np.linalg.solve(R, p) # 应用 FIR 滤波 return np.convolve(rx_signal, w, mode='same')
[文档] def equalize_1tap( rx_symbols: np.ndarray, h_coeffs: np.ndarray, noise_var: float = 0.0, method: str = "mmse", ) -> np.ndarray: """逐符号 1-tap 均衡 (平坦衰落信道)。 :param rx_symbols: 接收符号。 :param h_coeffs: 每个符号的信道系数 (与 rx_symbols 等长)。 :param noise_var: 噪声方差, 仅 MMSE 使用。 :param method: "zf" 或 "mmse"。 :returns: 均衡后的符号。 """ h = np.asarray(h_coeffs, dtype=complex) if method == "zf": h_safe = np.where(np.abs(h) > 1e-10, h, 1e-10 + 0j) return rx_symbols / h_safe elif method == "mmse": w = np.conj(h) / (np.abs(h) ** 2 + noise_var) return rx_symbols * w else: raise ValueError(f"Unknown method: {method}")