3918 lines
188 KiB
Plaintext
3918 lines
188 KiB
Plaintext
{
|
||
"cells": [
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 🎲 改进的随机批次生成器\n",
|
||
"\n",
|
||
"这个版本改进了数据生成策略:\n",
|
||
"- **随机文件选择**: 每次从所有训练文件中随机选择 n=4 个文件\n",
|
||
"- **随机样本采样**: 从选中的文件中随机采样指定数量的样本\n",
|
||
"- **提高数据多样性**: 避免按固定顺序处理文件,减少过拟合风险\n",
|
||
"- **可控批次大小**: 固定每批次样本数,确保训练稳定性"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# 环境配置与Utils"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 1,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Looking in indexes: https://download.pytorch.org/whl/cu126\n",
|
||
"Requirement already satisfied: torch in /usr/local/lib/python3.11/dist-packages (2.6.0+cu124)\n",
|
||
"Requirement already satisfied: torchvision in /usr/local/lib/python3.11/dist-packages (0.21.0+cu124)\n",
|
||
"Requirement already satisfied: torchaudio in /usr/local/lib/python3.11/dist-packages (2.6.0+cu124)\n",
|
||
"Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch) (3.18.0)\n",
|
||
"Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.11/dist-packages (from torch) (4.14.0)\n",
|
||
"Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch) (3.5)\n",
|
||
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch) (3.1.6)\n",
|
||
"Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch) (2025.5.1)\n",
|
||
"Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n",
|
||
"Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n",
|
||
"Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n",
|
||
"Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.11/dist-packages (from torch) (9.1.0.70)\n",
|
||
"Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.5.8)\n",
|
||
"Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /usr/local/lib/python3.11/dist-packages (from torch) (11.2.1.3)\n",
|
||
"Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /usr/local/lib/python3.11/dist-packages (from torch) (10.3.5.147)\n",
|
||
"Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /usr/local/lib/python3.11/dist-packages (from torch) (11.6.1.9)\n",
|
||
"Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /usr/local/lib/python3.11/dist-packages (from torch) (12.3.1.170)\n",
|
||
"Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch) (0.6.2)\n",
|
||
"Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch) (2.21.5)\n",
|
||
"Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n",
|
||
"Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n",
|
||
"Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch) (3.2.0)\n",
|
||
"Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch) (1.13.1)\n",
|
||
"Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch) (1.3.0)\n",
|
||
"Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from torchvision) (1.26.4)\n",
|
||
"Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.11/dist-packages (from torchvision) (11.2.1)\n",
|
||
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch) (3.0.2)\n",
|
||
"Requirement already satisfied: mkl_fft in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (1.3.8)\n",
|
||
"Requirement already satisfied: mkl_random in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (1.2.4)\n",
|
||
"Requirement already satisfied: mkl_umath in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (0.1.1)\n",
|
||
"Requirement already satisfied: mkl in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (2025.2.0)\n",
|
||
"Requirement already satisfied: tbb4py in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (2022.2.0)\n",
|
||
"Requirement already satisfied: mkl-service in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (2.4.1)\n",
|
||
"Requirement already satisfied: intel-openmp<2026,>=2024 in /usr/local/lib/python3.11/dist-packages (from mkl->numpy->torchvision) (2024.2.0)\n",
|
||
"Requirement already satisfied: tbb==2022.* in /usr/local/lib/python3.11/dist-packages (from mkl->numpy->torchvision) (2022.2.0)\n",
|
||
"Requirement already satisfied: tcmlib==1.* in /usr/local/lib/python3.11/dist-packages (from tbb==2022.*->mkl->numpy->torchvision) (1.4.0)\n",
|
||
"Requirement already satisfied: intel-cmplr-lib-rt in /usr/local/lib/python3.11/dist-packages (from mkl_umath->numpy->torchvision) (2024.2.0)\n",
|
||
"Requirement already satisfied: intel-cmplr-lib-ur==2024.2.0 in /usr/local/lib/python3.11/dist-packages (from intel-openmp<2026,>=2024->mkl->numpy->torchvision) (2024.2.0)\n",
|
||
"Requirement already satisfied: jupyter==1.1.1 in /usr/local/lib/python3.11/dist-packages (1.1.1)\n",
|
||
"Requirement already satisfied: numpy<2.1.0,>=1.26.0 in /usr/local/lib/python3.11/dist-packages (1.26.4)\n",
|
||
"Requirement already satisfied: pandas==2.3.0 in /usr/local/lib/python3.11/dist-packages (2.3.0)\n",
|
||
"Requirement already satisfied: matplotlib==3.10.1 in /usr/local/lib/python3.11/dist-packages (3.10.1)\n",
|
||
"Requirement already satisfied: scipy==1.15.2 in /usr/local/lib/python3.11/dist-packages (1.15.2)\n",
|
||
"Requirement already satisfied: scikit-learn==1.6.1 in /usr/local/lib/python3.11/dist-packages (1.6.1)\n",
|
||
"Requirement already satisfied: lightgbm==4.3.0 in /usr/local/lib/python3.11/dist-packages (4.3.0)\n",
|
||
"Requirement already satisfied: tqdm==4.67.1 in /usr/local/lib/python3.11/dist-packages (4.67.1)\n",
|
||
"Requirement already satisfied: g2p_en==2.1.0 in /usr/local/lib/python3.11/dist-packages (2.1.0)\n",
|
||
"Requirement already satisfied: h5py==3.13.0 in /usr/local/lib/python3.11/dist-packages (3.13.0)\n",
|
||
"Requirement already satisfied: omegaconf==2.3.0 in /usr/local/lib/python3.11/dist-packages (2.3.0)\n",
|
||
"Requirement already satisfied: editdistance==0.8.1 in /usr/local/lib/python3.11/dist-packages (0.8.1)\n",
|
||
"Requirement already satisfied: huggingface-hub==0.33.1 in /usr/local/lib/python3.11/dist-packages (0.33.1)\n",
|
||
"Requirement already satisfied: transformers==4.53.0 in /usr/local/lib/python3.11/dist-packages (4.53.0)\n",
|
||
"Requirement already satisfied: tokenizers==0.21.2 in /usr/local/lib/python3.11/dist-packages (0.21.2)\n",
|
||
"Requirement already satisfied: accelerate==1.8.1 in /usr/local/lib/python3.11/dist-packages (1.8.1)\n",
|
||
"Requirement already satisfied: bitsandbytes==0.46.0 in /usr/local/lib/python3.11/dist-packages (0.46.0)\n",
|
||
"Requirement already satisfied: seaborn==0.13.2 in /usr/local/lib/python3.11/dist-packages (0.13.2)\n",
|
||
"Requirement already satisfied: notebook in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.5.4)\n",
|
||
"Requirement already satisfied: jupyter-console in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.1.0)\n",
|
||
"Requirement already satisfied: nbconvert in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.4.5)\n",
|
||
"Requirement already satisfied: ipykernel in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.17.1)\n",
|
||
"Requirement already satisfied: ipywidgets in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (8.1.5)\n",
|
||
"Requirement already satisfied: jupyterlab in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (3.6.8)\n",
|
||
"Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas==2.3.0) (2.9.0.post0)\n",
|
||
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas==2.3.0) (2025.2)\n",
|
||
"Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas==2.3.0) (2025.2)\n",
|
||
"Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (1.3.2)\n",
|
||
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (0.12.1)\n",
|
||
"Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (4.58.4)\n",
|
||
"Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (1.4.8)\n",
|
||
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (25.0)\n",
|
||
"Requirement already satisfied: pillow>=8 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (11.2.1)\n",
|
||
"Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (3.0.9)\n",
|
||
"Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn==1.6.1) (1.5.1)\n",
|
||
"Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn==1.6.1) (3.6.0)\n",
|
||
"Requirement already satisfied: nltk>=3.2.4 in /usr/local/lib/python3.11/dist-packages (from g2p_en==2.1.0) (3.9.1)\n",
|
||
"Requirement already satisfied: inflect>=0.3.1 in /usr/local/lib/python3.11/dist-packages (from g2p_en==2.1.0) (7.5.0)\n",
|
||
"Requirement already satisfied: distance>=0.1.3 in /usr/local/lib/python3.11/dist-packages (from g2p_en==2.1.0) (0.1.3)\n",
|
||
"Requirement already satisfied: antlr4-python3-runtime==4.9.* in /usr/local/lib/python3.11/dist-packages (from omegaconf==2.3.0) (4.9.3)\n",
|
||
"Requirement already satisfied: PyYAML>=5.1.0 in /usr/local/lib/python3.11/dist-packages (from omegaconf==2.3.0) (6.0.2)\n",
|
||
"Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (3.18.0)\n",
|
||
"Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (2025.5.1)\n",
|
||
"Requirement already satisfied: requests in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (2.32.4)\n",
|
||
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (4.14.0)\n",
|
||
"Requirement already satisfied: hf-xet<2.0.0,>=1.1.2 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (1.1.5)\n",
|
||
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.11/dist-packages (from transformers==4.53.0) (2024.11.6)\n",
|
||
"Requirement already satisfied: safetensors>=0.4.3 in /usr/local/lib/python3.11/dist-packages (from transformers==4.53.0) (0.5.3)\n",
|
||
"Requirement already satisfied: psutil in /usr/local/lib/python3.11/dist-packages (from accelerate==1.8.1) (7.0.0)\n",
|
||
"Requirement already satisfied: torch>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from accelerate==1.8.1) (2.6.0+cu124)\n",
|
||
"Requirement already satisfied: mkl_fft in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (1.3.8)\n",
|
||
"Requirement already satisfied: mkl_random in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (1.2.4)\n",
|
||
"Requirement already satisfied: mkl_umath in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (0.1.1)\n",
|
||
"Requirement already satisfied: mkl in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (2025.2.0)\n",
|
||
"Requirement already satisfied: tbb4py in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (2022.2.0)\n",
|
||
"Requirement already satisfied: mkl-service in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (2.4.1)\n",
|
||
"Requirement already satisfied: more_itertools>=8.5.0 in /usr/local/lib/python3.11/dist-packages (from inflect>=0.3.1->g2p_en==2.1.0) (10.7.0)\n",
|
||
"Requirement already satisfied: typeguard>=4.0.1 in /usr/local/lib/python3.11/dist-packages (from inflect>=0.3.1->g2p_en==2.1.0) (4.4.4)\n",
|
||
"Requirement already satisfied: click in /usr/local/lib/python3.11/dist-packages (from nltk>=3.2.4->g2p_en==2.1.0) (8.2.1)\n",
|
||
"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.8.2->pandas==2.3.0) (1.17.0)\n",
|
||
"Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (3.5)\n",
|
||
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (3.1.6)\n",
|
||
"Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n",
|
||
"Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n",
|
||
"Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n",
|
||
"Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (9.1.0.70)\n",
|
||
"Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.5.8)\n",
|
||
"Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (11.2.1.3)\n",
|
||
"Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (10.3.5.147)\n",
|
||
"Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (11.6.1.9)\n",
|
||
"Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.3.1.170)\n",
|
||
"Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (0.6.2)\n",
|
||
"Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (2.21.5)\n",
|
||
"Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n",
|
||
"Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n",
|
||
"Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (3.2.0)\n",
|
||
"Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (1.13.1)\n",
|
||
"Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch>=2.0.0->accelerate==1.8.1) (1.3.0)\n",
|
||
"Requirement already satisfied: debugpy>=1.0 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (1.8.0)\n",
|
||
"Requirement already satisfied: ipython>=7.23.1 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (7.34.0)\n",
|
||
"Requirement already satisfied: jupyter-client>=6.1.12 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (8.6.3)\n",
|
||
"Requirement already satisfied: matplotlib-inline>=0.1 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (0.1.7)\n",
|
||
"Requirement already satisfied: nest-asyncio in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (1.6.0)\n",
|
||
"Requirement already satisfied: pyzmq>=17 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (24.0.1)\n",
|
||
"Requirement already satisfied: tornado>=6.1 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (6.5.1)\n",
|
||
"Requirement already satisfied: traitlets>=5.1.0 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (5.7.1)\n",
|
||
"Requirement already satisfied: comm>=0.1.3 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->jupyter==1.1.1) (0.2.2)\n",
|
||
"Requirement already satisfied: widgetsnbextension~=4.0.12 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->jupyter==1.1.1) (4.0.14)\n",
|
||
"Requirement already satisfied: jupyterlab-widgets~=3.0.12 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->jupyter==1.1.1) (3.0.15)\n",
|
||
"Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-console->jupyter==1.1.1) (3.0.51)\n",
|
||
"Requirement already satisfied: pygments in /usr/local/lib/python3.11/dist-packages (from jupyter-console->jupyter==1.1.1) (2.19.2)\n",
|
||
"Requirement already satisfied: jupyter-core in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (5.8.1)\n",
|
||
"Requirement already satisfied: jupyterlab-server~=2.19 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (2.27.3)\n",
|
||
"Requirement already satisfied: jupyter-server<3,>=1.16.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (2.12.5)\n",
|
||
"Requirement already satisfied: jupyter-ydoc~=0.2.4 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (0.2.5)\n",
|
||
"Requirement already satisfied: jupyter-server-ydoc~=0.8.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (0.8.0)\n",
|
||
"Requirement already satisfied: nbclassic in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (1.3.1)\n",
|
||
"Requirement already satisfied: argon2-cffi in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (25.1.0)\n",
|
||
"Requirement already satisfied: ipython-genutils in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (0.2.0)\n",
|
||
"Requirement already satisfied: nbformat in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (5.10.4)\n",
|
||
"Requirement already satisfied: Send2Trash>=1.8.0 in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (1.8.3)\n",
|
||
"Requirement already satisfied: terminado>=0.8.3 in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (0.18.1)\n",
|
||
"Requirement already satisfied: prometheus-client in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (0.22.1)\n",
|
||
"Requirement already satisfied: mistune<2,>=0.8.1 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.8.4)\n",
|
||
"Requirement already satisfied: jupyterlab-pygments in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.3.0)\n",
|
||
"Requirement already satisfied: entrypoints>=0.2.2 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.4)\n",
|
||
"Requirement already satisfied: bleach in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (6.2.0)\n",
|
||
"Requirement already satisfied: pandocfilters>=1.4.1 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (1.5.1)\n",
|
||
"Requirement already satisfied: testpath in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.6.0)\n",
|
||
"Requirement already satisfied: defusedxml in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.7.1)\n",
|
||
"Requirement already satisfied: beautifulsoup4 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (4.13.4)\n",
|
||
"Requirement already satisfied: nbclient<0.6.0,>=0.5.0 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.5.13)\n",
|
||
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (3.0.2)\n",
|
||
"Requirement already satisfied: intel-openmp<2026,>=2024 in /usr/local/lib/python3.11/dist-packages (from mkl->numpy<2.1.0,>=1.26.0) (2024.2.0)\n",
|
||
"Requirement already satisfied: tbb==2022.* in /usr/local/lib/python3.11/dist-packages (from mkl->numpy<2.1.0,>=1.26.0) (2022.2.0)\n",
|
||
"Requirement already satisfied: tcmlib==1.* in /usr/local/lib/python3.11/dist-packages (from tbb==2022.*->mkl->numpy<2.1.0,>=1.26.0) (1.4.0)\n",
|
||
"Requirement already satisfied: intel-cmplr-lib-rt in /usr/local/lib/python3.11/dist-packages (from mkl_umath->numpy<2.1.0,>=1.26.0) (2024.2.0)\n",
|
||
"Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (3.4.2)\n",
|
||
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (3.10)\n",
|
||
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (2.5.0)\n",
|
||
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (2025.6.15)\n",
|
||
"Requirement already satisfied: intel-cmplr-lib-ur==2024.2.0 in /usr/local/lib/python3.11/dist-packages (from intel-openmp<2026,>=2024->mkl->numpy<2.1.0,>=1.26.0) (2024.2.0)\n",
|
||
"Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (75.2.0)\n",
|
||
"Requirement already satisfied: jedi>=0.16 in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.19.2)\n",
|
||
"Requirement already satisfied: decorator in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (4.4.2)\n",
|
||
"Requirement already satisfied: pickleshare in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.7.5)\n",
|
||
"Requirement already satisfied: backcall in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.2.0)\n",
|
||
"Requirement already satisfied: pexpect>4.3 in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (4.9.0)\n",
|
||
"Requirement already satisfied: platformdirs>=2.5 in /usr/local/lib/python3.11/dist-packages (from jupyter-core->jupyterlab->jupyter==1.1.1) (4.3.8)\n",
|
||
"Requirement already satisfied: anyio>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (4.9.0)\n",
|
||
"Requirement already satisfied: jupyter-events>=0.9.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.12.0)\n",
|
||
"Requirement already satisfied: jupyter-server-terminals in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.5.3)\n",
|
||
"Requirement already satisfied: overrides in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (7.7.0)\n",
|
||
"Requirement already satisfied: websocket-client in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.8.0)\n",
|
||
"Requirement already satisfied: jupyter-server-fileid<1,>=0.6.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (0.9.3)\n",
|
||
"Requirement already satisfied: ypy-websocket<0.9.0,>=0.8.2 in /usr/local/lib/python3.11/dist-packages (from jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (0.8.4)\n",
|
||
"Requirement already satisfied: y-py<0.7.0,>=0.6.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-ydoc~=0.2.4->jupyterlab->jupyter==1.1.1) (0.6.2)\n",
|
||
"Requirement already satisfied: babel>=2.10 in /usr/local/lib/python3.11/dist-packages (from jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (2.17.0)\n",
|
||
"Requirement already satisfied: json5>=0.9.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (0.12.0)\n",
|
||
"Requirement already satisfied: jsonschema>=4.18.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (4.24.0)\n",
|
||
"Requirement already satisfied: notebook-shim>=0.2.3 in /usr/local/lib/python3.11/dist-packages (from nbclassic->jupyterlab->jupyter==1.1.1) (0.2.4)\n",
|
||
"Requirement already satisfied: fastjsonschema>=2.15 in /usr/local/lib/python3.11/dist-packages (from nbformat->notebook->jupyter==1.1.1) (2.21.1)\n",
|
||
"Requirement already satisfied: wcwidth in /usr/local/lib/python3.11/dist-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->jupyter-console->jupyter==1.1.1) (0.2.13)\n",
|
||
"Requirement already satisfied: ptyprocess in /usr/local/lib/python3.11/dist-packages (from terminado>=0.8.3->notebook->jupyter==1.1.1) (0.7.0)\n",
|
||
"Requirement already satisfied: argon2-cffi-bindings in /usr/local/lib/python3.11/dist-packages (from argon2-cffi->notebook->jupyter==1.1.1) (21.2.0)\n",
|
||
"Requirement already satisfied: soupsieve>1.2 in /usr/local/lib/python3.11/dist-packages (from beautifulsoup4->nbconvert->jupyter==1.1.1) (2.7)\n",
|
||
"Requirement already satisfied: webencodings in /usr/local/lib/python3.11/dist-packages (from bleach->nbconvert->jupyter==1.1.1) (0.5.1)\n",
|
||
"Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.11/dist-packages (from anyio>=3.1.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.3.1)\n",
|
||
"Requirement already satisfied: parso<0.9.0,>=0.8.4 in /usr/local/lib/python3.11/dist-packages (from jedi>=0.16->ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.8.4)\n",
|
||
"Requirement already satisfied: attrs>=22.2.0 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (25.3.0)\n",
|
||
"Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (2025.4.1)\n",
|
||
"Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (0.36.2)\n",
|
||
"Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (0.25.1)\n",
|
||
"Requirement already satisfied: python-json-logger>=2.0.4 in /usr/local/lib/python3.11/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (3.3.0)\n",
|
||
"Requirement already satisfied: rfc3339-validator in /usr/local/lib/python3.11/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.1.4)\n",
|
||
"Requirement already satisfied: rfc3986-validator>=0.1.1 in /usr/local/lib/python3.11/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.1.1)\n",
|
||
"Requirement already satisfied: aiofiles<23,>=22.1.0 in /usr/local/lib/python3.11/dist-packages (from ypy-websocket<0.9.0,>=0.8.2->jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (22.1.0)\n",
|
||
"Requirement already satisfied: aiosqlite<1,>=0.17.0 in /usr/local/lib/python3.11/dist-packages (from ypy-websocket<0.9.0,>=0.8.2->jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (0.21.0)\n",
|
||
"Requirement already satisfied: cffi>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from argon2-cffi-bindings->argon2-cffi->notebook->jupyter==1.1.1) (1.17.1)\n",
|
||
"Requirement already satisfied: pycparser in /usr/local/lib/python3.11/dist-packages (from cffi>=1.0.1->argon2-cffi-bindings->argon2-cffi->notebook->jupyter==1.1.1) (2.22)\n",
|
||
"Requirement already satisfied: fqdn in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.5.1)\n",
|
||
"Requirement already satisfied: isoduration in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (20.11.0)\n",
|
||
"Requirement already satisfied: jsonpointer>1.13 in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (3.0.0)\n",
|
||
"Requirement already satisfied: uri-template in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.3.0)\n",
|
||
"Requirement already satisfied: webcolors>=24.6.0 in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (24.11.1)\n",
|
||
"Requirement already satisfied: arrow>=0.15.0 in /usr/local/lib/python3.11/dist-packages (from isoduration->jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.3.0)\n",
|
||
"Requirement already satisfied: types-python-dateutil>=2.8.10 in /usr/local/lib/python3.11/dist-packages (from arrow>=0.15.0->isoduration->jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (2.9.0.20250516)\n",
|
||
"Obtaining file:///kaggle/working/nejm-brain-to-text\n",
|
||
" Preparing metadata (setup.py): started\n",
|
||
" Preparing metadata (setup.py): finished with status 'done'\n",
|
||
"Installing collected packages: nejm_b2txt_utils\n",
|
||
" Running setup.py develop for nejm_b2txt_utils\n",
|
||
"Successfully installed nejm_b2txt_utils-0.0.0\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"Cloning into 'nejm-brain-to-text'...\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"%%bash\n",
|
||
"rm -rf /kaggle/working/nejm-brain-to-text/\n",
|
||
"git clone https://github.com/ZH-CEN/nejm-brain-to-text.git\n",
|
||
"cp /kaggle/input/brain-to-text-baseline-model/t15_copyTask.pkl /kaggle/working/nejm-brain-to-text/data/t15_copyTask.pkl\n",
|
||
"\n",
|
||
"ln -s /kaggle/input/brain-to-text-25/t15_pretrained_rnn_baseline/t15_pretrained_rnn_baseline /kaggle/working/nejm-brain-to-text/data\n",
|
||
"ln -s /kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final /kaggle/working/nejm-brain-to-text/data\n",
|
||
"ln -s /kaggle/input/rnn-pretagged-data /kaggle/working/nejm-brain-to-text/data/concatenated_data\n",
|
||
"\n",
|
||
"pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126\n",
|
||
"\n",
|
||
"pip install \\\n",
|
||
" jupyter==1.1.1 \\\n",
|
||
" \"numpy>=1.26.0,<2.1.0\" \\\n",
|
||
" pandas==2.3.0 \\\n",
|
||
" matplotlib==3.10.1 \\\n",
|
||
" scipy==1.15.2 \\\n",
|
||
" scikit-learn==1.6.1 \\\n",
|
||
" lightgbm==4.3.0 \\\n",
|
||
" tqdm==4.67.1 \\\n",
|
||
" g2p_en==2.1.0 \\\n",
|
||
" h5py==3.13.0 \\\n",
|
||
" omegaconf==2.3.0 \\\n",
|
||
" editdistance==0.8.1 \\\n",
|
||
" huggingface-hub==0.33.1 \\\n",
|
||
" transformers==4.53.0 \\\n",
|
||
" tokenizers==0.21.2 \\\n",
|
||
" accelerate==1.8.1 \\\n",
|
||
" bitsandbytes==0.46.0 \\\n",
|
||
" seaborn==0.13.2\n",
|
||
"cd /kaggle/working/nejm-brain-to-text/\n",
|
||
"pip install -e ."
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 2,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"==================================================\n",
|
||
"🔧 LightGBM GPU环境检查\n",
|
||
"==================================================\n",
|
||
"❌ 未检测到NVIDIA GPU或驱动\n",
|
||
"\n",
|
||
"✅ CUDA工具包:\n",
|
||
" Cuda compilation tools, release 12.5, V12.5.82\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# 🚀 LightGBM GPU支持检查与配置\n",
|
||
"\n",
|
||
"print(\"=\"*50)\n",
|
||
"print(\"🔧 LightGBM GPU环境检查\")\n",
|
||
"print(\"=\"*50)\n",
|
||
"\n",
|
||
"# 检查CUDA和GPU驱动\n",
|
||
"import subprocess\n",
|
||
"import sys\n",
|
||
"\n",
|
||
"def run_command(command):\n",
|
||
" \"\"\"运行命令并返回结果\"\"\"\n",
|
||
" try:\n",
|
||
" result = subprocess.run(command, shell=True, capture_output=True, text=True, timeout=10)\n",
|
||
" return result.stdout.strip(), result.returncode == 0\n",
|
||
" except Exception as e:\n",
|
||
" return str(e), False\n",
|
||
"\n",
|
||
"# 检查NVIDIA GPU\n",
|
||
"nvidia_output, nvidia_success = run_command(\"nvidia-smi --query-gpu=name,memory.total,driver_version --format=csv,noheader,nounits\")\n",
|
||
"if nvidia_success:\n",
|
||
" print(\"✅ NVIDIA GPU检测:\")\n",
|
||
" for line in nvidia_output.split('\\n'):\n",
|
||
" if line.strip():\n",
|
||
" print(f\" {line}\")\n",
|
||
"else:\n",
|
||
" print(\"❌ 未检测到NVIDIA GPU或驱动\")\n",
|
||
"\n",
|
||
"# 检查CUDA版本\n",
|
||
"cuda_output, cuda_success = run_command(\"nvcc --version\")\n",
|
||
"if cuda_success:\n",
|
||
" print(\"\\n✅ CUDA工具包:\")\n",
|
||
" # 提取CUDA版本\n",
|
||
" for line in cuda_output.split('\\n'):\n",
|
||
" if 'release' in line:\n",
|
||
" print(f\" {line.strip()}\")\n",
|
||
"else:\n",
|
||
" print(\"\\n❌ 未安装CUDA工具包\")\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 3,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"/kaggle/working/nejm-brain-to-text\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"%cd nejm-brain-to-text\n",
|
||
"import numpy as np\n",
|
||
"import os\n",
|
||
"import pickle\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"import matplotlib\n",
|
||
"from g2p_en import G2p\n",
|
||
"import pandas as pd\n",
|
||
"import numpy as np\n",
|
||
"from nejm_b2txt_utils.general_utils import *\n",
|
||
"matplotlib.rcParams['pdf.fonttype'] = 42\n",
|
||
"matplotlib.rcParams['ps.fonttype'] = 42\n",
|
||
"matplotlib.rcParams['font.family'] = 'sans-serif'\n",
|
||
"matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans', 'Arial Unicode MS', 'sans-serif']\n",
|
||
"matplotlib.rcParams['axes.unicode_minus'] = False\n"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 4,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"/kaggle/working/nejm-brain-to-text/model_training\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"%cd model_training/\n",
|
||
"from data_augmentations import gauss_smooth"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 5,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"LOGIT_TO_PHONEME = [\n",
|
||
" 'BLANK',\n",
|
||
" 'AA', 'AE', 'AH', 'AO', 'AW',\n",
|
||
" 'AY', 'B', 'CH', 'D', 'DH',\n",
|
||
" 'EH', 'ER', 'EY', 'F', 'G',\n",
|
||
" 'HH', 'IH', 'IY', 'JH', 'K',\n",
|
||
" 'L', 'M', 'N', 'NG', 'OW',\n",
|
||
" 'OY', 'P', 'R', 'S', 'SH',\n",
|
||
" 'T', 'TH', 'UH', 'UW', 'V',\n",
|
||
" 'W', 'Y', 'Z', 'ZH',\n",
|
||
" ' | ',\n",
|
||
"]\n",
|
||
"# 全局配置\n",
|
||
"BALANCE_CONFIG = {\n",
|
||
" 'enable_balance': True, # 是否启用数据平衡\n",
|
||
" 'undersample_labels': [0, 40], # 需要下采样的标签 (blank等高频标签)\n",
|
||
" 'oversample_threshold': 0.5, # 过采样阈值 (相对于均值的比例)\n",
|
||
" 'random_state': 42 # 随机种子\n",
|
||
"}\n",
|
||
"# 全局PCA配置\n",
|
||
"PCA_CONFIG = {\n",
|
||
" 'enable_pca': True, # 是否启用PCA\n",
|
||
" 'n_components': None, # None=自动选择, 或指定具体数值\n",
|
||
" 'variance_threshold': 0.95, # 保留95%的方差\n",
|
||
" 'sample_size': 15000, # 用于拟合PCA的样本数\n",
|
||
"}\n",
|
||
"\n",
|
||
"# 全局PCA对象 (确保只拟合一次)\n",
|
||
"GLOBAL_PCA = {\n",
|
||
" 'scaler': None,\n",
|
||
" 'pca': None,\n",
|
||
" 'is_fitted': False,\n",
|
||
" 'n_components': None\n",
|
||
"}\n",
|
||
"# 设置数据目录和参数【PCA初始化】\n",
|
||
"data_dir = '../data/concatenated_data'\n",
|
||
"MAX_SAMPLES_PER_FILE = -1 # 每个文件最大样本数,可调整"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# 数据读取工作流"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"# 2️⃣ 数据加载与PCA降维"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 6,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 🚀 内存友好的数据读取 - 分批加载策略 + PCA降维 【这里还缺一个采样】\n",
|
||
"import os\n",
|
||
"import numpy as np\n",
|
||
"import gc\n",
|
||
"from sklearn.decomposition import PCA\n",
|
||
"from sklearn.preprocessing import StandardScaler\n",
|
||
"import joblib\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"import random\n",
|
||
"\n",
|
||
"\n",
|
||
"def load_data_batch(data_dir, data_type, max_samples_per_file=5000, verbose=True, random_shuffle_files=True):\n",
|
||
" \"\"\"\n",
|
||
" 分批加载指定类型的数据,支持随机文件顺序\n",
|
||
" \n",
|
||
" Args:\n",
|
||
" data_dir: 数据目录\n",
|
||
" data_type: 'train', 'val', 'test'\n",
|
||
" max_samples_per_file: 每个文件最大加载样本数\n",
|
||
" verbose: 是否打印每个文件的加载进度\n",
|
||
" random_shuffle_files: 是否随机打乱文件加载顺序\n",
|
||
" \n",
|
||
" Returns:\n",
|
||
" generator: 数据批次生成器\n",
|
||
" \"\"\"\n",
|
||
" files = [f for f in os.listdir(data_dir) if f.endswith('.npz') and data_type in f]\n",
|
||
" \n",
|
||
" # 随机打乱文件顺序\n",
|
||
" if random_shuffle_files:\n",
|
||
" random.shuffle(files)\n",
|
||
" if verbose:\n",
|
||
" print(f\" 已随机打乱 {len(files)} 个文件的加载顺序\")\n",
|
||
" \n",
|
||
" for file_idx, f in enumerate(files):\n",
|
||
" if verbose:\n",
|
||
" print(f\" 正在加载文件 {file_idx+1}/{len(files)}: {f}\")\n",
|
||
" \n",
|
||
" data = np.load(os.path.join(data_dir, f), allow_pickle=True)\n",
|
||
" trials = data['neural_logits_concatenated']\n",
|
||
" \n",
|
||
" # 限制每个文件的样本数\n",
|
||
" if len(trials) > max_samples_per_file and max_samples_per_file != -1:\n",
|
||
" # 随机选择样本而不是只取前N个\n",
|
||
" random_indices = np.random.choice(len(trials), max_samples_per_file, replace=False)\n",
|
||
" trials = trials[random_indices]\n",
|
||
" if verbose:\n",
|
||
" print(f\" 随机采样样本数至: {max_samples_per_file}\")\n",
|
||
" \n",
|
||
" yield trials, f\n",
|
||
" \n",
|
||
" # 清理内存\n",
|
||
" del data, trials\n",
|
||
" gc.collect()\n",
|
||
"\n",
|
||
"def extract_features_labels_batch(trials_batch, random_shuffle_trials=True):\n",
|
||
" \"\"\"\n",
|
||
" 从试验批次中提取特征和标签,支持随机打乱试验顺序\n",
|
||
" \n",
|
||
" Args:\n",
|
||
" trials_batch: 试验批次数据\n",
|
||
" random_shuffle_trials: 是否随机打乱批次内的试验顺序\n",
|
||
" \"\"\"\n",
|
||
" features = []\n",
|
||
" labels = []\n",
|
||
" \n",
|
||
" # 随机打乱试验顺序\n",
|
||
" if random_shuffle_trials and len(trials_batch) > 1:\n",
|
||
" trial_indices = list(range(len(trials_batch)))\n",
|
||
" random.shuffle(trial_indices)\n",
|
||
" trials_batch = trials_batch[trial_indices]\n",
|
||
" \n",
|
||
" for trial in trials_batch:\n",
|
||
" if trial.shape[0] > 0:\n",
|
||
" # 随机打乱时间步顺序\n",
|
||
" time_indices = list(range(trial.shape[0]))\n",
|
||
" if random_shuffle_trials:\n",
|
||
" random.shuffle(time_indices)\n",
|
||
" \n",
|
||
" for t in time_indices:\n",
|
||
" neural_features = trial[t, :7168] # 前7168维神经特征\n",
|
||
" rnn_logits = trial[t, 7168:] # 后41维RNN输出\n",
|
||
" phoneme_label = np.argmax(rnn_logits)\n",
|
||
" \n",
|
||
" features.append(neural_features)\n",
|
||
" labels.append(phoneme_label)\n",
|
||
" \n",
|
||
" return np.array(features), np.array(labels)\n",
|
||
"\n",
|
||
"def fit_global_pca(data_dir, config):\n",
|
||
" \"\"\"\n",
|
||
" 在训练数据上拟合全局PCA (只执行一次)\n",
|
||
" \"\"\"\n",
|
||
" if GLOBAL_PCA['is_fitted'] or not config['enable_pca']:\n",
|
||
" print(\"PCA已拟合或未启用,跳过拟合步骤\")\n",
|
||
" return\n",
|
||
" \n",
|
||
" print(f\"拟合全局PCA降维器...\")\n",
|
||
" print(f\" 配置: {config}\")\n",
|
||
" \n",
|
||
" # 收集训练样本(使用随机加载)\n",
|
||
" sample_features = []\n",
|
||
" collected_samples = 0\n",
|
||
" \n",
|
||
" # 设置随机种子以确保可重现性\n",
|
||
" random.seed(42)\n",
|
||
" np.random.seed(42)\n",
|
||
" \n",
|
||
" for trials_batch, filename in load_data_batch(\n",
|
||
" data_dir, 'train', 5000, verbose=False, random_shuffle_files=True\n",
|
||
" ):\n",
|
||
" features, labels = extract_features_labels_batch(trials_batch, random_shuffle_trials=True)\n",
|
||
" sample_features.append(features)\n",
|
||
" collected_samples += features.shape[0]\n",
|
||
" \n",
|
||
" if collected_samples >= config['sample_size']:\n",
|
||
" break\n",
|
||
" \n",
|
||
" if sample_features:\n",
|
||
" # 合并样本数据\n",
|
||
" X_sample = np.vstack(sample_features)[:config['sample_size']]\n",
|
||
" \n",
|
||
" # 再次随机打乱样本顺序\n",
|
||
" shuffle_indices = np.random.permutation(len(X_sample))\n",
|
||
" X_sample = X_sample[shuffle_indices]\n",
|
||
" \n",
|
||
" print(f\" 实际样本数: {X_sample.shape[0]}\")\n",
|
||
" print(f\" 原始特征数: {X_sample.shape[1]}\")\n",
|
||
" \n",
|
||
" # 标准化\n",
|
||
" GLOBAL_PCA['scaler'] = StandardScaler()\n",
|
||
" X_sample_scaled = GLOBAL_PCA['scaler'].fit_transform(X_sample)\n",
|
||
" \n",
|
||
" # 确定PCA成分数\n",
|
||
" if config['n_components'] is None:\n",
|
||
" print(f\" 自动选择PCA成分数...\")\n",
|
||
" pca_full = PCA()\n",
|
||
" pca_full.fit(X_sample_scaled)\n",
|
||
" \n",
|
||
" cumsum_var = np.cumsum(pca_full.explained_variance_ratio_)\n",
|
||
" optimal_components = np.argmax(cumsum_var >= config['variance_threshold']) + 1\n",
|
||
" GLOBAL_PCA['n_components'] = min(optimal_components, X_sample.shape[1])\n",
|
||
" \n",
|
||
" print(f\" 保留{config['variance_threshold']*100}%方差需要: {optimal_components} 个成分\")\n",
|
||
" print(f\" 选择成分数: {GLOBAL_PCA['n_components']}\")\n",
|
||
" else:\n",
|
||
" GLOBAL_PCA['n_components'] = config['n_components']\n",
|
||
" print(f\" 使用指定成分数: {GLOBAL_PCA['n_components']}\")\n",
|
||
" \n",
|
||
" # 拟合最终PCA\n",
|
||
" GLOBAL_PCA['pca'] = PCA(n_components=GLOBAL_PCA['n_components'], random_state=42)\n",
|
||
" GLOBAL_PCA['pca'].fit(X_sample_scaled)\n",
|
||
" GLOBAL_PCA['is_fitted'] = True\n",
|
||
" \n",
|
||
" # 保存模型\n",
|
||
" pca_path = \"global_pca_model.joblib\"\n",
|
||
" joblib.dump({\n",
|
||
" 'scaler': GLOBAL_PCA['scaler'], \n",
|
||
" 'pca': GLOBAL_PCA['pca'],\n",
|
||
" 'n_components': GLOBAL_PCA['n_components']\n",
|
||
" }, pca_path)\n",
|
||
" \n",
|
||
" print(f\" 全局PCA拟合完成\")\n",
|
||
" print(f\" 降维: {X_sample.shape[1]} → {GLOBAL_PCA['n_components']}\")\n",
|
||
" print(f\" 降维比例: {GLOBAL_PCA['n_components']/X_sample.shape[1]:.2%}\")\n",
|
||
" print(f\" 保留方差: {GLOBAL_PCA['pca'].explained_variance_ratio_.sum():.4f}\")\n",
|
||
" print(f\" 模型已保存: {pca_path}\")\n",
|
||
" \n",
|
||
" # 清理样本数据\n",
|
||
" del sample_features, X_sample, X_sample_scaled\n",
|
||
" gc.collect()\n",
|
||
" else:\n",
|
||
" print(\"无法收集样本数据用于PCA拟合\")\n",
|
||
"\n",
|
||
"def apply_pca_transform(features):\n",
|
||
" \"\"\"\n",
|
||
" 应用全局PCA变换\n",
|
||
" \"\"\"\n",
|
||
" if not PCA_CONFIG['enable_pca'] or not GLOBAL_PCA['is_fitted']:\n",
|
||
" return features\n",
|
||
" \n",
|
||
" # 标准化 + PCA变换\n",
|
||
" features_scaled = GLOBAL_PCA['scaler'].transform(features)\n",
|
||
" features_pca = GLOBAL_PCA['pca'].transform(features_scaled)\n",
|
||
" return features_pca"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 📊 数据平衡策略 - 标签分布分析与采样优化"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 7,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 【数据平衡核心实现】\n",
|
||
"def balance_dataset(X, y):\n",
|
||
" \"\"\"\n",
|
||
" 对数据集进行平衡处理:只做下采样到第三小的样本数目\n",
|
||
" \n",
|
||
" Args:\n",
|
||
" X: 特征数据\n",
|
||
" y: 标签数据\n",
|
||
" config: 平衡配置\n",
|
||
" \n",
|
||
" Returns:\n",
|
||
" X_balanced, y_balanced: 平衡后的数据\n",
|
||
" \"\"\"\n",
|
||
" if not config['enable_balance']:\n",
|
||
" print(\"🔕 数据平衡已禁用,返回原始数据\")\n",
|
||
" return X, y\n",
|
||
" \n",
|
||
" print(f\"\\n⚖️ 开始数据平衡处理(只下采样到第三小样本数)...\")\n",
|
||
" print(f\" 原始数据: {X.shape[0]:,} 样本\")\n",
|
||
" \n",
|
||
" # 分析当前分布,找到第三小的样本数\n",
|
||
" label_counts = Counter(y)\n",
|
||
" all_counts = [label_counts.get(i, 0) for i in range(41)] # 所有标签的样本数\n",
|
||
" non_zero_counts = [count for count in all_counts if count > 0] # 去除0样本的标签\n",
|
||
" \n",
|
||
" # 排序找到第三小的样本数\n",
|
||
" sorted_counts = sorted(non_zero_counts)\n",
|
||
" if len(sorted_counts) >= 3:\n",
|
||
" third_smallest_count = sorted_counts[2] # 第三小(索引2)\n",
|
||
" elif len(sorted_counts) >= 2:\n",
|
||
" third_smallest_count = sorted_counts[1] # 如果不足3个,用第二小\n",
|
||
" else:\n",
|
||
" third_smallest_count = sorted_counts[0] if sorted_counts else 1 # 如果不足2个,用最小的\n",
|
||
" \n",
|
||
" print(f\" 所有标签样本数: {sorted_counts[:10]}{'...' if len(sorted_counts) > 10 else ''}\")\n",
|
||
" print(f\" 第三小样本数: {third_smallest_count}\")\n",
|
||
" print(f\" 下采样策略: 所有标签都下采样到 {third_smallest_count}\")\n",
|
||
" \n",
|
||
" # 准备平衡后的数据\n",
|
||
" X_balanced = []\n",
|
||
" y_balanced = []\n",
|
||
" \n",
|
||
" random.seed(config['random_state'])\n",
|
||
" np.random.seed(config['random_state'])\n",
|
||
" \n",
|
||
" for label in range(41):\n",
|
||
" # 获取当前标签的所有样本\n",
|
||
" label_mask = (y == label)\n",
|
||
" X_label = X[label_mask]\n",
|
||
" y_label = y[label_mask]\n",
|
||
" current_count = len(y_label)\n",
|
||
" \n",
|
||
" if current_count == 0:\n",
|
||
" continue\n",
|
||
" \n",
|
||
" # 下采样到第三小样本数\n",
|
||
" if current_count > third_smallest_count:\n",
|
||
" # 随机下采样\n",
|
||
" indices = np.random.choice(current_count, third_smallest_count, replace=False)\n",
|
||
" X_resampled = X_label[indices]\n",
|
||
" y_resampled = y_label[indices]\n",
|
||
" print(f\" 📉 标签 {label}: {current_count} → {third_smallest_count} (下采样)\")\n",
|
||
" else:\n",
|
||
" # 保持所有样本(不进行上采样)\n",
|
||
" X_resampled = X_label\n",
|
||
" y_resampled = y_label\n",
|
||
" print(f\" ✅ 标签 {label}: {current_count} (保持不变)\")\n",
|
||
" \n",
|
||
" X_balanced.append(X_resampled)\n",
|
||
" y_balanced.append(y_resampled)\n",
|
||
" \n",
|
||
" # 合并所有平衡后的数据\n",
|
||
" X_balanced = np.vstack(X_balanced)\n",
|
||
" y_balanced = np.hstack(y_balanced)\n",
|
||
" \n",
|
||
" # 随机打乱\n",
|
||
" shuffle_indices = np.random.permutation(len(y_balanced))\n",
|
||
" X_balanced = X_balanced[shuffle_indices]\n",
|
||
" y_balanced = y_balanced[shuffle_indices]\n",
|
||
" \n",
|
||
" # 统计最终结果\n",
|
||
" final_counts = Counter(y_balanced)\n",
|
||
" print(f\"\\n ✅ 下采样完成: {X_balanced.shape[0]:,} 样本\")\n",
|
||
" print(f\" 数据变化: {X.shape[0]:,} → {X_balanced.shape[0]:,} ({X_balanced.shape[0]/X.shape[0]:.2f}x)\")\n",
|
||
" print(f\" 最终各标签样本数分布:\")\n",
|
||
" for label in range(41):\n",
|
||
" count = final_counts.get(label, 0)\n",
|
||
" if count > 0:\n",
|
||
" print(f\" 标签 {label}: {count}\")\n",
|
||
" \n",
|
||
" return X_balanced, y_balanced"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 🔄 集成数据平衡的内存友好数据加载器"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 🧪 数据平衡效果测试"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 🚀 改进版智能数据处理管道"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 8,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"🚀 创建智能数据处理管道...\n",
|
||
"✅ 管道创建完成,准备执行步骤1...\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# 🚀 改进版智能数据处理管道【没有解决分批训练的问题】\n",
|
||
"# 流程:分析分布 → 确定采样比率 → 拟合PCA(只下采样) → 数据处理(下采样+上采样+PCA)\n",
|
||
"\n",
|
||
"import numpy as np\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"from collections import Counter\n",
|
||
"from sklearn.utils import resample\n",
|
||
"from sklearn.decomposition import PCA\n",
|
||
"from sklearn.preprocessing import StandardScaler\n",
|
||
"import joblib\n",
|
||
"import random\n",
|
||
"import gc\n",
|
||
"\n",
|
||
"class SmartDataPipeline:\n",
|
||
" \"\"\"\n",
|
||
" 智能数据处理管道\n",
|
||
" 步骤1: 分析数据分布,确定采样策略\n",
|
||
" 步骤2: 仅下采样拟合PCA参数\n",
|
||
" 步骤3: 数据处理时应用完整采样+PCA降维\n",
|
||
" \"\"\"\n",
|
||
" \n",
|
||
" def __init__(self, data_dir, random_state=42):\n",
|
||
" self.data_dir = data_dir\n",
|
||
" self.random_state = random_state\n",
|
||
" \n",
|
||
" # 步骤1: 分布分析结果\n",
|
||
" self.distribution_analysis = None\n",
|
||
" self.sampling_strategy = None\n",
|
||
" \n",
|
||
" # 步骤2: PCA参数(基于下采样数据拟合)\n",
|
||
" self.pca_scaler = None\n",
|
||
" self.pca_model = None\n",
|
||
" self.pca_components = None\n",
|
||
" self.pca_fitted = False\n",
|
||
" \n",
|
||
" # 配置参数\n",
|
||
" self.undersample_labels = [0, 40] # 需要下采样的标签\n",
|
||
" self.oversample_threshold = 0.5 # 过采样阈值(相对于均值)\n",
|
||
" self.pca_variance_threshold = 0.95 # PCA保留方差比例\n",
|
||
" self.pca_sample_size = 15000 # PCA拟合样本数\n",
|
||
" \n",
|
||
" # def step1_analyze_distribution(self, max_samples=100000):\n",
|
||
" # \"\"\"\n",
|
||
" # 步骤1: 分析数据分布,确定采样策略\n",
|
||
" # \"\"\"\n",
|
||
" # print(\"🔍 步骤1: 分析数据分布...\")\n",
|
||
" \n",
|
||
" # # 分析验证集分布(代表整体分布特征)\n",
|
||
" # all_labels = []\n",
|
||
" # for trials_batch, filename in load_data_batch(self.data_dir, 'val', 5000):\n",
|
||
" # _, labels = extract_features_labels_batch(trials_batch)\n",
|
||
" # all_labels.extend(labels.tolist())\n",
|
||
" # if len(all_labels) >= max_samples:\n",
|
||
" # break\n",
|
||
" \n",
|
||
" # # 统计分析\n",
|
||
" # label_counts = Counter(all_labels)\n",
|
||
" \n",
|
||
" # # 计算1-39标签的均值(排除0和40)\n",
|
||
" # counts_1_39 = [label_counts.get(i, 0) for i in range(1, 40)]\n",
|
||
" # target_mean = np.mean(counts_1_39)\n",
|
||
" \n",
|
||
" # # 生成采样策略\n",
|
||
" # sampling_strategy = {}\n",
|
||
" # for label in range(41):\n",
|
||
" # current_count = label_counts.get(label, 0)\n",
|
||
" \n",
|
||
" # if label in self.undersample_labels:\n",
|
||
" # # 下采样到均值水平\n",
|
||
" # target_count = int(target_mean)\n",
|
||
" # action = 'undersample' if current_count > target_count else 'keep'\n",
|
||
" # elif current_count < target_mean * self.oversample_threshold:\n",
|
||
" # # 过采样到阈值水平\n",
|
||
" # target_count = int(target_mean * self.oversample_threshold)\n",
|
||
" # action = 'oversample' if current_count < target_count else 'keep'\n",
|
||
" # else:\n",
|
||
" # # 保持不变\n",
|
||
" # target_count = current_count\n",
|
||
" # action = 'keep'\n",
|
||
" \n",
|
||
" # sampling_strategy[label] = {\n",
|
||
" # 'current_count': current_count,\n",
|
||
" # 'target_count': target_count,\n",
|
||
" # 'action': action\n",
|
||
" # }\n",
|
||
" \n",
|
||
" # self.distribution_analysis = {\n",
|
||
" # 'label_counts': label_counts,\n",
|
||
" # 'target_mean': target_mean,\n",
|
||
" # 'total_samples': len(all_labels)\n",
|
||
" # }\n",
|
||
" # self.sampling_strategy = sampling_strategy\n",
|
||
" \n",
|
||
" # print(f\" ✅ 分析完成: {len(all_labels):,} 样本\")\n",
|
||
" # print(f\" 📊 标签1-39均值: {target_mean:.0f}\")\n",
|
||
" # print(f\" 📉 下采样标签: {self.undersample_labels} → {target_mean:.0f}\")\n",
|
||
" # print(f\" 📈 过采样阈值: {self.oversample_threshold} × 均值 = {target_mean * self.oversample_threshold:.0f}\")\n",
|
||
" \n",
|
||
" # return self.distribution_analysis, self.sampling_strategy\n",
|
||
"\n",
|
||
"# 创建智能数据处理管道\n",
|
||
"print(\"🚀 创建智能数据处理管道...\")\n",
|
||
"pipeline = SmartDataPipeline(data_dir, random_state=42)\n",
|
||
"print(\"✅ 管道创建完成,准备执行步骤1...\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 9,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"步骤1和步骤2方法已添加到管道(第三小样本数策略)\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# 继续添加智能管道的其他方法【管道完善】- 修改为第三小样本数策略\n",
|
||
"\n",
|
||
"def step1_analyze_distribution(self, max_samples=100000):\n",
|
||
" \"\"\"\n",
|
||
" 步骤1: 分析数据分布,确定采样策略(改为第三小样本数策略)\n",
|
||
" \"\"\"\n",
|
||
" print(\"🔍 步骤1: 分析数据分布(第三小样本数策略)...\")\n",
|
||
" \n",
|
||
" # 分析验证集分布(代表整体分布特征)\n",
|
||
" all_labels = []\n",
|
||
" for trials_batch, filename in load_data_batch(self.data_dir, 'val', 5000):\n",
|
||
" _, labels = extract_features_labels_batch(trials_batch)\n",
|
||
" all_labels.extend(labels.tolist())\n",
|
||
" if len(all_labels) >= max_samples:\n",
|
||
" break\n",
|
||
" \n",
|
||
" # 统计分析\n",
|
||
" label_counts = Counter(all_labels)\n",
|
||
" \n",
|
||
" # 计算所有标签的样本数,找到第三小的数值\n",
|
||
" all_counts = [label_counts.get(i, 0) for i in range(41)]\n",
|
||
" non_zero_counts = [count for count in all_counts if count > 0]\n",
|
||
" sorted_counts = sorted(non_zero_counts)\n",
|
||
" \n",
|
||
" if len(sorted_counts) >= 3:\n",
|
||
" third_smallest_count = sorted_counts[2] # 第三小\n",
|
||
" elif len(sorted_counts) >= 2:\n",
|
||
" third_smallest_count = sorted_counts[1] # 如果不足3个,用第二小\n",
|
||
" else:\n",
|
||
" third_smallest_count = sorted_counts[0] if sorted_counts else 1 # 如果不足2个,用最小的\n",
|
||
" \n",
|
||
" print(f\" 所有标签样本数: {sorted_counts[:10]}{'...' if len(sorted_counts) > 10 else ''}\")\n",
|
||
" print(f\" 第三小样本数: {third_smallest_count}\")\n",
|
||
" \n",
|
||
" # 生成采样策略:所有标签都下采样到第三小,不进行过采样\n",
|
||
" sampling_strategy = {}\n",
|
||
" for label in range(41):\n",
|
||
" current_count = label_counts.get(label, 0)\n",
|
||
" \n",
|
||
" if current_count > third_smallest_count:\n",
|
||
" # 下采样到第三小样本数\n",
|
||
" target_count = third_smallest_count\n",
|
||
" action = 'undersample'\n",
|
||
" else:\n",
|
||
" # 保持现有样本数(不进行过采样)\n",
|
||
" target_count = current_count\n",
|
||
" action = 'keep'\n",
|
||
" \n",
|
||
" sampling_strategy[label] = {\n",
|
||
" 'current_count': current_count,\n",
|
||
" 'target_count': target_count,\n",
|
||
" 'action': action\n",
|
||
" }\n",
|
||
" \n",
|
||
" self.distribution_analysis = {\n",
|
||
" 'label_counts': label_counts,\n",
|
||
" 'target_third_smallest': third_smallest_count,\n",
|
||
" 'total_samples': len(all_labels),\n",
|
||
" 'sorted_counts': sorted_counts\n",
|
||
" }\n",
|
||
" self.sampling_strategy = sampling_strategy\n",
|
||
" \n",
|
||
" # 统计采样策略\n",
|
||
" undersample_count = sum(1 for s in sampling_strategy.values() if s['action'] == 'undersample')\n",
|
||
" keep_count = sum(1 for s in sampling_strategy.values() if s['action'] == 'keep')\n",
|
||
" \n",
|
||
" print(f\" ✅ 分析完成: {len(all_labels):,} 样本\")\n",
|
||
" print(f\" 📉 下采样标签: {undersample_count} 个 → {third_smallest_count}\")\n",
|
||
" print(f\" ✅ 保持不变: {keep_count} 个\")\n",
|
||
" print(f\" 🚫 不进行过采样\")\n",
|
||
" \n",
|
||
" return self.distribution_analysis, self.sampling_strategy\n",
|
||
"\n",
|
||
"def step2_fit_pca_with_undersampling(self):\n",
|
||
" \"\"\"\n",
|
||
" 步骤2: 仅对下采样数据拟合PCA参数(不进行过采样,避免PCA被过采样影响)\n",
|
||
" \"\"\"\n",
|
||
" if self.sampling_strategy is None:\n",
|
||
" raise ValueError(\"请先执行步骤1: step1_analyze_distribution()\")\n",
|
||
" \n",
|
||
" print(\"\\n🔧 步骤2: 拟合PCA参数(仅下采样,不过采样)...\")\n",
|
||
" \n",
|
||
" # 收集用于PCA拟合的样本(只下采样,不过采样)\n",
|
||
" pca_features = []\n",
|
||
" collected_samples = 0\n",
|
||
" \n",
|
||
" for trials_batch, filename in load_data_batch(self.data_dir, 'train', 3000, verbose=False):\n",
|
||
" features, labels = extract_features_labels_batch(trials_batch)\n",
|
||
" \n",
|
||
" # 对当前批次应用仅下采样策略\n",
|
||
" downsampled_features, downsampled_labels = self._apply_undersampling_only(features, labels)\n",
|
||
" \n",
|
||
" if downsampled_features.shape[0] > 0:\n",
|
||
" pca_features.append(downsampled_features)\n",
|
||
" collected_samples += downsampled_features.shape[0]\n",
|
||
" \n",
|
||
" if collected_samples >= self.pca_sample_size:\n",
|
||
" break\n",
|
||
" \n",
|
||
" if not pca_features:\n",
|
||
" raise ValueError(\"无法收集用于PCA拟合的样本,请检查数据或采样策略\")\n",
|
||
" \n",
|
||
" # 合并样本用于PCA拟合\n",
|
||
" X_pca_fit = np.vstack(pca_features)[:self.pca_sample_size]\n",
|
||
" print(f\" 用于PCA拟合的样本数: {X_pca_fit.shape[0]:,}\")\n",
|
||
" \n",
|
||
" # 标准化 + PCA\n",
|
||
" self.pca_scaler = StandardScaler()\n",
|
||
" X_scaled = self.pca_scaler.fit_transform(X_pca_fit)\n",
|
||
" \n",
|
||
" # 自动选择PCA成分数以保留指定方差\n",
|
||
" if self.pca_components is None:\n",
|
||
" pca_full = PCA()\n",
|
||
" pca_full.fit(X_scaled)\n",
|
||
" cumsum_var = np.cumsum(pca_full.explained_variance_ratio_)\n",
|
||
" optimal_components = np.argmax(cumsum_var >= self.pca_variance_threshold) + 1\n",
|
||
" self.pca_components = optimal_components\n",
|
||
" \n",
|
||
" self.pca_model = PCA(n_components=self.pca_components, random_state=self.random_state)\n",
|
||
" self.pca_model.fit(X_scaled)\n",
|
||
" self.pca_fitted = True\n",
|
||
" \n",
|
||
" print(f\" PCA拟合完成: 7168 → {self.pca_components}\")\n",
|
||
" print(f\" 保留方差: {self.pca_model.explained_variance_ratio_.sum():.4f}\")\n",
|
||
"\n",
|
||
"def _apply_undersampling_only(self, X, y):\n",
|
||
" \"\"\"\n",
|
||
" 仅对指定标签做下采样(不做过采样)- 修改为第三小样本数策略\n",
|
||
" \"\"\"\n",
|
||
" if self.sampling_strategy is None:\n",
|
||
" raise ValueError(\"请先执行步骤1: step1_analyze_distribution()\")\n",
|
||
" \n",
|
||
" X_result = []\n",
|
||
" y_result = []\n",
|
||
" \n",
|
||
" np.random.seed(self.random_state)\n",
|
||
" \n",
|
||
" for label in range(41):\n",
|
||
" label_mask = (y == label)\n",
|
||
" X_label = X[label_mask]\n",
|
||
" y_label = y[label_mask]\n",
|
||
" current_count = len(y_label)\n",
|
||
" \n",
|
||
" if current_count == 0:\n",
|
||
" continue\n",
|
||
" \n",
|
||
" strategy = self.sampling_strategy[label]\n",
|
||
" \n",
|
||
" if strategy['action'] == 'undersample' and current_count > strategy['target_count']:\n",
|
||
" # 下采样到第三小样本数\n",
|
||
" indices = np.random.choice(current_count, strategy['target_count'], replace=False)\n",
|
||
" X_resampled = X_label[indices]\n",
|
||
" y_resampled = y_label[indices]\n",
|
||
" else:\n",
|
||
" # 保持原样\n",
|
||
" X_resampled = X_label\n",
|
||
" y_resampled = y_label\n",
|
||
" \n",
|
||
" X_result.append(X_resampled)\n",
|
||
" y_result.append(y_resampled)\n",
|
||
" \n",
|
||
" if X_result:\n",
|
||
" return np.vstack(X_result), np.hstack(y_result)\n",
|
||
" else:\n",
|
||
" return np.array([]).reshape(0, X.shape[1]), np.array([])\n",
|
||
"\n",
|
||
"# 动态添加方法到类\n",
|
||
"SmartDataPipeline.step1_analyze_distribution = step1_analyze_distribution\n",
|
||
"SmartDataPipeline.step2_fit_pca_with_undersampling = step2_fit_pca_with_undersampling\n",
|
||
"SmartDataPipeline._apply_undersampling_only = _apply_undersampling_only\n",
|
||
"\n",
|
||
"print(\"步骤1和步骤2方法已添加到管道(第三小样本数策略)\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 12,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"所有方法已添加到智能管道(第三小样本数策略)\n",
|
||
"\n",
|
||
"智能数据处理管道状态:\n",
|
||
" 步骤1 - 分布分析: 未完成\n",
|
||
" 步骤2 - PCA拟合: 未完成\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# 添加智能管道的剩余方法 - 修改为第三小样本数策略\n",
|
||
"\n",
|
||
"def _apply_full_sampling(self, X, y):\n",
|
||
" \"\"\"\n",
|
||
" 应用完整的采样策略(修改为只下采样到第三小样本数)\n",
|
||
" \"\"\"\n",
|
||
" X_result = []\n",
|
||
" y_result = []\n",
|
||
" \n",
|
||
" np.random.seed(self.random_state)\n",
|
||
" \n",
|
||
" for label in range(41):\n",
|
||
" label_mask = (y == label)\n",
|
||
" X_label = X[label_mask]\n",
|
||
" y_label = y[label_mask]\n",
|
||
" current_count = len(y_label)\n",
|
||
" \n",
|
||
" if current_count == 0:\n",
|
||
" continue\n",
|
||
" \n",
|
||
" strategy = self.sampling_strategy[label]\n",
|
||
" target_count = strategy['target_count']\n",
|
||
" \n",
|
||
" if strategy['action'] == 'undersample' and current_count > target_count:\n",
|
||
" # 只进行下采样到第三小样本数\n",
|
||
" indices = np.random.choice(current_count, target_count, replace=False)\n",
|
||
" X_resampled = X_label[indices]\n",
|
||
" y_resampled = y_label[indices]\n",
|
||
" else:\n",
|
||
" # 保持原样(不进行过采样)\n",
|
||
" X_resampled = X_label\n",
|
||
" y_resampled = y_label\n",
|
||
" \n",
|
||
" X_result.append(X_resampled)\n",
|
||
" y_result.append(y_resampled)\n",
|
||
" \n",
|
||
" if X_result:\n",
|
||
" return np.vstack(X_result), np.hstack(y_result)\n",
|
||
" else:\n",
|
||
" return np.array([]).reshape(0, X.shape[1]), np.array([])\n",
|
||
"\n",
|
||
"def _apply_pca_transform(self, X):\n",
|
||
" \"\"\"\n",
|
||
" 应用PCA变换\n",
|
||
" \"\"\"\n",
|
||
" if not self.pca_fitted:\n",
|
||
" return X\n",
|
||
" \n",
|
||
" X_scaled = self.pca_scaler.transform(X)\n",
|
||
" X_pca = self.pca_model.transform(X_scaled)\n",
|
||
" return X_pca\n",
|
||
"\n",
|
||
"def step3_process_data(self, data_type, apply_sampling=None):\n",
|
||
" \"\"\"\n",
|
||
" 步骤3: 处理数据(采样+PCA降维)\n",
|
||
" \"\"\"\n",
|
||
" if not self.pca_fitted:\n",
|
||
" raise ValueError(\"请先执行步骤2: step2_fit_pca_with_undersampling()\")\n",
|
||
" \n",
|
||
" if apply_sampling is None:\n",
|
||
" apply_sampling = (data_type == 'train')\n",
|
||
" \n",
|
||
" print(f\"\\n处理{data_type}数据...\")\n",
|
||
" print(f\" 采样策略: {'启用(只下采样)' if apply_sampling else '禁用'}\")\n",
|
||
" \n",
|
||
" all_features = []\n",
|
||
" all_labels = []\n",
|
||
" \n",
|
||
" # 在内部关闭加载时的逐文件打印\n",
|
||
" for trials_batch, filename in load_data_batch(self.data_dir, data_type, 3000, verbose=False):\n",
|
||
" features, labels = extract_features_labels_batch(trials_batch)\n",
|
||
" \n",
|
||
" if apply_sampling:\n",
|
||
" features_sampled, labels_sampled = self._apply_full_sampling(features, labels)\n",
|
||
" else:\n",
|
||
" features_sampled, labels_sampled = features, labels\n",
|
||
" \n",
|
||
" if features_sampled.shape[0] > 0:\n",
|
||
" features_pca = self._apply_pca_transform(features_sampled)\n",
|
||
" all_features.append(features_pca)\n",
|
||
" all_labels.append(labels_sampled)\n",
|
||
" \n",
|
||
" if all_features:\n",
|
||
" X = np.vstack(all_features)\n",
|
||
" y = np.hstack(all_labels)\n",
|
||
" \n",
|
||
" shuffle_indices = np.random.permutation(len(y))\n",
|
||
" X = X[shuffle_indices]\n",
|
||
" y = y[shuffle_indices]\n",
|
||
" \n",
|
||
" print(f\" 完成: {X.shape[0]:,} 样本, {X.shape[1]} 特征\")\n",
|
||
" \n",
|
||
" del all_features, all_labels\n",
|
||
" gc.collect()\n",
|
||
" \n",
|
||
" return X, y\n",
|
||
" else:\n",
|
||
" return None, None\n",
|
||
"\n",
|
||
"def print_summary(self):\n",
|
||
" print(\"\\n智能数据处理管道状态:\")\n",
|
||
" print(f\" 步骤1 - 分布分析: {'完成' if self.distribution_analysis else '未完成'}\")\n",
|
||
" print(f\" 步骤2 - PCA拟合: {'完成' if self.pca_fitted else '未完成'}\")\n",
|
||
" \n",
|
||
" if self.distribution_analysis:\n",
|
||
" target_third_smallest = self.distribution_analysis['target_third_smallest']\n",
|
||
" print(f\" 第三小样本数: {target_third_smallest}\")\n",
|
||
" print(f\" 采样策略: 只下采样,不过采样\")\n",
|
||
" \n",
|
||
" if self.pca_fitted:\n",
|
||
" print(f\" PCA降维: 7168 → {self.pca_components} ({self.pca_components/7168:.1%})\")\n",
|
||
" print(f\" 保留方差: {self.pca_model.explained_variance_ratio_.sum():.4f}\")\n",
|
||
"\n",
|
||
"# 动态添加剩余方法到类\n",
|
||
"SmartDataPipeline._apply_full_sampling = _apply_full_sampling\n",
|
||
"SmartDataPipeline._apply_pca_transform = _apply_pca_transform\n",
|
||
"SmartDataPipeline.step3_process_data = step3_process_data\n",
|
||
"SmartDataPipeline.print_summary = print_summary\n",
|
||
"\n",
|
||
"print(\"所有方法已添加到智能管道(第三小样本数策略)\")\n",
|
||
"pipeline.print_summary()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 🎯 数据增强模块 - 时序神经数据增强"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 10,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"🎯 神经数据增强器初始化完成\n",
|
||
" 白噪声标准差: 1.0\n",
|
||
" 常数偏移标准差: 0.2\n",
|
||
" 随机游走标准差: 0.0\n",
|
||
" 静态增益标准差: 0.0\n",
|
||
" 随机切割步数: 3\n",
|
||
" 平滑数据: True\n",
|
||
" 平滑核大小: 100, 标准差: 2\n",
|
||
"✅ 数据增强器创建完成!\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# 🎯 时序神经数据增强类\n",
|
||
"\n",
|
||
"import numpy as np\n",
|
||
"import random\n",
|
||
"from scipy import ndimage\n",
|
||
"from scipy.ndimage import gaussian_filter1d\n",
|
||
"import gc\n",
|
||
"\n",
|
||
"class NeuralDataAugmenter:\n",
|
||
" \"\"\"\n",
|
||
" 时序神经数据增强器\n",
|
||
" 专门用于处理脑机接口的神经信号数据增强\n",
|
||
" \"\"\"\n",
|
||
" \n",
|
||
" def __init__(self, \n",
|
||
" white_noise_std=1.0,\n",
|
||
" constant_offset_std=0.2,\n",
|
||
" random_walk_std=0.0,\n",
|
||
" random_walk_axis=-1,\n",
|
||
" static_gain_std=0.0,\n",
|
||
" random_cut=3,\n",
|
||
" smooth_kernel_size=100,\n",
|
||
" smooth_data=True,\n",
|
||
" smooth_kernel_std=2,\n",
|
||
" random_state=42):\n",
|
||
" \"\"\"\n",
|
||
" 初始化数据增强器\n",
|
||
" \n",
|
||
" Args:\n",
|
||
" white_noise_std: 白噪声标准差\n",
|
||
" constant_offset_std: 常数偏移标准差\n",
|
||
" random_walk_std: 随机游走标准差\n",
|
||
" random_walk_axis: 随机游走应用的轴\n",
|
||
" static_gain_std: 静态增益标准差\n",
|
||
" random_cut: 随机切割时间步数\n",
|
||
" smooth_kernel_size: 平滑核大小\n",
|
||
" smooth_data: 是否平滑数据\n",
|
||
" smooth_kernel_std: 平滑核标准差\n",
|
||
" random_state: 随机种子\n",
|
||
" \"\"\"\n",
|
||
" self.white_noise_std = white_noise_std\n",
|
||
" self.constant_offset_std = constant_offset_std\n",
|
||
" self.random_walk_std = random_walk_std\n",
|
||
" self.random_walk_axis = random_walk_axis\n",
|
||
" self.static_gain_std = static_gain_std\n",
|
||
" self.random_cut = random_cut\n",
|
||
" self.smooth_kernel_size = smooth_kernel_size\n",
|
||
" self.smooth_data = smooth_data\n",
|
||
" self.smooth_kernel_std = smooth_kernel_std\n",
|
||
" self.random_state = random_state\n",
|
||
" \n",
|
||
" # 设置随机种子\n",
|
||
" np.random.seed(random_state)\n",
|
||
" random.seed(random_state)\n",
|
||
" \n",
|
||
" print(f\"🎯 神经数据增强器初始化完成\")\n",
|
||
" print(f\" 白噪声标准差: {white_noise_std}\")\n",
|
||
" print(f\" 常数偏移标准差: {constant_offset_std}\")\n",
|
||
" print(f\" 随机游走标准差: {random_walk_std}\")\n",
|
||
" print(f\" 静态增益标准差: {static_gain_std}\")\n",
|
||
" print(f\" 随机切割步数: {random_cut}\")\n",
|
||
" print(f\" 平滑数据: {smooth_data}\")\n",
|
||
" if smooth_data:\n",
|
||
" print(f\" 平滑核大小: {smooth_kernel_size}, 标准差: {smooth_kernel_std}\")\n",
|
||
" \n",
|
||
" def reconstruct_time_series(self, flattened_features, original_shape=(14, 512)):\n",
|
||
" \"\"\"\n",
|
||
" 将扁平化的特征重建为时序数据\n",
|
||
" \n",
|
||
" Args:\n",
|
||
" flattened_features: 扁平化的特征 (7168维)\n",
|
||
" original_shape: 原始时序形状 (时间步, 特征维度)\n",
|
||
" \n",
|
||
" Returns:\n",
|
||
" time_series: 重建的时序数据 (time_steps, features)\n",
|
||
" \"\"\"\n",
|
||
" # 假设前7168维是神经特征 (14 * 512 = 7168)\n",
|
||
" neural_data = flattened_features[:7168]\n",
|
||
" time_series = neural_data.reshape(original_shape)\n",
|
||
" return time_series\n",
|
||
" \n",
|
||
" def add_white_noise(self, data):\n",
|
||
" \"\"\"添加白噪声\"\"\"\n",
|
||
" if self.white_noise_std <= 0:\n",
|
||
" return data\n",
|
||
" \n",
|
||
" noise = np.random.normal(0, self.white_noise_std, data.shape)\n",
|
||
" return data + noise\n",
|
||
" \n",
|
||
" def add_constant_offset(self, data):\n",
|
||
" \"\"\"添加常数偏移\"\"\"\n",
|
||
" if self.constant_offset_std <= 0:\n",
|
||
" return data\n",
|
||
" \n",
|
||
" # 为每个通道添加不同的常数偏移\n",
|
||
" offset = np.random.normal(0, self.constant_offset_std, (1, data.shape[1]))\n",
|
||
" return data + offset\n",
|
||
" \n",
|
||
" def add_random_walk(self, data):\n",
|
||
" \"\"\"添加随机游走\"\"\"\n",
|
||
" if self.random_walk_std <= 0:\n",
|
||
" return data\n",
|
||
" \n",
|
||
" if self.random_walk_axis == -1: # 沿时间轴\n",
|
||
" walk = np.random.normal(0, self.random_walk_std, data.shape[0])\n",
|
||
" walk = np.cumsum(walk) # 累积求和形成随机游走\n",
|
||
" walk = walk.reshape(-1, 1) # 广播到所有通道\n",
|
||
" return data + walk\n",
|
||
" else:\n",
|
||
" walk = np.random.normal(0, self.random_walk_std, data.shape)\n",
|
||
" walk = np.cumsum(walk, axis=self.random_walk_axis)\n",
|
||
" return data + walk\n",
|
||
" \n",
|
||
" def apply_static_gain(self, data):\n",
|
||
" \"\"\"应用静态增益\"\"\"\n",
|
||
" if self.static_gain_std <= 0:\n",
|
||
" return data\n",
|
||
" \n",
|
||
" # 为每个通道应用不同的增益\n",
|
||
" gain = 1 + np.random.normal(0, self.static_gain_std, (1, data.shape[1]))\n",
|
||
" return data * gain\n",
|
||
" \n",
|
||
" def random_time_cut(self, data):\n",
|
||
" \"\"\"随机切割时间步\"\"\"\n",
|
||
" if self.random_cut <= 0 or data.shape[0] <= self.random_cut:\n",
|
||
" return data\n",
|
||
" \n",
|
||
" # 从开头随机切掉一些时间步\n",
|
||
" cut_steps = np.random.randint(0, min(self.random_cut + 1, data.shape[0]))\n",
|
||
" return data[cut_steps:]\n",
|
||
" \n",
|
||
" def smooth_data_func(self, data):\n",
|
||
" \"\"\"平滑数据\"\"\"\n",
|
||
" if not self.smooth_data or self.smooth_kernel_std <= 0:\n",
|
||
" return data\n",
|
||
" \n",
|
||
" # 对每个通道分别应用高斯平滑\n",
|
||
" smoothed_data = np.zeros_like(data)\n",
|
||
" for i in range(data.shape[1]):\n",
|
||
" smoothed_data[:, i] = gaussian_filter1d(\n",
|
||
" data[:, i], \n",
|
||
" sigma=self.smooth_kernel_std\n",
|
||
" )\n",
|
||
" return smoothed_data\n",
|
||
" \n",
|
||
" def augment_time_series(self, time_series_data):\n",
|
||
" \"\"\"\n",
|
||
" 对时序数据应用所有增强方法\n",
|
||
" \n",
|
||
" Args:\n",
|
||
" time_series_data: 时序数据 (time_steps, features)\n",
|
||
" \n",
|
||
" Returns:\n",
|
||
" augmented_data: 增强后的时序数据\n",
|
||
" \"\"\"\n",
|
||
" data = time_series_data.copy()\n",
|
||
" \n",
|
||
" # 1. 随机时间切割(在其他增强之前)\n",
|
||
" data = self.random_time_cut(data)\n",
|
||
" \n",
|
||
" # 2. 添加白噪声\n",
|
||
" data = self.add_white_noise(data)\n",
|
||
" \n",
|
||
" # 3. 添加常数偏移\n",
|
||
" data = self.add_constant_offset(data)\n",
|
||
" \n",
|
||
" # 4. 添加随机游走\n",
|
||
" data = self.add_random_walk(data)\n",
|
||
" \n",
|
||
" # 5. 应用静态增益\n",
|
||
" data = self.apply_static_gain(data)\n",
|
||
" \n",
|
||
" # 6. 平滑数据(在最后应用)\n",
|
||
" data = self.smooth_data_func(data)\n",
|
||
" \n",
|
||
" return data\n",
|
||
" \n",
|
||
" def flatten_time_series(self, time_series_data, target_length=7168):\n",
|
||
" \"\"\"\n",
|
||
" 将时序数据重新扁平化为目标长度\n",
|
||
" \n",
|
||
" Args:\n",
|
||
" time_series_data: 时序数据 (time_steps, features)\n",
|
||
" target_length: 目标扁平化长度\n",
|
||
" \n",
|
||
" Returns:\n",
|
||
" flattened_data: 扁平化的数据\n",
|
||
" \"\"\"\n",
|
||
" flattened = time_series_data.flatten()\n",
|
||
" \n",
|
||
" # 如果长度不够,用零填充\n",
|
||
" if len(flattened) < target_length:\n",
|
||
" padded = np.zeros(target_length)\n",
|
||
" padded[:len(flattened)] = flattened\n",
|
||
" return padded\n",
|
||
" # 如果长度超过,截断\n",
|
||
" elif len(flattened) > target_length:\n",
|
||
" return flattened[:target_length]\n",
|
||
" else:\n",
|
||
" return flattened\n",
|
||
" \n",
|
||
" def augment_neural_features(self, features_batch, augment_ratio=0.5):\n",
|
||
" \"\"\"\n",
|
||
" 对神经特征批次进行数据增强\n",
|
||
" \n",
|
||
" Args:\n",
|
||
" features_batch: 特征批次 (batch_size, 7168)\n",
|
||
" augment_ratio: 增强比例(0-1之间)\n",
|
||
" \n",
|
||
" Returns:\n",
|
||
" augmented_features: 增强后的特征(包含原始和增强的数据)\n",
|
||
" augmented_indices: 增强样本的索引\n",
|
||
" \"\"\"\n",
|
||
" batch_size = features_batch.shape[0]\n",
|
||
" n_augment = int(batch_size * augment_ratio)\n",
|
||
" \n",
|
||
" if n_augment == 0:\n",
|
||
" return features_batch, []\n",
|
||
" \n",
|
||
" # 随机选择要增强的样本\n",
|
||
" augment_indices = np.random.choice(batch_size, n_augment, replace=False)\n",
|
||
" \n",
|
||
" augmented_features = []\n",
|
||
" \n",
|
||
" for i, features in enumerate(features_batch):\n",
|
||
" if i in augment_indices:\n",
|
||
" # 重建时序数据\n",
|
||
" time_series = self.reconstruct_time_series(features)\n",
|
||
" \n",
|
||
" # 应用数据增强\n",
|
||
" augmented_time_series = self.augment_time_series(time_series)\n",
|
||
" \n",
|
||
" # 重新扁平化\n",
|
||
" augmented_features_flat = self.flatten_time_series(augmented_time_series)\n",
|
||
" \n",
|
||
" augmented_features.append(augmented_features_flat)\n",
|
||
" else:\n",
|
||
" augmented_features.append(features)\n",
|
||
" \n",
|
||
" return np.array(augmented_features), augment_indices.tolist()\n",
|
||
"\n",
|
||
"# 创建数据增强器实例\n",
|
||
"augmenter = NeuralDataAugmenter(\n",
|
||
" white_noise_std=1.0,\n",
|
||
" constant_offset_std=0.2,\n",
|
||
" random_walk_std=0.0,\n",
|
||
" random_walk_axis=-1,\n",
|
||
" static_gain_std=0.0,\n",
|
||
" random_cut=3,\n",
|
||
" smooth_kernel_size=100,\n",
|
||
" smooth_data=True,\n",
|
||
" smooth_kernel_std=2,\n",
|
||
" random_state=42\n",
|
||
")\n",
|
||
"\n",
|
||
"print(\"✅ 数据增强器创建完成!\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 30,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"✅ 数据增强功能已集成到智能管道(已修复数组比较问题)\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# 🔗 集成数据增强到智能管道\n",
|
||
"\n",
|
||
"def extract_features_labels_batch_with_augmentation(trials_batch, random_shuffle_trials=True, \n",
|
||
" apply_augmentation=True, augment_ratio=0.3):\n",
|
||
" \"\"\"\n",
|
||
" 从试验批次中提取特征和标签,支持数据增强\n",
|
||
" \n",
|
||
" Args:\n",
|
||
" trials_batch: 试验批次数据\n",
|
||
" random_shuffle_trials: 是否随机打乱试验顺序\n",
|
||
" apply_augmentation: 是否应用数据增强\n",
|
||
" augment_ratio: 数据增强比例\n",
|
||
" \"\"\"\n",
|
||
" features = []\n",
|
||
" labels = []\n",
|
||
" \n",
|
||
" # 随机打乱试验顺序\n",
|
||
" if random_shuffle_trials and len(trials_batch) > 1:\n",
|
||
" trial_indices = list(range(len(trials_batch)))\n",
|
||
" random.shuffle(trial_indices)\n",
|
||
" trials_batch = trials_batch[trial_indices]\n",
|
||
" \n",
|
||
" for trial in trials_batch:\n",
|
||
" if trial.shape[0] > 0:\n",
|
||
" # 随机打乱时间步顺序\n",
|
||
" time_indices = list(range(trial.shape[0]))\n",
|
||
" if random_shuffle_trials:\n",
|
||
" random.shuffle(time_indices)\n",
|
||
" \n",
|
||
" for t in time_indices:\n",
|
||
" neural_features = trial[t, :7168] # 前7168维神经特征\n",
|
||
" rnn_logits = trial[t, 7168:] # 后41维RNN输出\n",
|
||
" phoneme_label = np.argmax(rnn_logits)\n",
|
||
" \n",
|
||
" features.append(neural_features)\n",
|
||
" labels.append(phoneme_label)\n",
|
||
" \n",
|
||
" if not features:\n",
|
||
" return np.array([]), np.array([])\n",
|
||
" \n",
|
||
" features = np.array(features)\n",
|
||
" labels = np.array(labels)\n",
|
||
" \n",
|
||
" # 应用数据增强\n",
|
||
" if apply_augmentation and len(features) > 0:\n",
|
||
" print(f\" 应用数据增强 (比例: {augment_ratio})\")\n",
|
||
" features, augmented_indices = augmenter.augment_neural_features(features, augment_ratio)\n",
|
||
" print(f\" 增强样本数: {len(augmented_indices)}\")\n",
|
||
" \n",
|
||
" return features, labels\n",
|
||
"\n",
|
||
"def _apply_full_sampling(self, X, y):\n",
|
||
" \"\"\"\n",
|
||
" 应用完整的采样策略(修改为只下采样到第三小样本数)\n",
|
||
" \"\"\"\n",
|
||
" X_result = []\n",
|
||
" y_result = []\n",
|
||
" \n",
|
||
" np.random.seed(self.random_state)\n",
|
||
" \n",
|
||
" # 确保y是1维numpy数组,避免数组比较的歧义\n",
|
||
" y = np.asarray(y, dtype=int).flatten()\n",
|
||
" \n",
|
||
" for label in range(41):\n",
|
||
" label_mask = (y == label)\n",
|
||
" X_label = X[label_mask]\n",
|
||
" y_label = y[label_mask]\n",
|
||
" current_count = len(y_label)\n",
|
||
" \n",
|
||
" if current_count == 0:\n",
|
||
" continue\n",
|
||
" \n",
|
||
" strategy = self.sampling_strategy[label]\n",
|
||
" target_count = strategy['target_count']\n",
|
||
" \n",
|
||
" if strategy['action'] == 'undersample' and current_count > target_count:\n",
|
||
" # 只进行下采样到第三小样本数\n",
|
||
" indices = np.random.choice(current_count, target_count, replace=False)\n",
|
||
" X_resampled = X_label[indices]\n",
|
||
" y_resampled = y_label[indices]\n",
|
||
" else:\n",
|
||
" # 保持原样(不进行过采样)\n",
|
||
" X_resampled = X_label\n",
|
||
" y_resampled = y_label\n",
|
||
" \n",
|
||
" X_result.append(X_resampled)\n",
|
||
" y_result.append(y_resampled)\n",
|
||
" \n",
|
||
" if X_result:\n",
|
||
" return np.vstack(X_result), np.hstack(y_result)\n",
|
||
" else:\n",
|
||
" return np.array([]).reshape(0, X.shape[1]), np.array([])\n",
|
||
"\n",
|
||
"def _apply_pca_transform(self, X):\n",
|
||
" \"\"\"\n",
|
||
" 应用PCA变换\n",
|
||
" \"\"\"\n",
|
||
" if not self.pca_fitted:\n",
|
||
" return X\n",
|
||
" \n",
|
||
" X_scaled = self.pca_scaler.transform(X)\n",
|
||
" X_pca = self.pca_model.transform(X_scaled)\n",
|
||
" return X_pca\n",
|
||
"\n",
|
||
"def _apply_full_sampling_with_augmentation(self, X, y, apply_augmentation=True, augment_ratio=0.3):\n",
|
||
" \"\"\"\n",
|
||
" 应用完整的采样策略(修改为只下采样到第三小样本数)+ 数据增强\n",
|
||
" \"\"\"\n",
|
||
" X_result = []\n",
|
||
" y_result = []\n",
|
||
" \n",
|
||
" np.random.seed(self.random_state)\n",
|
||
" \n",
|
||
" # 确保y是1维numpy数组,避免数组比较的歧义\n",
|
||
" y = np.asarray(y, dtype=int).flatten()\n",
|
||
" \n",
|
||
" for label in range(41):\n",
|
||
" label_mask = (y == label)\n",
|
||
" X_label = X[label_mask]\n",
|
||
" y_label = y[label_mask]\n",
|
||
" current_count = len(y_label)\n",
|
||
" \n",
|
||
" if current_count == 0:\n",
|
||
" continue\n",
|
||
" \n",
|
||
" strategy = self.sampling_strategy[label]\n",
|
||
" target_count = strategy['target_count']\n",
|
||
" \n",
|
||
" if strategy['action'] == 'undersample' and current_count > target_count:\n",
|
||
" # 只进行下采样到第三小样本数\n",
|
||
" indices = np.random.choice(current_count, target_count, replace=False)\n",
|
||
" X_resampled = X_label[indices]\n",
|
||
" y_resampled = y_label[indices]\n",
|
||
" else:\n",
|
||
" # 保持原样(不进行过采样)\n",
|
||
" X_resampled = X_label\n",
|
||
" y_resampled = y_label\n",
|
||
" \n",
|
||
" # 对下采样后的数据应用数据增强\n",
|
||
" if apply_augmentation and len(X_resampled) > 0:\n",
|
||
" X_resampled, _ = augmenter.augment_neural_features(X_resampled, augment_ratio)\n",
|
||
" \n",
|
||
" X_result.append(X_resampled)\n",
|
||
" y_result.append(y_resampled)\n",
|
||
" \n",
|
||
" if X_result:\n",
|
||
" return np.vstack(X_result), np.hstack(y_result)\n",
|
||
" else:\n",
|
||
" return np.array([]).reshape(0, X.shape[1]), np.array([])\n",
|
||
"\n",
|
||
"def step3_process_data(self, data_type, apply_sampling=None):\n",
|
||
" \"\"\"\n",
|
||
" 步骤3: 处理数据(采样+PCA降维)- 原始版本(无数据增强)\n",
|
||
" \"\"\"\n",
|
||
" if not self.pca_fitted:\n",
|
||
" raise ValueError(\"请先执行步骤2: step2_fit_pca_with_undersampling()\")\n",
|
||
" \n",
|
||
" if apply_sampling is None:\n",
|
||
" apply_sampling = (data_type == 'train')\n",
|
||
" \n",
|
||
" print(f\"\\n处理{data_type}数据...\")\n",
|
||
" print(f\" 采样策略: {'启用(只下采样)' if apply_sampling else '禁用'}\")\n",
|
||
" \n",
|
||
" all_features = []\n",
|
||
" all_labels = []\n",
|
||
" \n",
|
||
" # 在内部关闭加载时的逐文件打印\n",
|
||
" for trials_batch, filename in load_data_batch(self.data_dir, data_type, 3000, verbose=False):\n",
|
||
" features, labels = extract_features_labels_batch(trials_batch)\n",
|
||
" \n",
|
||
" if apply_sampling:\n",
|
||
" features_sampled, labels_sampled = self._apply_full_sampling(features, labels)\n",
|
||
" else:\n",
|
||
" features_sampled, labels_sampled = features, labels\n",
|
||
" \n",
|
||
" if features_sampled.shape[0] > 0:\n",
|
||
" features_pca = self._apply_pca_transform(features_sampled)\n",
|
||
" all_features.append(features_pca)\n",
|
||
" all_labels.append(labels_sampled)\n",
|
||
" \n",
|
||
" if all_features:\n",
|
||
" X = np.vstack(all_features)\n",
|
||
" y = np.hstack(all_labels)\n",
|
||
" \n",
|
||
" shuffle_indices = np.random.permutation(len(y))\n",
|
||
" X = X[shuffle_indices]\n",
|
||
" y = y[shuffle_indices]\n",
|
||
" \n",
|
||
" print(f\" 完成: {X.shape[0]:,} 样本, {X.shape[1]} 特征\")\n",
|
||
" \n",
|
||
" del all_features, all_labels\n",
|
||
" gc.collect()\n",
|
||
" \n",
|
||
" return X, y\n",
|
||
" else:\n",
|
||
" return None, None\n",
|
||
"\n",
|
||
"def step3_process_data_with_augmentation(self, data_type, apply_sampling=None, \n",
|
||
" apply_augmentation=True, augment_ratio=0.3):\n",
|
||
" \"\"\"\n",
|
||
" 步骤3: 处理数据(采样+数据增强+PCA降维)\n",
|
||
" \"\"\"\n",
|
||
" if not self.pca_fitted:\n",
|
||
" raise ValueError(\"请先执行步骤2: step2_fit_pca_with_undersampling()\")\n",
|
||
" \n",
|
||
" if apply_sampling is None:\n",
|
||
" apply_sampling = (data_type == 'train')\n",
|
||
" \n",
|
||
" # 只对训练数据应用数据增强\n",
|
||
" if data_type != 'train':\n",
|
||
" apply_augmentation = False\n",
|
||
" \n",
|
||
" print(f\"\\n处理{data_type}数据...\")\n",
|
||
" print(f\" 采样策略: {'启用(只下采样)' if apply_sampling else '禁用'}\")\n",
|
||
" print(f\" 数据增强: {'启用' if apply_augmentation else '禁用'}\")\n",
|
||
" if apply_augmentation:\n",
|
||
" print(f\" 增强比例: {augment_ratio}\")\n",
|
||
" \n",
|
||
" all_features = []\n",
|
||
" all_labels = []\n",
|
||
" \n",
|
||
" # 在内部关闭加载时的逐文件打印\n",
|
||
" for trials_batch, filename in load_data_batch(self.data_dir, data_type, 3000, verbose=False):\n",
|
||
" # 使用带数据增强的特征提取函数\n",
|
||
" features, labels = extract_features_labels_batch_with_augmentation(\n",
|
||
" trials_batch, \n",
|
||
" random_shuffle_trials=True,\n",
|
||
" apply_augmentation=apply_augmentation,\n",
|
||
" augment_ratio=augment_ratio\n",
|
||
" )\n",
|
||
" \n",
|
||
" if apply_sampling and len(features) > 0:\n",
|
||
" features_sampled, labels_sampled = self._apply_full_sampling(features, labels)\n",
|
||
" else:\n",
|
||
" features_sampled, labels_sampled = features, labels\n",
|
||
" \n",
|
||
" if features_sampled.shape[0] > 0:\n",
|
||
" features_pca = self._apply_pca_transform(features_sampled)\n",
|
||
" all_features.append(features_pca)\n",
|
||
" all_labels.append(labels_sampled)\n",
|
||
" \n",
|
||
" if all_features:\n",
|
||
" X = np.vstack(all_features)\n",
|
||
" y = np.hstack(all_labels)\n",
|
||
" \n",
|
||
" shuffle_indices = np.random.permutation(len(y))\n",
|
||
" X = X[shuffle_indices]\n",
|
||
" y = y[shuffle_indices]\n",
|
||
" \n",
|
||
" print(f\" 完成: {X.shape[0]:,} 样本, {X.shape[1]} 特征\")\n",
|
||
" \n",
|
||
" del all_features, all_labels\n",
|
||
" gc.collect()\n",
|
||
" \n",
|
||
" return X, y\n",
|
||
" else:\n",
|
||
" return None, None\n",
|
||
"\n",
|
||
"# 动态添加所有方法到类(包括原始版本和增强版本)\n",
|
||
"# 确保添加所有必需的方法\n",
|
||
"SmartDataPipeline._apply_full_sampling = _apply_full_sampling\n",
|
||
"SmartDataPipeline._apply_pca_transform = _apply_pca_transform \n",
|
||
"SmartDataPipeline._apply_full_sampling_with_augmentation = _apply_full_sampling_with_augmentation\n",
|
||
"SmartDataPipeline.step3_process_data = step3_process_data\n",
|
||
"SmartDataPipeline.step3_process_data_with_augmentation = step3_process_data_with_augmentation\n",
|
||
"\n",
|
||
"print(\"✅ 数据增强功能已集成到智能管道(已修复数组比较问题)\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 31,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"🧪 测试神经数据增强功能...\n",
|
||
"============================================================\n",
|
||
"模拟时序数据形状: (5, 14, 512)\n",
|
||
"扁平化特征形状: (5, 7168)\n",
|
||
"\n",
|
||
"测试单个样本的数据增强:\n",
|
||
" 原始样本统计: mean=-0.007, std=1.004\n",
|
||
" 重建时序形状: (14, 512)\n",
|
||
" 增强时序形状: (12, 512)\n",
|
||
" 增强样本统计: mean=0.005, std=0.599\n",
|
||
"\n",
|
||
"测试批量数据增强:\n",
|
||
" 原始批次形状: (5, 7168)\n",
|
||
" 增强批次形状: (5, 7168)\n",
|
||
" 增强样本索引: [3, 2, 1]\n",
|
||
"\n",
|
||
"统计比较:\n",
|
||
" 样本 0: ➡️ 原始\n",
|
||
" 原始: mean=-0.007, std=1.004\n",
|
||
" 处理: mean=-0.007, std=1.004\n",
|
||
" 样本 1: 🎯 增强\n",
|
||
" 原始: mean=0.008, std=1.000\n",
|
||
" 处理: mean=0.007, std=0.604\n",
|
||
" 样本 2: 🎯 增强\n",
|
||
" 原始: mean=0.011, std=1.005\n",
|
||
" 处理: mean=0.005, std=0.600\n",
|
||
" 样本 3: 🎯 增强\n",
|
||
" 原始: mean=-0.008, std=0.985\n",
|
||
" 处理: mean=-0.006, std=0.610\n",
|
||
" 样本 4: ➡️ 原始\n",
|
||
" 原始: mean=-0.021, std=1.015\n",
|
||
" 处理: mean=-0.021, std=1.015\n",
|
||
"\n",
|
||
"✅ 数据增强测试完成!\n",
|
||
"\n",
|
||
"📋 当前数据增强配置:\n",
|
||
" 白噪声标准差: 1.0\n",
|
||
" 常数偏移标准差: 0.2\n",
|
||
" 随机游走标准差: 0.0\n",
|
||
" 静态增益标准差: 0.0\n",
|
||
" 随机切割步数: 3\n",
|
||
" 数据平滑: True\n",
|
||
" 平滑核标准差: 2\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# 🧪 测试数据增强功能\n",
|
||
"\n",
|
||
"print(\"🧪 测试神经数据增强功能...\")\n",
|
||
"print(\"=\" * 60)\n",
|
||
"\n",
|
||
"# 创建模拟的时序神经数据\n",
|
||
"np.random.seed(42)\n",
|
||
"batch_size = 5\n",
|
||
"n_timesteps = 14\n",
|
||
"n_features = 512\n",
|
||
"\n",
|
||
"# 模拟神经信号数据\n",
|
||
"mock_time_series = np.random.randn(batch_size, n_timesteps, n_features)\n",
|
||
"print(f\"模拟时序数据形状: {mock_time_series.shape}\")\n",
|
||
"\n",
|
||
"# 扁平化为7168维特征\n",
|
||
"mock_features = mock_time_series.reshape(batch_size, -1)\n",
|
||
"print(f\"扁平化特征形状: {mock_features.shape}\")\n",
|
||
"\n",
|
||
"# 测试单个样本的数据增强\n",
|
||
"print(f\"\\n测试单个样本的数据增强:\")\n",
|
||
"original_sample = mock_features[0]\n",
|
||
"print(f\" 原始样本统计: mean={original_sample.mean():.3f}, std={original_sample.std():.3f}\")\n",
|
||
"\n",
|
||
"# 重建时序数据\n",
|
||
"reconstructed_ts = augmenter.reconstruct_time_series(original_sample)\n",
|
||
"print(f\" 重建时序形状: {reconstructed_ts.shape}\")\n",
|
||
"\n",
|
||
"# 应用数据增强\n",
|
||
"augmented_ts = augmenter.augment_time_series(reconstructed_ts)\n",
|
||
"print(f\" 增强时序形状: {augmented_ts.shape}\")\n",
|
||
"\n",
|
||
"# 重新扁平化\n",
|
||
"augmented_sample = augmenter.flatten_time_series(augmented_ts)\n",
|
||
"print(f\" 增强样本统计: mean={augmented_sample.mean():.3f}, std={augmented_sample.std():.3f}\")\n",
|
||
"\n",
|
||
"# 测试批量数据增强\n",
|
||
"print(f\"\\n测试批量数据增强:\")\n",
|
||
"augmented_batch, aug_indices = augmenter.augment_neural_features(mock_features, augment_ratio=0.6)\n",
|
||
"print(f\" 原始批次形状: {mock_features.shape}\")\n",
|
||
"print(f\" 增强批次形状: {augmented_batch.shape}\")\n",
|
||
"print(f\" 增强样本索引: {aug_indices}\")\n",
|
||
"\n",
|
||
"# 比较原始和增强数据的统计特性\n",
|
||
"print(f\"\\n统计比较:\")\n",
|
||
"for i in range(batch_size):\n",
|
||
" original_stats = f\"mean={mock_features[i].mean():.3f}, std={mock_features[i].std():.3f}\"\n",
|
||
" augmented_stats = f\"mean={augmented_batch[i].mean():.3f}, std={augmented_batch[i].std():.3f}\"\n",
|
||
" status = \"🎯 增强\" if i in aug_indices else \"➡️ 原始\"\n",
|
||
" print(f\" 样本 {i}: {status}\")\n",
|
||
" print(f\" 原始: {original_stats}\")\n",
|
||
" print(f\" 处理: {augmented_stats}\")\n",
|
||
"\n",
|
||
"print(f\"\\n✅ 数据增强测试完成!\")\n",
|
||
"\n",
|
||
"# 显示数据增强配置\n",
|
||
"print(f\"\\n📋 当前数据增强配置:\")\n",
|
||
"print(f\" 白噪声标准差: {augmenter.white_noise_std}\")\n",
|
||
"print(f\" 常数偏移标准差: {augmenter.constant_offset_std}\")\n",
|
||
"print(f\" 随机游走标准差: {augmenter.random_walk_std}\")\n",
|
||
"print(f\" 静态增益标准差: {augmenter.static_gain_std}\")\n",
|
||
"print(f\" 随机切割步数: {augmenter.random_cut}\")\n",
|
||
"print(f\" 数据平滑: {augmenter.smooth_data}\")\n",
|
||
"if augmenter.smooth_data:\n",
|
||
" print(f\" 平滑核标准差: {augmenter.smooth_kernel_std}\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 🔥 执行智能数据处理管道"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 14,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"🚀 开始执行智能数据处理管道...\n",
|
||
"============================================================\n",
|
||
"\n",
|
||
"======================🔍 STEP 1: 分析数据分布======================\n",
|
||
"🔍 步骤1: 分析数据分布(第三小样本数策略)...\n",
|
||
" 已随机打乱 41 个文件的加载顺序\n",
|
||
" 正在加载文件 1/41: t15.2024.02.25_val_concatenated.npz\n",
|
||
" 正在加载文件 2/41: t15.2023.10.08_val_concatenated.npz\n",
|
||
" 正在加载文件 2/41: t15.2023.10.08_val_concatenated.npz\n",
|
||
" 正在加载文件 3/41: t15.2025.03.30_val_concatenated.npz\n",
|
||
" 正在加载文件 3/41: t15.2025.03.30_val_concatenated.npz\n",
|
||
" 正在加载文件 4/41: t15.2023.11.19_val_concatenated.npz\n",
|
||
" 正在加载文件 4/41: t15.2023.11.19_val_concatenated.npz\n",
|
||
" 正在加载文件 5/41: t15.2023.11.26_val_concatenated.npz\n",
|
||
" 正在加载文件 5/41: t15.2023.11.26_val_concatenated.npz\n",
|
||
" 正在加载文件 6/41: t15.2024.07.28_val_concatenated.npz\n",
|
||
" 正在加载文件 6/41: t15.2024.07.28_val_concatenated.npz\n",
|
||
" 正在加载文件 7/41: t15.2024.07.21_val_concatenated.npz\n",
|
||
" 正在加载文件 7/41: t15.2024.07.21_val_concatenated.npz\n",
|
||
" 正在加载文件 8/41: t15.2023.09.29_val_concatenated.npz\n",
|
||
" 正在加载文件 8/41: t15.2023.09.29_val_concatenated.npz\n",
|
||
" 正在加载文件 9/41: t15.2025.01.10_val_concatenated.npz\n",
|
||
" 正在加载文件 9/41: t15.2025.01.10_val_concatenated.npz\n",
|
||
" 正在加载文件 10/41: t15.2025.04.13_val_concatenated.npz\n",
|
||
" 正在加载文件 10/41: t15.2025.04.13_val_concatenated.npz\n",
|
||
" 正在加载文件 11/41: t15.2024.07.19_val_concatenated.npz\n",
|
||
" 正在加载文件 11/41: t15.2024.07.19_val_concatenated.npz\n",
|
||
" 正在加载文件 12/41: t15.2023.11.04_val_concatenated.npz\n",
|
||
" 正在加载文件 12/41: t15.2023.11.04_val_concatenated.npz\n",
|
||
" 正在加载文件 13/41: t15.2023.11.03_val_concatenated.npz\n",
|
||
" 正在加载文件 13/41: t15.2023.11.03_val_concatenated.npz\n",
|
||
" 所有标签样本数: [29, 56, 76, 78, 87, 87, 89, 135, 136, 147]...\n",
|
||
" 第三小样本数: 76\n",
|
||
" ✅ 分析完成: 105,617 样本\n",
|
||
" 📉 下采样标签: 38 个 → 76\n",
|
||
" ✅ 保持不变: 3 个\n",
|
||
" 🚫 不进行过采样\n",
|
||
"\n",
|
||
"📊 采样策略总结:\n",
|
||
" 📉 下采样标签: 38 个\n",
|
||
" 📈 过采样标签: 0 个\n",
|
||
" ✅ 保持不变: 3 个\n",
|
||
"\n",
|
||
"✅ 步骤1完成!\n",
|
||
" 所有标签样本数: [29, 56, 76, 78, 87, 87, 89, 135, 136, 147]...\n",
|
||
" 第三小样本数: 76\n",
|
||
" ✅ 分析完成: 105,617 样本\n",
|
||
" 📉 下采样标签: 38 个 → 76\n",
|
||
" ✅ 保持不变: 3 个\n",
|
||
" 🚫 不进行过采样\n",
|
||
"\n",
|
||
"📊 采样策略总结:\n",
|
||
" 📉 下采样标签: 38 个\n",
|
||
" 📈 过采样标签: 0 个\n",
|
||
" ✅ 保持不变: 3 个\n",
|
||
"\n",
|
||
"✅ 步骤1完成!\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# 🔥 执行智能数据处理管道【确定采样策略】\n",
|
||
"\n",
|
||
"print(\"🚀 开始执行智能数据处理管道...\")\n",
|
||
"print(\"=\" * 60)\n",
|
||
"\n",
|
||
"# 步骤1: 分析数据分布\n",
|
||
"print(\"\\n\" + \"🔍 STEP 1: 分析数据分布\".center(60, \"=\"))\n",
|
||
"distribution, strategy = pipeline.step1_analyze_distribution()\n",
|
||
"\n",
|
||
"# 显示采样策略总结\n",
|
||
"print(f\"\\n📊 采样策略总结:\")\n",
|
||
"undersample_count = sum(1 for s in strategy.values() if s['action'] == 'undersample')\n",
|
||
"oversample_count = sum(1 for s in strategy.values() if s['action'] == 'oversample')\n",
|
||
"keep_count = sum(1 for s in strategy.values() if s['action'] == 'keep')\n",
|
||
"\n",
|
||
"print(f\" 📉 下采样标签: {undersample_count} 个\")\n",
|
||
"print(f\" 📈 过采样标签: {oversample_count} 个\") \n",
|
||
"print(f\" ✅ 保持不变: {keep_count} 个\")\n",
|
||
"\n",
|
||
"print(\"\\n✅ 步骤1完成!\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 18,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"\n",
|
||
"=====================🔧 STEP 2: 拟合PCA参数======================\n",
|
||
"\n",
|
||
"🔧 步骤2: 拟合PCA参数(仅下采样,不过采样)...\n",
|
||
" 用于PCA拟合的样本数: 15,000\n",
|
||
" 用于PCA拟合的样本数: 15,000\n",
|
||
" PCA拟合完成: 7168 → 1243\n",
|
||
" 保留方差: 0.9489\n",
|
||
" PCA拟合完成: 7168 → 1243\n",
|
||
" 保留方差: 0.9489\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# 步骤2: 拟合PCA参数【确定PCA策略】\n",
|
||
"print(\"\\n\" + \"🔧 STEP 2: 拟合PCA参数\".center(60, \"=\"))\n",
|
||
"pipeline.step2_fit_pca_with_undersampling()\n",
|
||
"\n",
|
||
"# print(\"\\n✅ 步骤2完成!\")\n",
|
||
"# pipeline.print_summary()"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 🚀 使用智能管道进行分批训练"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 34,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 使用智能管道进行分批训练\n",
|
||
"import lightgbm as lgb\n",
|
||
"import time\n",
|
||
"from collections import Counter\n",
|
||
"import matplotlib.pyplot as plt\n",
|
||
"import random\n",
|
||
"import numpy as np\n",
|
||
"import os\n",
|
||
"import gc\n",
|
||
"\n",
|
||
"class SmartBatchTrainer:\n",
|
||
" \"\"\"\n",
|
||
" 智能分批训练器,集成智能数据管道\n",
|
||
" \"\"\"\n",
|
||
" \n",
|
||
" def __init__(self, pipeline, params=None, min_learning_rate=1e-4, t_0=50, t_mult=2):\n",
|
||
" self.pipeline = pipeline\n",
|
||
" self.model = None\n",
|
||
" self.training_history = {} # 改为字典,因为只有一次训练\n",
|
||
" self.batch_count = 0\n",
|
||
" self.min_learning_rate = min_learning_rate\n",
|
||
" self.lr_history = [] # 用于可视化\n",
|
||
" \n",
|
||
" # 带重启的余弦退火参数\n",
|
||
" self.t_0 = t_0 # 第一个重启周期的长度\n",
|
||
" self.t_mult = t_mult # 重启周期的乘数\n",
|
||
" \n",
|
||
" # 默认LightGBM参数(GPU优化)\n",
|
||
" self.params = params or {\n",
|
||
" 'objective': 'multiclass',\n",
|
||
" 'num_class': 41,\n",
|
||
" 'metric': 'multi_logloss',\n",
|
||
" 'boosting_type': 'gbdt',\n",
|
||
" 'device_type': 'cpu',\n",
|
||
" # 'gpu_platform_id': 0,\n",
|
||
" # 'gpu_device_id': 0,\n",
|
||
" 'max_bin': 255,\n",
|
||
" 'num_leaves': 127,\n",
|
||
" 'learning_rate': 0.10, #默认0.08\n",
|
||
" 'feature_fraction': 0.8,\n",
|
||
" 'bagging_fraction': 0.8,\n",
|
||
" 'bagging_freq': 5,\n",
|
||
" 'min_data_in_leaf': 20,\n",
|
||
" 'lambda_l1': 0.1,\n",
|
||
" 'lambda_l2': 0.1,\n",
|
||
" 'verbose': -1,\n",
|
||
" 'num_threads': -1\n",
|
||
" }\n",
|
||
" \n",
|
||
" self.initial_learning_rate = self.params.get('learning_rate', 0.08)\n",
|
||
" \n",
|
||
" print(f\"智能分批训练器创建完成\")\n",
|
||
" print(f\" LightGBM参数已配置:{self.params['device_type'].upper()}模式\")\n",
|
||
" print(f\" 学习率调度: 带重启的余弦退火 (从 {self.initial_learning_rate} 到 {self.min_learning_rate})\")\n",
|
||
" print(f\" 重启参数: T_0={self.t_0}, T_mult={self.t_mult}\")\n",
|
||
" \n",
|
||
" def prepare_validation_data(self):\n",
|
||
" \"\"\"准备验证数据(仅PCA,保持原始分布)\"\"\"\n",
|
||
" print(\"准备验证数据...\")\n",
|
||
" X_val, y_val = self.pipeline.step3_process_data('val', apply_sampling=False)\n",
|
||
" if X_val is None:\n",
|
||
" raise ValueError(\"无法加载验证数据\")\n",
|
||
" val_counts = Counter(y_val)\n",
|
||
" print(f\" 验证数据准备完成: {X_val.shape[0]:,} 样本\")\n",
|
||
" print(f\" 验证集分布 (标签0: {val_counts.get(0, 0):,}, 标签40: {val_counts.get(40, 0):,})\")\n",
|
||
" \n",
|
||
" # 缓存原始数组,便于计算accuracy\n",
|
||
" self._X_val_np = X_val\n",
|
||
" self._y_val_np = y_val\n",
|
||
" \n",
|
||
" return lgb.Dataset(X_val, label=y_val, free_raw_data=False)\n",
|
||
" \n",
|
||
" def get_training_batch_generator(self, n_files_per_batch=4, batch_size=8000):\n",
|
||
" \"\"\"改进的训练批次生成器:每次从所有文件中随机选择n个文件,然后随机采样\"\"\"\n",
|
||
" print(f\"准备改进的训练批次生成器...\")\n",
|
||
" print(f\" 每批次选择文件数: {n_files_per_batch}\")\n",
|
||
" print(f\" 每批次目标样本数: {batch_size:,}\")\n",
|
||
" \n",
|
||
" # 获取所有训练文件列表\n",
|
||
" all_train_files = [f for f in os.listdir(self.pipeline.data_dir) \n",
|
||
" if f.endswith('.npz') and 'train' in f]\n",
|
||
" \n",
|
||
" if len(all_train_files) < n_files_per_batch:\n",
|
||
" print(f\" 可用文件数({len(all_train_files)})少于每批次需要的文件数({n_files_per_batch})\")\n",
|
||
" n_files_per_batch = len(all_train_files)\n",
|
||
" \n",
|
||
" print(f\" 总计可用训练文件: {len(all_train_files)}\")\n",
|
||
" \n",
|
||
" batch_id = 0\n",
|
||
" while True: # 无限生成器,可以重复采样\n",
|
||
" batch_id += 1\n",
|
||
" \n",
|
||
" # 随机选择n个文件\n",
|
||
" selected_files = random.sample(all_train_files, n_files_per_batch)\n",
|
||
" \n",
|
||
" print(f\" 批次 {batch_id} - 随机选择的文件:\")\n",
|
||
" for i, f in enumerate(selected_files, 1):\n",
|
||
" print(f\" {i}. {f}\")\n",
|
||
" \n",
|
||
" # 从选中的文件中加载数据\n",
|
||
" all_features = []\n",
|
||
" all_labels = []\n",
|
||
" total_available_samples = 0\n",
|
||
" \n",
|
||
" for filename in selected_files:\n",
|
||
" # 加载文件数据\n",
|
||
" data = np.load(os.path.join(self.pipeline.data_dir, filename), allow_pickle=True)\n",
|
||
" trials = data['neural_logits_concatenated']\n",
|
||
" \n",
|
||
" # 提取特征和标签\n",
|
||
" features, labels = extract_features_labels_batch(trials)\n",
|
||
" \n",
|
||
" if features.shape[0] > 0:\n",
|
||
" all_features.append(features)\n",
|
||
" all_labels.append(labels)\n",
|
||
" total_available_samples += features.shape[0]\n",
|
||
" \n",
|
||
" # 清理单个文件数据\n",
|
||
" del data, trials\n",
|
||
" gc.collect()\n",
|
||
" \n",
|
||
" if all_features:\n",
|
||
" # 合并所有选中文件的数据\n",
|
||
" combined_features = np.vstack(all_features)\n",
|
||
" combined_labels = np.hstack(all_labels)\n",
|
||
" \n",
|
||
" print(f\" 合并后总样本数: {combined_features.shape[0]:,}\")\n",
|
||
" \n",
|
||
" # 随机采样到目标batch_size\n",
|
||
" if combined_features.shape[0] > batch_size:\n",
|
||
" # 随机选择batch_size个样本\n",
|
||
" sample_indices = np.random.choice(\n",
|
||
" combined_features.shape[0], \n",
|
||
" size=batch_size, \n",
|
||
" replace=False\n",
|
||
" )\n",
|
||
" sampled_features = combined_features[sample_indices]\n",
|
||
" sampled_labels = combined_labels[sample_indices]\n",
|
||
" print(f\" 随机采样到: {batch_size:,} 样本\")\n",
|
||
" else:\n",
|
||
" # 如果样本不足,使用所有样本\n",
|
||
" sampled_features = combined_features\n",
|
||
" sampled_labels = combined_labels\n",
|
||
" print(f\" 样本不足,使用全部: {sampled_features.shape[0]:,} 样本\")\n",
|
||
" \n",
|
||
" # 应用采样策略(平衡处理)\n",
|
||
" features_balanced, labels_balanced = self.pipeline._apply_full_sampling(\n",
|
||
" sampled_features, sampled_labels\n",
|
||
" )\n",
|
||
" \n",
|
||
" # 应用PCA降维\n",
|
||
" if features_balanced.shape[0] > 0:\n",
|
||
" features_pca = self.pipeline._apply_pca_transform(features_balanced)\n",
|
||
" \n",
|
||
" # 分析当前批次分布\n",
|
||
" batch_counts = Counter(labels_balanced)\n",
|
||
" \n",
|
||
" print(f\" 批次 {batch_id} 最终结果:\")\n",
|
||
" print(f\" 平衡后样本数: {features_pca.shape[0]:,}\")\n",
|
||
" print(f\" 特征维度: {features_pca.shape[1]}\")\n",
|
||
" print(f\" 分布: 标签0={batch_counts.get(0,0)}, 标签40={batch_counts.get(40,0)}\")\n",
|
||
" print(f\" \" + \"=\"*50)\n",
|
||
" \n",
|
||
" # 重要修复:设置 free_raw_data=False 避免增量训练失败\n",
|
||
" yield lgb.Dataset(features_pca, label=labels_balanced, free_raw_data=False), f\"batch_{batch_id}_files_{len(selected_files)}\"\n",
|
||
" \n",
|
||
" # 清理批次数据\n",
|
||
" del all_features, all_labels, combined_features, combined_labels\n",
|
||
" del sampled_features, sampled_labels, features_balanced, labels_balanced\n",
|
||
" gc.collect()\n",
|
||
" else:\n",
|
||
" print(f\" 批次 {batch_id} 无有效数据\")\n",
|
||
" continue\n",
|
||
" \n",
|
||
" def prepare_full_data(self):\n",
|
||
" \"\"\"一次性准备所有训练和验证数据\"\"\"\n",
|
||
" print(\"准备全量训练和验证数据...\")\n",
|
||
" \n",
|
||
" # 1. 准备验证数据 (保持原始分布)\n",
|
||
" X_val, y_val = self.pipeline.step3_process_data('val', apply_sampling=False)\n",
|
||
" if X_val is None:\n",
|
||
" raise ValueError(\"无法加载验证数据\")\n",
|
||
" val_counts = Counter(y_val)\n",
|
||
" print(f\" 验证数据准备完成: {X_val.shape[0]:,} 样本\")\n",
|
||
" print(f\" 验证集分布 (标签0: {val_counts.get(0, 0):,}, 标签40: {val_counts.get(40, 0):,})\")\n",
|
||
" val_data = lgb.Dataset(X_val, label=y_val, free_raw_data=False)\n",
|
||
" \n",
|
||
" # 2. 准备训练数据 (应用完整采样和PCA策略)\n",
|
||
" X_train, y_train = self.pipeline.step3_process_data('train', apply_sampling=True)\n",
|
||
" if X_train is None:\n",
|
||
" raise ValueError(\"无法加载训练数据\")\n",
|
||
" train_counts = Counter(y_train)\n",
|
||
" print(f\" 训练数据准备完成: {X_train.shape[0]:,} 样本, {X_train.shape[1]} 特征\")\n",
|
||
" print(f\" 训练集(采样后)分布 (标签0: {train_counts.get(0, 0):,}, 标签40: {train_counts.get(40, 0):,})\")\n",
|
||
" train_data = lgb.Dataset(X_train, label=y_train)\n",
|
||
" \n",
|
||
" return train_data, val_data, X_val, y_val\n",
|
||
" \n",
|
||
" def prepare_training_data(self):\n",
|
||
" \"\"\"准备训练数据(仅PCA,保持原始分布)\"\"\"\n",
|
||
" print(\"准备训练数据...\")\n",
|
||
" # 2. 准备训练数据 (应用完整采样和PCA策略)\n",
|
||
" X_train, y_train = self.pipeline.step3_process_data('train', apply_sampling=True)\n",
|
||
" if X_train is None:\n",
|
||
" raise ValueError(\"无法加载训练数据\")\n",
|
||
" train_counts = Counter(y_train)\n",
|
||
" print(f\" 训练数据准备完成: {X_train.shape[0]:,} 样本, {X_train.shape[1]} 特征\")\n",
|
||
" print(f\" 训练集(采样后)分布 (标签0: {train_counts.get(0, 0):,}, 标签40: {train_counts.get(40, 0):,})\")\n",
|
||
" \n",
|
||
" return lgb.Dataset(X_train, label=y_train, free_raw_data=False)\n",
|
||
" \n",
|
||
" # 带重启的余弦退火调度器函数\n",
|
||
" def _cosine_annealing_with_warm_restarts(self, current_round):\n",
|
||
" \"\"\"带重启的余弦退火调度器 (SGDR)\"\"\"\n",
|
||
" eta_max = self.initial_learning_rate\n",
|
||
" eta_min = self.min_learning_rate\n",
|
||
" \n",
|
||
" # 计算当前在哪个重启周期中\n",
|
||
" t_cur = current_round\n",
|
||
" t_i = self.t_0\n",
|
||
" \n",
|
||
" # 找到当前的重启周期\n",
|
||
" cycle = 0\n",
|
||
" while t_cur >= t_i:\n",
|
||
" t_cur -= t_i\n",
|
||
" cycle += 1\n",
|
||
" t_i *= self.t_mult\n",
|
||
" \n",
|
||
" # 在当前周期内的位置\n",
|
||
" progress = t_cur / t_i\n",
|
||
" \n",
|
||
" # 计算学习率\n",
|
||
" lr = eta_min + 0.5 * (eta_max - eta_min) * (1 + np.cos(np.pi * progress))\n",
|
||
" \n",
|
||
" return lr\n",
|
||
" \n",
|
||
" def train_incremental(self, num_boost_round=100, early_stopping_rounds=10, \n",
|
||
" n_files_per_batch=4, batch_size=8000, max_batches=None):\n",
|
||
" \"\"\"增量分批训练 - 支持自定义批次参数\"\"\"\n",
|
||
" print(f\"开始智能分批训练...\")\n",
|
||
" print(f\" 训练轮数 (每批次): {num_boost_round}\")\n",
|
||
" print(f\" 早停轮数: {early_stopping_rounds}\")\n",
|
||
" print(f\" 每批次文件数: {n_files_per_batch}\")\n",
|
||
" print(f\" 每批次样本数: {batch_size:,}\")\n",
|
||
" if max_batches:\n",
|
||
" print(f\" 最大批次数: {max_batches}\")\n",
|
||
" print(\"=\" * 60)\n",
|
||
" \n",
|
||
" # 准备验证数据\n",
|
||
" val_data = self.prepare_validation_data()\n",
|
||
" \n",
|
||
" print(f\"开始分批增量训练...\")\n",
|
||
" total_start_time = time.time()\n",
|
||
" \n",
|
||
" # 初始化训练历史\n",
|
||
" self.training_history = []\n",
|
||
" \n",
|
||
" # 创建改进的生成器\n",
|
||
" batch_generator = self.get_training_batch_generator(n_files_per_batch, batch_size)\n",
|
||
" \n",
|
||
" for train_data, batch_name in batch_generator:\n",
|
||
" self.batch_count += 1\n",
|
||
" batch_start_time = time.time()\n",
|
||
" \n",
|
||
" # 检查是否达到最大批次数\n",
|
||
" if max_batches and self.batch_count > max_batches:\n",
|
||
" print(f\"达到最大批次数 {max_batches},停止训练\")\n",
|
||
" break\n",
|
||
" \n",
|
||
" # 先构建数据集,使得可以安全访问 num_data()\n",
|
||
" try:\n",
|
||
" train_data.construct()\n",
|
||
" except Exception:\n",
|
||
" pass\n",
|
||
"\n",
|
||
" print(f\"\\n批次 {self.batch_count}: {batch_name}\")\n",
|
||
" try:\n",
|
||
" print(f\" 样本数: {train_data.num_data():,}\")\n",
|
||
" except Exception:\n",
|
||
" print(\" 样本数: (未构建,跳过显示)\")\n",
|
||
" \n",
|
||
" # 计算当前批次的学习率\n",
|
||
" current_lr = self._cosine_annealing_with_warm_restarts(\n",
|
||
" (self.batch_count - 1) * num_boost_round\n",
|
||
" )\n",
|
||
" \n",
|
||
" # 更新训练参数中的学习率\n",
|
||
" current_params = self.params.copy()\n",
|
||
" current_params['learning_rate'] = current_lr\n",
|
||
" \n",
|
||
" try:\n",
|
||
" # 训练参数\n",
|
||
" train_params = {\n",
|
||
" 'params': current_params,\n",
|
||
" 'train_set': train_data,\n",
|
||
" 'num_boost_round': num_boost_round,\n",
|
||
" 'valid_sets': [val_data],\n",
|
||
" 'valid_names': ['validation'],\n",
|
||
" 'callbacks': [\n",
|
||
" lgb.log_evaluation(period=1, show_stdv=False) # 1轮打印一次,减少重复\n",
|
||
" ]\n",
|
||
" }\n",
|
||
" \n",
|
||
" # 如果有早停设置\n",
|
||
" if early_stopping_rounds:\n",
|
||
" train_params['callbacks'].append(\n",
|
||
" lgb.early_stopping(early_stopping_rounds, verbose=False)\n",
|
||
" )\n",
|
||
" \n",
|
||
" # 增量训练\n",
|
||
" if self.model is None:\n",
|
||
" # 第一次训练\n",
|
||
" print(f\" 首次训练 (学习率: {current_lr:.6f})\")\n",
|
||
" self.model = lgb.train(**train_params)\n",
|
||
" else:\n",
|
||
" # 增量训练\n",
|
||
" print(f\" 增量训练 (学习率: {current_lr:.6f})\")\n",
|
||
" train_params['init_model'] = self.model\n",
|
||
" self.model = lgb.train(**train_params)\n",
|
||
" \n",
|
||
" # 验证 - 修复数组比较的歧义性问题\n",
|
||
" # 优先使用缓存的验证集数组,退回到val_data中的数据\n",
|
||
" Xv = getattr(self, '_X_val_np', None) \n",
|
||
" yv = getattr(self, '_y_val_np', None)\n",
|
||
" \n",
|
||
" if Xv is None or yv is None:\n",
|
||
" print(\" 警告: 无法获取验证数据,跳过准确率计算\")\n",
|
||
" val_accuracy = 0.0\n",
|
||
" else:\n",
|
||
" val_pred = self.model.predict(Xv)\n",
|
||
" \n",
|
||
" # 确保yv是1维numpy数组,避免数组比较的歧义\n",
|
||
" yv = np.asarray(yv, dtype=int).flatten()\n",
|
||
" \n",
|
||
" # 计算验证准确率\n",
|
||
" pred_labels = np.argmax(val_pred, axis=1)\n",
|
||
" pred_labels = np.asarray(pred_labels, dtype=int).flatten()\n",
|
||
" \n",
|
||
" # 确保两个数组形状一致\n",
|
||
" if len(pred_labels) != len(yv):\n",
|
||
" print(f\" 警告: 预测标签数({len(pred_labels)}) != 真实标签数({len(yv)})\")\n",
|
||
" min_len = min(len(pred_labels), len(yv))\n",
|
||
" pred_labels = pred_labels[:min_len]\n",
|
||
" yv = yv[:min_len]\n",
|
||
" \n",
|
||
" # 使用更安全的数组比较方式\n",
|
||
" try:\n",
|
||
" comparison = np.equal(pred_labels, yv)\n",
|
||
" val_accuracy = float(np.mean(comparison))\n",
|
||
" except Exception as e:\n",
|
||
" print(f\" 数组比较错误: {e}\")\n",
|
||
" val_accuracy = 0.0\n",
|
||
" \n",
|
||
" # 记录训练历史\n",
|
||
" batch_time = time.time() - batch_start_time\n",
|
||
" try:\n",
|
||
" samples_cnt = train_data.num_data()\n",
|
||
" except Exception:\n",
|
||
" samples_cnt = None\n",
|
||
" self.training_history.append({\n",
|
||
" 'batch': self.batch_count,\n",
|
||
" 'batch_name': batch_name,\n",
|
||
" 'val_accuracy': val_accuracy,\n",
|
||
" 'time': batch_time,\n",
|
||
" 'num_trees': self.model.num_trees(),\n",
|
||
" 'learning_rate': current_lr,\n",
|
||
" 'samples': samples_cnt\n",
|
||
" })\n",
|
||
" \n",
|
||
" print(f\" 批次完成:\")\n",
|
||
" print(f\" 验证准确率: {val_accuracy:.4f}\")\n",
|
||
" print(f\" 训练时间: {batch_time:.1f}秒\")\n",
|
||
" print(f\" 模型树数: {self.model.num_trees()}\")\n",
|
||
" print(f\" 当前学习率: {current_lr:.6f}\")\n",
|
||
" \n",
|
||
" except Exception as e:\n",
|
||
" print(f\" 批次训练失败: {e}\")\n",
|
||
" import traceback\n",
|
||
" traceback.print_exc()\n",
|
||
" continue\n",
|
||
" \n",
|
||
" # 训练完成\n",
|
||
" total_time = time.time() - total_start_time\n",
|
||
" print(f\"\\n增量训练完成!\")\n",
|
||
" print(f\" 总批次数: {len(self.training_history)}\")\n",
|
||
" print(f\" 总训练时间: {total_time:.1f}秒\")\n",
|
||
" \n",
|
||
" if self.training_history:\n",
|
||
" best_batch = max(self.training_history, key=lambda x: x['val_accuracy'])\n",
|
||
" print(f\" 最佳准确率: {best_batch['val_accuracy']:.4f} (批次 {best_batch['batch']})\")\n",
|
||
" final_accuracy = self.training_history[-1]['val_accuracy']\n",
|
||
" print(f\" 最终准确率: {final_accuracy:.4f}\")\n",
|
||
" \n",
|
||
" return self.model\n",
|
||
"\n",
|
||
" @staticmethod\n",
|
||
" def _ctc_collapse(seq, blank=0, drop_sep40=False):\n",
|
||
" out = []\n",
|
||
" prev = None\n",
|
||
" for s in seq:\n",
|
||
" if s == prev:\n",
|
||
" continue\n",
|
||
" prev = s\n",
|
||
" if s == blank:\n",
|
||
" continue\n",
|
||
" if drop_sep40 and s == 40:\n",
|
||
" continue\n",
|
||
" out.append(int(s))\n",
|
||
" return out\n",
|
||
"\n",
|
||
" @staticmethod\n",
|
||
" def _levenshtein(a, b):\n",
|
||
" # a, b are lists of ints\n",
|
||
" n, m = len(a), len(b)\n",
|
||
" if n == 0:\n",
|
||
" return m\n",
|
||
" if m == 0:\n",
|
||
" return n\n",
|
||
" dp = list(range(m + 1))\n",
|
||
" for i in range(1, n + 1):\n",
|
||
" prev = dp[0]\n",
|
||
" dp[0] = i\n",
|
||
" ai = a[i - 1]\n",
|
||
" for j in range(1, m + 1):\n",
|
||
" tmp = dp[j]\n",
|
||
" cost = 0 if ai == b[j - 1] else 1\n",
|
||
" dp[j] = min(dp[j] + 1, # deletion\n",
|
||
" dp[j - 1] + 1, # insertion\n",
|
||
" prev + cost) # substitution\n",
|
||
" prev = tmp\n",
|
||
" return dp[m]\n",
|
||
"\n",
|
||
" def evaluate_val_per_experiment(self, fraction=0.33, random_state=42, drop_sep40=False, max_trials_per_file=None):\n",
|
||
" \"\"\"使用所有验证文件,每个文件抽取33%的trial,按trial计算PER并求均值\"\"\"\n",
|
||
" if self.model is None:\n",
|
||
" raise RuntimeError(\"模型尚未训练,无法评估PER\")\n",
|
||
"\n",
|
||
" rng = np.random.RandomState(random_state)\n",
|
||
" val_files = [f for f in os.listdir(self.pipeline.data_dir) if f.endswith('.npz') and 'val' in f]\n",
|
||
" if not val_files:\n",
|
||
" raise FileNotFoundError(\"未找到验证集npz文件\")\n",
|
||
"\n",
|
||
" results_by_file = {}\n",
|
||
" per_list = []\n",
|
||
" corpus_edit = 0\n",
|
||
" corpus_len = 0\n",
|
||
" total_trials = 0\n",
|
||
"\n",
|
||
" for vf in sorted(val_files):\n",
|
||
" data = np.load(os.path.join(self.pipeline.data_dir, vf), allow_pickle=True)\n",
|
||
" trials = data['neural_logits_concatenated']\n",
|
||
" n_trials = len(trials)\n",
|
||
" if n_trials == 0:\n",
|
||
" results_by_file[vf] = {'n': 0, 'mean_PER': None}\n",
|
||
" continue\n",
|
||
" k = max(1, int(np.ceil(n_trials * fraction)))\n",
|
||
" idx = np.arange(n_trials)\n",
|
||
" idx = rng.choice(idx, size=k, replace=False)\n",
|
||
" if max_trials_per_file is not None:\n",
|
||
" k = min(k, max_trials_per_file)\n",
|
||
" idx = idx[:k]\n",
|
||
"\n",
|
||
" trial_pers = []\n",
|
||
" for ti in idx:\n",
|
||
" tr = trials[ti]\n",
|
||
" X_trial = tr[:, :7168]\n",
|
||
" rnn_logits = tr[:, 7168:]\n",
|
||
" # 变换到PCA空间\n",
|
||
" X_trial_pca = self.pipeline._apply_pca_transform(X_trial)\n",
|
||
" # 预测\n",
|
||
" pred_proba = self.model.predict(X_trial_pca)\n",
|
||
" y_pred_seq = np.argmax(pred_proba, axis=1)\n",
|
||
" y_true_seq = np.argmax(rnn_logits, axis=1)\n",
|
||
" # CTC折叠\n",
|
||
" pred_collapsed = self._ctc_collapse(y_pred_seq, blank=0, drop_sep40=drop_sep40)\n",
|
||
" true_collapsed = self._ctc_collapse(y_true_seq, blank=0, drop_sep40=drop_sep40)\n",
|
||
" if len(true_collapsed) == 0:\n",
|
||
" continue\n",
|
||
" ed = self._levenshtein(pred_collapsed, true_collapsed)\n",
|
||
" per = ed / len(true_collapsed)\n",
|
||
" trial_pers.append(per)\n",
|
||
" corpus_edit += ed\n",
|
||
" corpus_len += len(true_collapsed)\n",
|
||
" total_trials += 1\n",
|
||
"\n",
|
||
" if trial_pers:\n",
|
||
" results_by_file[vf] = {\n",
|
||
" 'n': len(trial_pers),\n",
|
||
" 'mean_PER': float(np.mean(trial_pers))\n",
|
||
" }\n",
|
||
" per_list.extend(trial_pers)\n",
|
||
" else:\n",
|
||
" results_by_file[vf] = {'n': 0, 'mean_PER': None}\n",
|
||
"\n",
|
||
" del data, trials\n",
|
||
" gc.collect()\n",
|
||
"\n",
|
||
" overall_mean = float(np.mean(per_list)) if per_list else None\n",
|
||
" corpus_per = float(corpus_edit / corpus_len) if corpus_len > 0 else None\n",
|
||
"\n",
|
||
" summary = {\n",
|
||
" 'overall_mean_PER': overall_mean,\n",
|
||
" 'corpus_PER': corpus_per,\n",
|
||
" 'total_trials': total_trials,\n",
|
||
" 'per_file': results_by_file\n",
|
||
" }\n",
|
||
" print(\"验证集PER评估完成\")\n",
|
||
" print(f\" 文件数: {len(val_files)} 评估trial数: {total_trials}\")\n",
|
||
" print(f\" 平均PER(逐trial取均值): {overall_mean}\")\n",
|
||
" print(f\" 语料级PER(总编辑距离/总长度): {corpus_per}\")\n",
|
||
" return summary"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": 28,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"✅ SmartBatchTrainer已添加数据增强支持\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# 🔧 为SmartBatchTrainer添加数据增强支持\n",
|
||
"\n",
|
||
"def prepare_training_data_with_augmentation(self, apply_augmentation=True, augment_ratio=0.3):\n",
|
||
" \"\"\"准备训练数据(应用采样+数据增强+PCA)\"\"\"\n",
|
||
" print(\"准备训练数据(含数据增强)...\")\n",
|
||
" \n",
|
||
" if apply_augmentation:\n",
|
||
" print(f\" 数据增强: 启用 (比例: {augment_ratio})\")\n",
|
||
" X_train, y_train = self.pipeline.step3_process_data_with_augmentation(\n",
|
||
" 'train', apply_sampling=True, apply_augmentation=True, augment_ratio=augment_ratio\n",
|
||
" )\n",
|
||
" else:\n",
|
||
" print(f\" 数据增强: 禁用\")\n",
|
||
" X_train, y_train = self.pipeline.step3_process_data('train', apply_sampling=True)\n",
|
||
" \n",
|
||
" if X_train is None:\n",
|
||
" raise ValueError(\"无法加载训练数据\")\n",
|
||
" \n",
|
||
" train_counts = Counter(y_train)\n",
|
||
" print(f\" 训练数据准备完成: {X_train.shape[0]:,} 样本, {X_train.shape[1]} 特征\")\n",
|
||
" print(f\" 训练集(采样+增强后)分布 (标签0: {train_counts.get(0, 0):,}, 标签40: {train_counts.get(40, 0):,})\")\n",
|
||
" \n",
|
||
" return lgb.Dataset(X_train, label=y_train, free_raw_data=False)\n",
|
||
"\n",
|
||
"def prepare_full_data_with_augmentation(self, apply_augmentation=True, augment_ratio=0.3):\n",
|
||
" \"\"\"一次性准备所有训练和验证数据(含数据增强)\"\"\"\n",
|
||
" print(\"准备全量训练和验证数据(含数据增强)...\")\n",
|
||
" \n",
|
||
" # 1. 准备验证数据 (保持原始分布,不增强)\n",
|
||
" X_val, y_val = self.pipeline.step3_process_data('val', apply_sampling=False)\n",
|
||
" if X_val is None:\n",
|
||
" raise ValueError(\"无法加载验证数据\")\n",
|
||
" val_counts = Counter(y_val)\n",
|
||
" print(f\" 验证数据准备完成: {X_val.shape[0]:,} 样本\")\n",
|
||
" print(f\" 验证集分布 (标签0: {val_counts.get(0, 0):,}, 标签40: {val_counts.get(40, 0):,})\")\n",
|
||
" val_data = lgb.Dataset(X_val, label=y_val, free_raw_data=False)\n",
|
||
" \n",
|
||
" # 2. 准备训练数据 (应用完整采样和数据增强)\n",
|
||
" if apply_augmentation:\n",
|
||
" print(f\" 训练数据增强: 启用 (比例: {augment_ratio})\")\n",
|
||
" X_train, y_train = self.pipeline.step3_process_data_with_augmentation(\n",
|
||
" 'train', apply_sampling=True, apply_augmentation=True, augment_ratio=augment_ratio\n",
|
||
" )\n",
|
||
" else:\n",
|
||
" print(f\" 训练数据增强: 禁用\")\n",
|
||
" X_train, y_train = self.pipeline.step3_process_data('train', apply_sampling=True)\n",
|
||
" \n",
|
||
" if X_train is None:\n",
|
||
" raise ValueError(\"无法加载训练数据\")\n",
|
||
" train_counts = Counter(y_train)\n",
|
||
" print(f\" 训练数据准备完成: {X_train.shape[0]:,} 样本, {X_train.shape[1]} 特征\")\n",
|
||
" print(f\" 训练集(采样+增强后)分布 (标签0: {train_counts.get(0, 0):,}, 标签40: {train_counts.get(40, 0):,})\")\n",
|
||
" train_data = lgb.Dataset(X_train, label=y_train)\n",
|
||
" \n",
|
||
" return train_data, val_data, X_val, y_val\n",
|
||
"\n",
|
||
"def get_training_batch_generator_with_augmentation(self, n_files_per_batch=4, batch_size=8000, \n",
|
||
" apply_augmentation=True, augment_ratio=0.3):\n",
|
||
" \"\"\"改进的训练批次生成器:每次从所有文件中随机选择n个文件,然后随机采样+数据增强\"\"\"\n",
|
||
" print(f\"准备改进的训练批次生成器(含数据增强)...\")\n",
|
||
" print(f\" 每批次选择文件数: {n_files_per_batch}\")\n",
|
||
" print(f\" 每批次目标样本数: {batch_size:,}\")\n",
|
||
" print(f\" 数据增强: {'启用' if apply_augmentation else '禁用'}\")\n",
|
||
" if apply_augmentation:\n",
|
||
" print(f\" 增强比例: {augment_ratio}\")\n",
|
||
" \n",
|
||
" # 获取所有训练文件列表\n",
|
||
" all_train_files = [f for f in os.listdir(self.pipeline.data_dir) \n",
|
||
" if f.endswith('.npz') and 'train' in f]\n",
|
||
" \n",
|
||
" if len(all_train_files) < n_files_per_batch:\n",
|
||
" print(f\" 可用文件数({len(all_train_files)})少于每批次需要的文件数({n_files_per_batch})\")\n",
|
||
" n_files_per_batch = len(all_train_files)\n",
|
||
" \n",
|
||
" print(f\" 总计可用训练文件: {len(all_train_files)}\")\n",
|
||
" \n",
|
||
" batch_id = 0\n",
|
||
" while True: # 无限生成器,可以重复采样\n",
|
||
" batch_id += 1\n",
|
||
" \n",
|
||
" # 随机选择n个文件\n",
|
||
" selected_files = random.sample(all_train_files, n_files_per_batch)\n",
|
||
" \n",
|
||
" print(f\" 批次 {batch_id} - 随机选择的文件:\")\n",
|
||
" for i, f in enumerate(selected_files, 1):\n",
|
||
" print(f\" {i}. {f}\")\n",
|
||
" \n",
|
||
" # 从选中的文件中加载数据\n",
|
||
" all_features = []\n",
|
||
" all_labels = []\n",
|
||
" total_available_samples = 0\n",
|
||
" \n",
|
||
" for filename in selected_files:\n",
|
||
" # 加载文件数据\n",
|
||
" data = np.load(os.path.join(self.pipeline.data_dir, filename), allow_pickle=True)\n",
|
||
" trials = data['neural_logits_concatenated']\n",
|
||
" \n",
|
||
" # 提取特征和标签(带数据增强)\n",
|
||
" features, labels = extract_features_labels_batch_with_augmentation(\n",
|
||
" trials, \n",
|
||
" random_shuffle_trials=True,\n",
|
||
" apply_augmentation=apply_augmentation,\n",
|
||
" augment_ratio=augment_ratio\n",
|
||
" )\n",
|
||
" \n",
|
||
" if features.shape[0] > 0:\n",
|
||
" all_features.append(features)\n",
|
||
" all_labels.append(labels)\n",
|
||
" total_available_samples += features.shape[0]\n",
|
||
" \n",
|
||
" # 清理单个文件数据\n",
|
||
" del data, trials\n",
|
||
" gc.collect()\n",
|
||
" \n",
|
||
" if all_features:\n",
|
||
" # 合并所有选中文件的数据\n",
|
||
" combined_features = np.vstack(all_features)\n",
|
||
" combined_labels = np.hstack(all_labels)\n",
|
||
" \n",
|
||
" print(f\" 合并后总样本数: {combined_features.shape[0]:,}\")\n",
|
||
" \n",
|
||
" # 随机采样到目标batch_size\n",
|
||
" if combined_features.shape[0] > batch_size:\n",
|
||
" # 随机选择batch_size个样本\n",
|
||
" sample_indices = np.random.choice(\n",
|
||
" combined_features.shape[0], \n",
|
||
" size=batch_size, \n",
|
||
" replace=False\n",
|
||
" )\n",
|
||
" sampled_features = combined_features[sample_indices]\n",
|
||
" sampled_labels = combined_labels[sample_indices]\n",
|
||
" print(f\" 随机采样到: {batch_size:,} 样本\")\n",
|
||
" else:\n",
|
||
" # 如果样本不足,使用所有样本\n",
|
||
" sampled_features = combined_features\n",
|
||
" sampled_labels = combined_labels\n",
|
||
" print(f\" 样本不足,使用全部: {sampled_features.shape[0]:,} 样本\")\n",
|
||
" \n",
|
||
" # 应用采样策略(平衡处理)\n",
|
||
" features_balanced, labels_balanced = self.pipeline._apply_full_sampling(\n",
|
||
" sampled_features, sampled_labels\n",
|
||
" )\n",
|
||
" \n",
|
||
" # 应用PCA降维\n",
|
||
" if features_balanced.shape[0] > 0:\n",
|
||
" features_pca = self.pipeline._apply_pca_transform(features_balanced)\n",
|
||
" \n",
|
||
" # 分析当前批次分布\n",
|
||
" batch_counts = Counter(labels_balanced)\n",
|
||
" \n",
|
||
" print(f\" 批次 {batch_id} 最终结果:\")\n",
|
||
" print(f\" 平衡后样本数: {features_pca.shape[0]:,}\")\n",
|
||
" print(f\" 特征维度: {features_pca.shape[1]}\")\n",
|
||
" print(f\" 分布: 标签0={batch_counts.get(0,0)}, 标签40={batch_counts.get(40,0)}\")\n",
|
||
" print(f\" \" + \"=\"*50)\n",
|
||
" \n",
|
||
" # 重要修复:设置 free_raw_data=False 避免增量训练失败\n",
|
||
" yield lgb.Dataset(features_pca, label=labels_balanced, free_raw_data=False), f\"batch_{batch_id}_files_{len(selected_files)}_aug_{apply_augmentation}\"\n",
|
||
" \n",
|
||
" # 清理批次数据\n",
|
||
" del all_features, all_labels, combined_features, combined_labels\n",
|
||
" del sampled_features, sampled_labels, features_balanced, labels_balanced\n",
|
||
" gc.collect()\n",
|
||
" else:\n",
|
||
" print(f\" 批次 {batch_id} 无有效数据\")\n",
|
||
" continue\n",
|
||
"\n",
|
||
"# 动态添加数据增强支持的方法到SmartBatchTrainer类\n",
|
||
"SmartBatchTrainer.prepare_training_data_with_augmentation = prepare_training_data_with_augmentation\n",
|
||
"SmartBatchTrainer.prepare_full_data_with_augmentation = prepare_full_data_with_augmentation\n",
|
||
"SmartBatchTrainer.get_training_batch_generator_with_augmentation = get_training_batch_generator_with_augmentation\n",
|
||
"\n",
|
||
"print(\"✅ SmartBatchTrainer已添加数据增强支持\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"智能分批训练器创建完成\n",
|
||
" LightGBM参数已配置:CPU模式\n",
|
||
" 学习率调度: 带重启的余弦退火 (从 0.1 到 0.001)\n",
|
||
" 重启参数: T_0=50, T_mult=2\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"trainer = SmartBatchTrainer(pipeline, min_learning_rate=0.001, t_0=30, t_mult=2)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"开始智能分批训练...\n",
|
||
" 训练轮数 (每批次): 5\n",
|
||
" 早停轮数: 10\n",
|
||
" 每批次文件数: 8\n",
|
||
" 每批次样本数: 50,000\n",
|
||
" 最大批次数: 100\n",
|
||
"============================================================\n",
|
||
"准备验证数据...\n",
|
||
"\n",
|
||
"处理val数据...\n",
|
||
" 采样策略: 禁用\n",
|
||
" 完成: 321,773 样本, 1243 特征\n",
|
||
" 验证数据准备完成: 321,773 样本\n",
|
||
" 验证集分布 (标签0: 238,705, 标签40: 35,425)\n",
|
||
"开始分批增量训练...\n",
|
||
"准备改进的训练批次生成器...\n",
|
||
" 每批次选择文件数: 8\n",
|
||
" 每批次目标样本数: 50,000\n",
|
||
" 总计可用训练文件: 45\n",
|
||
" 批次 1 - 随机选择的文件:\n",
|
||
" 1. t15.2023.11.26_train_concatenated.npz\n",
|
||
" 2. t15.2024.04.28_train_concatenated.npz\n",
|
||
" 3. t15.2023.10.01_train_concatenated.npz\n",
|
||
" 4. t15.2025.04.13_train_concatenated.npz\n",
|
||
" 5. t15.2024.02.25_train_concatenated.npz\n",
|
||
" 6. t15.2023.08.20_train_concatenated.npz\n",
|
||
" 7. t15.2023.12.08_train_concatenated.npz\n",
|
||
" 8. t15.2023.10.06_train_concatenated.npz\n",
|
||
" 完成: 321,773 样本, 1243 特征\n",
|
||
" 验证数据准备完成: 321,773 样本\n",
|
||
" 验证集分布 (标签0: 238,705, 标签40: 35,425)\n",
|
||
"开始分批增量训练...\n",
|
||
"准备改进的训练批次生成器...\n",
|
||
" 每批次选择文件数: 8\n",
|
||
" 每批次目标样本数: 50,000\n",
|
||
" 总计可用训练文件: 45\n",
|
||
" 批次 1 - 随机选择的文件:\n",
|
||
" 1. t15.2023.11.26_train_concatenated.npz\n",
|
||
" 2. t15.2024.04.28_train_concatenated.npz\n",
|
||
" 3. t15.2023.10.01_train_concatenated.npz\n",
|
||
" 4. t15.2025.04.13_train_concatenated.npz\n",
|
||
" 5. t15.2024.02.25_train_concatenated.npz\n",
|
||
" 6. t15.2023.08.20_train_concatenated.npz\n",
|
||
" 7. t15.2023.12.08_train_concatenated.npz\n",
|
||
" 8. t15.2023.10.06_train_concatenated.npz\n",
|
||
" 合并后总样本数: 307,619\n",
|
||
" 合并后总样本数: 307,619\n",
|
||
" 随机采样到: 50,000 样本\n",
|
||
" 随机采样到: 50,000 样本\n",
|
||
" 批次 1 最终结果:\n",
|
||
" 平衡后样本数: 2,836\n",
|
||
" 特征维度: 1243\n",
|
||
" 分布: 标签0=76, 标签40=76\n",
|
||
" ==================================================\n",
|
||
" 批次 1 最终结果:\n",
|
||
" 平衡后样本数: 2,836\n",
|
||
" 特征维度: 1243\n",
|
||
" 分布: 标签0=76, 标签40=76\n",
|
||
" ==================================================\n",
|
||
"\n",
|
||
"批次 1: batch_1_files_8\n",
|
||
" 样本数: 2,836\n",
|
||
" 首次训练 (学习率: 0.100000)\n",
|
||
"\n",
|
||
"批次 1: batch_1_files_8\n",
|
||
" 样本数: 2,836\n",
|
||
" 首次训练 (学习率: 0.100000)\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"/usr/local/lib/python3.11/dist-packages/lightgbm/basic.py:2421: UserWarning: Overriding the parameters from Reference Dataset.\n",
|
||
" _log_warning('Overriding the parameters from Reference Dataset.')\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"[1]\tvalidation's multi_logloss: 3.61118\n",
|
||
"[2]\tvalidation's multi_logloss: 3.59449\n",
|
||
"[2]\tvalidation's multi_logloss: 3.59449\n",
|
||
"[3]\tvalidation's multi_logloss: 3.60069\n",
|
||
"[3]\tvalidation's multi_logloss: 3.60069\n",
|
||
"[4]\tvalidation's multi_logloss: 3.58862\n",
|
||
"[4]\tvalidation's multi_logloss: 3.58862\n",
|
||
"[5]\tvalidation's multi_logloss: 3.56763\n",
|
||
"[5]\tvalidation's multi_logloss: 3.56763\n",
|
||
" 批次完成:\n",
|
||
" 验证准确率: 0.0740\n",
|
||
" 训练时间: 56.9秒\n",
|
||
" 模型树数: 205\n",
|
||
" 当前学习率: 0.100000\n",
|
||
" 批次 2 - 随机选择的文件:\n",
|
||
" 1. t15.2023.11.04_train_concatenated.npz\n",
|
||
" 2. t15.2023.10.08_train_concatenated.npz\n",
|
||
" 3. t15.2024.03.15_train_concatenated.npz\n",
|
||
" 4. t15.2023.12.08_train_concatenated.npz\n",
|
||
" 5. t15.2024.04.28_train_concatenated.npz\n",
|
||
" 6. t15.2023.09.29_train_concatenated.npz\n",
|
||
" 7. t15.2023.08.27_train_concatenated.npz\n",
|
||
" 8. t15.2025.03.30_train_concatenated.npz\n",
|
||
" 批次完成:\n",
|
||
" 验证准确率: 0.0740\n",
|
||
" 训练时间: 56.9秒\n",
|
||
" 模型树数: 205\n",
|
||
" 当前学习率: 0.100000\n",
|
||
" 批次 2 - 随机选择的文件:\n",
|
||
" 1. t15.2023.11.04_train_concatenated.npz\n",
|
||
" 2. t15.2023.10.08_train_concatenated.npz\n",
|
||
" 3. t15.2024.03.15_train_concatenated.npz\n",
|
||
" 4. t15.2023.12.08_train_concatenated.npz\n",
|
||
" 5. t15.2024.04.28_train_concatenated.npz\n",
|
||
" 6. t15.2023.09.29_train_concatenated.npz\n",
|
||
" 7. t15.2023.08.27_train_concatenated.npz\n",
|
||
" 8. t15.2025.03.30_train_concatenated.npz\n",
|
||
" 合并后总样本数: 312,205\n",
|
||
" 合并后总样本数: 312,205\n",
|
||
" 随机采样到: 50,000 样本\n",
|
||
" 随机采样到: 50,000 样本\n",
|
||
" 批次 2 最终结果:\n",
|
||
" 平衡后样本数: 2,802\n",
|
||
" 特征维度: 1243\n",
|
||
" 分布: 标签0=76, 标签40=76\n",
|
||
" ==================================================\n",
|
||
" 批次 2 最终结果:\n",
|
||
" 平衡后样本数: 2,802\n",
|
||
" 特征维度: 1243\n",
|
||
" 分布: 标签0=76, 标签40=76\n",
|
||
" ==================================================\n",
|
||
"\n",
|
||
"批次 2: batch_2_files_8\n",
|
||
" 样本数: 2,802\n",
|
||
" 增量训练 (学习率: 0.097577)\n",
|
||
"\n",
|
||
"批次 2: batch_2_files_8\n",
|
||
" 样本数: 2,802\n",
|
||
" 增量训练 (学习率: 0.097577)\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"/usr/local/lib/python3.11/dist-packages/lightgbm/basic.py:2421: UserWarning: Overriding the parameters from Reference Dataset.\n",
|
||
" _log_warning('Overriding the parameters from Reference Dataset.')\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"[6]\tvalidation's multi_logloss: 3.54169\n",
|
||
"[7]\tvalidation's multi_logloss: 3.53279\n",
|
||
"[7]\tvalidation's multi_logloss: 3.53279\n",
|
||
"[8]\tvalidation's multi_logloss: 3.54027\n",
|
||
"[8]\tvalidation's multi_logloss: 3.54027\n",
|
||
"[9]\tvalidation's multi_logloss: 3.522\n",
|
||
"[9]\tvalidation's multi_logloss: 3.522\n",
|
||
"[10]\tvalidation's multi_logloss: 3.5052\n",
|
||
"[10]\tvalidation's multi_logloss: 3.5052\n",
|
||
" 批次完成:\n",
|
||
" 验证准确率: 0.1107\n",
|
||
" 训练时间: 75.2秒\n",
|
||
" 模型树数: 410\n",
|
||
" 当前学习率: 0.097577\n",
|
||
" 批次 3 - 随机选择的文件:\n",
|
||
" 1. t15.2023.10.01_train_concatenated.npz\n",
|
||
" 2. t15.2024.07.28_train_concatenated.npz\n",
|
||
" 3. t15.2025.01.12_train_concatenated.npz\n",
|
||
" 4. t15.2023.10.22_train_concatenated.npz\n",
|
||
" 5. t15.2025.03.30_train_concatenated.npz\n",
|
||
" 6. t15.2023.08.13_train_concatenated.npz\n",
|
||
" 7. t15.2024.05.10_train_concatenated.npz\n",
|
||
" 8. t15.2025.04.13_train_concatenated.npz\n",
|
||
" 批次完成:\n",
|
||
" 验证准确率: 0.1107\n",
|
||
" 训练时间: 75.2秒\n",
|
||
" 模型树数: 410\n",
|
||
" 当前学习率: 0.097577\n",
|
||
" 批次 3 - 随机选择的文件:\n",
|
||
" 1. t15.2023.10.01_train_concatenated.npz\n",
|
||
" 2. t15.2024.07.28_train_concatenated.npz\n",
|
||
" 3. t15.2025.01.12_train_concatenated.npz\n",
|
||
" 4. t15.2023.10.22_train_concatenated.npz\n",
|
||
" 5. t15.2025.03.30_train_concatenated.npz\n",
|
||
" 6. t15.2023.08.13_train_concatenated.npz\n",
|
||
" 7. t15.2024.05.10_train_concatenated.npz\n",
|
||
" 8. t15.2025.04.13_train_concatenated.npz\n",
|
||
" 合并后总样本数: 293,792\n",
|
||
" 合并后总样本数: 293,792\n",
|
||
" 随机采样到: 50,000 样本\n",
|
||
" 随机采样到: 50,000 样本\n",
|
||
" 批次 3 最终结果:\n",
|
||
" 平衡后样本数: 2,871\n",
|
||
" 特征维度: 1243\n",
|
||
" 分布: 标签0=76, 标签40=76\n",
|
||
" ==================================================\n",
|
||
" 批次 3 最终结果:\n",
|
||
" 平衡后样本数: 2,871\n",
|
||
" 特征维度: 1243\n",
|
||
" 分布: 标签0=76, 标签40=76\n",
|
||
" ==================================================\n",
|
||
"\n",
|
||
"批次 3: batch_3_files_8\n",
|
||
" 样本数: 2,871\n",
|
||
" 增量训练 (学习率: 0.090546)\n",
|
||
"\n",
|
||
"批次 3: batch_3_files_8\n",
|
||
" 样本数: 2,871\n",
|
||
" 增量训练 (学习率: 0.090546)\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"/usr/local/lib/python3.11/dist-packages/lightgbm/basic.py:2421: UserWarning: Overriding the parameters from Reference Dataset.\n",
|
||
" _log_warning('Overriding the parameters from Reference Dataset.')\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"[11]\tvalidation's multi_logloss: 3.53701\n",
|
||
"[12]\tvalidation's multi_logloss: 3.55692\n",
|
||
"[12]\tvalidation's multi_logloss: 3.55692\n",
|
||
"[13]\tvalidation's multi_logloss: 3.57387\n",
|
||
"[13]\tvalidation's multi_logloss: 3.57387\n",
|
||
"[14]\tvalidation's multi_logloss: 3.57509\n",
|
||
"[14]\tvalidation's multi_logloss: 3.57509\n",
|
||
"[15]\tvalidation's multi_logloss: 3.57052\n",
|
||
"[15]\tvalidation's multi_logloss: 3.57052\n",
|
||
" 批次完成:\n",
|
||
" 验证准确率: 0.1049\n",
|
||
" 训练时间: 83.5秒\n",
|
||
" 模型树数: 451\n",
|
||
" 当前学习率: 0.090546\n",
|
||
" 批次 4 - 随机选择的文件:\n",
|
||
" 1. t15.2024.03.15_train_concatenated.npz\n",
|
||
" 2. t15.2023.10.15_train_concatenated.npz\n",
|
||
" 3. t15.2024.02.25_train_concatenated.npz\n",
|
||
" 4. t15.2023.08.20_train_concatenated.npz\n",
|
||
" 5. t15.2023.10.22_train_concatenated.npz\n",
|
||
" 6. t15.2023.12.10_train_concatenated.npz\n",
|
||
" 7. t15.2023.10.20_train_concatenated.npz\n",
|
||
" 8. t15.2024.07.28_train_concatenated.npz\n",
|
||
" 批次完成:\n",
|
||
" 验证准确率: 0.1049\n",
|
||
" 训练时间: 83.5秒\n",
|
||
" 模型树数: 451\n",
|
||
" 当前学习率: 0.090546\n",
|
||
" 批次 4 - 随机选择的文件:\n",
|
||
" 1. t15.2024.03.15_train_concatenated.npz\n",
|
||
" 2. t15.2023.10.15_train_concatenated.npz\n",
|
||
" 3. t15.2024.02.25_train_concatenated.npz\n",
|
||
" 4. t15.2023.08.20_train_concatenated.npz\n",
|
||
" 5. t15.2023.10.22_train_concatenated.npz\n",
|
||
" 6. t15.2023.12.10_train_concatenated.npz\n",
|
||
" 7. t15.2023.10.20_train_concatenated.npz\n",
|
||
" 8. t15.2024.07.28_train_concatenated.npz\n",
|
||
" 合并后总样本数: 335,983\n",
|
||
" 合并后总样本数: 335,983\n",
|
||
" 随机采样到: 50,000 样本\n",
|
||
" 随机采样到: 50,000 样本\n",
|
||
" 批次 4 最终结果:\n",
|
||
" 平衡后样本数: 2,840\n",
|
||
" 特征维度: 1243\n",
|
||
" 分布: 标签0=76, 标签40=76\n",
|
||
" ==================================================\n",
|
||
" 批次 4 最终结果:\n",
|
||
" 平衡后样本数: 2,840\n",
|
||
" 特征维度: 1243\n",
|
||
" 分布: 标签0=76, 标签40=76\n",
|
||
" ==================================================\n",
|
||
"\n",
|
||
"批次 4: batch_4_files_8\n",
|
||
" 样本数: 2,840\n",
|
||
" 增量训练 (学习率: 0.079595)\n",
|
||
"\n",
|
||
"批次 4: batch_4_files_8\n",
|
||
" 样本数: 2,840\n",
|
||
" 增量训练 (学习率: 0.079595)\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"/usr/local/lib/python3.11/dist-packages/lightgbm/basic.py:2421: UserWarning: Overriding the parameters from Reference Dataset.\n",
|
||
" _log_warning('Overriding the parameters from Reference Dataset.')\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"[12]\tvalidation's multi_logloss: 3.51209\n",
|
||
"[13]\tvalidation's multi_logloss: 3.51032\n",
|
||
"[13]\tvalidation's multi_logloss: 3.51032\n",
|
||
"[14]\tvalidation's multi_logloss: 3.51553\n",
|
||
"[14]\tvalidation's multi_logloss: 3.51553\n",
|
||
"[15]\tvalidation's multi_logloss: 3.5119\n",
|
||
"[15]\tvalidation's multi_logloss: 3.5119\n",
|
||
"[16]\tvalidation's multi_logloss: 3.50659\n",
|
||
"[16]\tvalidation's multi_logloss: 3.50659\n",
|
||
" 批次完成:\n",
|
||
" 验证准确率: 0.1218\n",
|
||
" 训练时间: 89.2秒\n",
|
||
" 模型树数: 656\n",
|
||
" 当前学习率: 0.079595\n",
|
||
" 批次 5 - 随机选择的文件:\n",
|
||
" 1. t15.2023.08.27_train_concatenated.npz\n",
|
||
" 2. t15.2023.11.17_train_concatenated.npz\n",
|
||
" 3. t15.2024.07.19_train_concatenated.npz\n",
|
||
" 4. t15.2023.09.03_train_concatenated.npz\n",
|
||
" 5. t15.2023.08.20_train_concatenated.npz\n",
|
||
" 6. t15.2023.12.10_train_concatenated.npz\n",
|
||
" 7. t15.2023.12.08_train_concatenated.npz\n",
|
||
" 8. t15.2023.10.13_train_concatenated.npz\n",
|
||
" 批次完成:\n",
|
||
" 验证准确率: 0.1218\n",
|
||
" 训练时间: 89.2秒\n",
|
||
" 模型树数: 656\n",
|
||
" 当前学习率: 0.079595\n",
|
||
" 批次 5 - 随机选择的文件:\n",
|
||
" 1. t15.2023.08.27_train_concatenated.npz\n",
|
||
" 2. t15.2023.11.17_train_concatenated.npz\n",
|
||
" 3. t15.2024.07.19_train_concatenated.npz\n",
|
||
" 4. t15.2023.09.03_train_concatenated.npz\n",
|
||
" 5. t15.2023.08.20_train_concatenated.npz\n",
|
||
" 6. t15.2023.12.10_train_concatenated.npz\n",
|
||
" 7. t15.2023.12.08_train_concatenated.npz\n",
|
||
" 8. t15.2023.10.13_train_concatenated.npz\n",
|
||
" 合并后总样本数: 319,334\n",
|
||
" 合并后总样本数: 319,334\n",
|
||
" 随机采样到: 50,000 样本\n",
|
||
" 随机采样到: 50,000 样本\n",
|
||
" 批次 5 最终结果:\n",
|
||
" 平衡后样本数: 2,839\n",
|
||
" 特征维度: 1243\n",
|
||
" 分布: 标签0=76, 标签40=76\n",
|
||
" ==================================================\n",
|
||
" 批次 5 最终结果:\n",
|
||
" 平衡后样本数: 2,839\n",
|
||
" 特征维度: 1243\n",
|
||
" 分布: 标签0=76, 标签40=76\n",
|
||
" ==================================================\n",
|
||
"\n",
|
||
"批次 5: batch_5_files_8\n",
|
||
" 样本数: 2,839\n",
|
||
" 增量训练 (学习率: 0.065796)\n",
|
||
"\n",
|
||
"批次 5: batch_5_files_8\n",
|
||
" 样本数: 2,839\n",
|
||
" 增量训练 (学习率: 0.065796)\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"/usr/local/lib/python3.11/dist-packages/lightgbm/basic.py:2421: UserWarning: Overriding the parameters from Reference Dataset.\n",
|
||
" _log_warning('Overriding the parameters from Reference Dataset.')\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"[17]\tvalidation's multi_logloss: 3.48509\n",
|
||
"[18]\tvalidation's multi_logloss: 3.49194\n",
|
||
"[18]\tvalidation's multi_logloss: 3.49194\n",
|
||
"[19]\tvalidation's multi_logloss: 3.50197\n",
|
||
"[19]\tvalidation's multi_logloss: 3.50197\n",
|
||
"[20]\tvalidation's multi_logloss: 3.50542\n",
|
||
"[20]\tvalidation's multi_logloss: 3.50542\n",
|
||
"[21]\tvalidation's multi_logloss: 3.51177\n",
|
||
"[21]\tvalidation's multi_logloss: 3.51177\n",
|
||
" 批次完成:\n",
|
||
" 验证准确率: 0.1273\n",
|
||
" 训练时间: 101.3秒\n",
|
||
" 模型树数: 697\n",
|
||
" 当前学习率: 0.065796\n",
|
||
" 批次 6 - 随机选择的文件:\n",
|
||
" 1. t15.2023.10.13_train_concatenated.npz\n",
|
||
" 2. t15.2023.10.20_train_concatenated.npz\n",
|
||
" 3. t15.2024.02.25_train_concatenated.npz\n",
|
||
" 4. t15.2023.08.27_train_concatenated.npz\n",
|
||
" 5. t15.2024.04.28_train_concatenated.npz\n",
|
||
" 6. t15.2025.01.12_train_concatenated.npz\n",
|
||
" 7. t15.2023.08.25_train_concatenated.npz\n",
|
||
" 8. t15.2023.12.03_train_concatenated.npz\n",
|
||
" 批次完成:\n",
|
||
" 验证准确率: 0.1273\n",
|
||
" 训练时间: 101.3秒\n",
|
||
" 模型树数: 697\n",
|
||
" 当前学习率: 0.065796\n",
|
||
" 批次 6 - 随机选择的文件:\n",
|
||
" 1. t15.2023.10.13_train_concatenated.npz\n",
|
||
" 2. t15.2023.10.20_train_concatenated.npz\n",
|
||
" 3. t15.2024.02.25_train_concatenated.npz\n",
|
||
" 4. t15.2023.08.27_train_concatenated.npz\n",
|
||
" 5. t15.2024.04.28_train_concatenated.npz\n",
|
||
" 6. t15.2025.01.12_train_concatenated.npz\n",
|
||
" 7. t15.2023.08.25_train_concatenated.npz\n",
|
||
" 8. t15.2023.12.03_train_concatenated.npz\n",
|
||
" 合并后总样本数: 258,917\n",
|
||
" 合并后总样本数: 258,917\n",
|
||
" 随机采样到: 50,000 样本\n",
|
||
" 随机采样到: 50,000 样本\n",
|
||
" 批次 6 最终结果:\n",
|
||
" 平衡后样本数: 2,826\n",
|
||
" 特征维度: 1243\n",
|
||
" 分布: 标签0=76, 标签40=76\n",
|
||
" ==================================================\n",
|
||
" 批次 6 最终结果:\n",
|
||
" 平衡后样本数: 2,826\n",
|
||
" 特征维度: 1243\n",
|
||
" 分布: 标签0=76, 标签40=76\n",
|
||
" ==================================================\n",
|
||
"\n",
|
||
"批次 6: batch_6_files_8\n",
|
||
" 样本数: 2,826\n",
|
||
" 增量训练 (学习率: 0.050500)\n",
|
||
"\n",
|
||
"批次 6: batch_6_files_8\n",
|
||
" 样本数: 2,826\n",
|
||
" 增量训练 (学习率: 0.050500)\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"/usr/local/lib/python3.11/dist-packages/lightgbm/basic.py:2421: UserWarning: Overriding the parameters from Reference Dataset.\n",
|
||
" _log_warning('Overriding the parameters from Reference Dataset.')\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"[18]\tvalidation's multi_logloss: 3.48149\n",
|
||
"[19]\tvalidation's multi_logloss: 3.47673\n",
|
||
"[19]\tvalidation's multi_logloss: 3.47673\n",
|
||
"[20]\tvalidation's multi_logloss: 3.47655\n",
|
||
"[20]\tvalidation's multi_logloss: 3.47655\n",
|
||
"[21]\tvalidation's multi_logloss: 3.47924\n",
|
||
"[21]\tvalidation's multi_logloss: 3.47924\n",
|
||
"[22]\tvalidation's multi_logloss: 3.47754\n",
|
||
"[22]\tvalidation's multi_logloss: 3.47754\n",
|
||
" 批次完成:\n",
|
||
" 验证准确率: 0.1308\n",
|
||
" 训练时间: 112.3秒\n",
|
||
" 模型树数: 820\n",
|
||
" 当前学习率: 0.050500\n",
|
||
" 批次 7 - 随机选择的文件:\n",
|
||
" 1. t15.2023.12.29_train_concatenated.npz\n",
|
||
" 2. t15.2023.09.29_train_concatenated.npz\n",
|
||
" 3. t15.2023.09.01_train_concatenated.npz\n",
|
||
" 4. t15.2023.12.08_train_concatenated.npz\n",
|
||
" 5. t15.2024.05.10_train_concatenated.npz\n",
|
||
" 6. t15.2023.10.08_train_concatenated.npz\n",
|
||
" 7. t15.2025.04.13_train_concatenated.npz\n",
|
||
" 8. t15.2023.09.24_train_concatenated.npz\n",
|
||
" 批次完成:\n",
|
||
" 验证准确率: 0.1308\n",
|
||
" 训练时间: 112.3秒\n",
|
||
" 模型树数: 820\n",
|
||
" 当前学习率: 0.050500\n",
|
||
" 批次 7 - 随机选择的文件:\n",
|
||
" 1. t15.2023.12.29_train_concatenated.npz\n",
|
||
" 2. t15.2023.09.29_train_concatenated.npz\n",
|
||
" 3. t15.2023.09.01_train_concatenated.npz\n",
|
||
" 4. t15.2023.12.08_train_concatenated.npz\n",
|
||
" 5. t15.2024.05.10_train_concatenated.npz\n",
|
||
" 6. t15.2023.10.08_train_concatenated.npz\n",
|
||
" 7. t15.2025.04.13_train_concatenated.npz\n",
|
||
" 8. t15.2023.09.24_train_concatenated.npz\n",
|
||
" 合并后总样本数: 338,426\n",
|
||
" 合并后总样本数: 338,426\n",
|
||
" 随机采样到: 50,000 样本\n",
|
||
" 随机采样到: 50,000 样本\n",
|
||
" 批次 7 最终结果:\n",
|
||
" 平衡后样本数: 2,858\n",
|
||
" 特征维度: 1243\n",
|
||
" 分布: 标签0=76, 标签40=76\n",
|
||
" ==================================================\n",
|
||
" 批次 7 最终结果:\n",
|
||
" 平衡后样本数: 2,858\n",
|
||
" 特征维度: 1243\n",
|
||
" 分布: 标签0=76, 标签40=76\n",
|
||
" ==================================================\n",
|
||
"\n",
|
||
"批次 7: batch_7_files_8\n",
|
||
" 样本数: 2,858\n",
|
||
" 增量训练 (学习率: 0.035204)\n",
|
||
"\n",
|
||
"批次 7: batch_7_files_8\n",
|
||
" 样本数: 2,858\n",
|
||
" 增量训练 (学习率: 0.035204)\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"/usr/local/lib/python3.11/dist-packages/lightgbm/basic.py:2421: UserWarning: Overriding the parameters from Reference Dataset.\n",
|
||
" _log_warning('Overriding the parameters from Reference Dataset.')\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"[21]\tvalidation's multi_logloss: 3.47751\n",
|
||
"[22]\tvalidation's multi_logloss: 3.48293\n",
|
||
"[22]\tvalidation's multi_logloss: 3.48293\n",
|
||
"[23]\tvalidation's multi_logloss: 3.48763\n",
|
||
"[23]\tvalidation's multi_logloss: 3.48763\n",
|
||
"[24]\tvalidation's multi_logloss: 3.48841\n",
|
||
"[24]\tvalidation's multi_logloss: 3.48841\n",
|
||
"[25]\tvalidation's multi_logloss: 3.49353\n",
|
||
"[25]\tvalidation's multi_logloss: 3.49353\n",
|
||
" 批次完成:\n",
|
||
" 验证准确率: 0.1311\n",
|
||
" 训练时间: 137.4秒\n",
|
||
" 模型树数: 861\n",
|
||
" 当前学习率: 0.035204\n",
|
||
" 批次 8 - 随机选择的文件:\n",
|
||
" 1. t15.2023.12.08_train_concatenated.npz\n",
|
||
" 2. t15.2025.03.30_train_concatenated.npz\n",
|
||
" 3. t15.2023.11.03_train_concatenated.npz\n",
|
||
" 4. t15.2023.09.29_train_concatenated.npz\n",
|
||
" 5. t15.2024.03.15_train_concatenated.npz\n",
|
||
" 6. t15.2025.01.10_train_concatenated.npz\n",
|
||
" 7. t15.2023.08.27_train_concatenated.npz\n",
|
||
" 8. t15.2023.10.22_train_concatenated.npz\n",
|
||
" 批次完成:\n",
|
||
" 验证准确率: 0.1311\n",
|
||
" 训练时间: 137.4秒\n",
|
||
" 模型树数: 861\n",
|
||
" 当前学习率: 0.035204\n",
|
||
" 批次 8 - 随机选择的文件:\n",
|
||
" 1. t15.2023.12.08_train_concatenated.npz\n",
|
||
" 2. t15.2025.03.30_train_concatenated.npz\n",
|
||
" 3. t15.2023.11.03_train_concatenated.npz\n",
|
||
" 4. t15.2023.09.29_train_concatenated.npz\n",
|
||
" 5. t15.2024.03.15_train_concatenated.npz\n",
|
||
" 6. t15.2025.01.10_train_concatenated.npz\n",
|
||
" 7. t15.2023.08.27_train_concatenated.npz\n",
|
||
" 8. t15.2023.10.22_train_concatenated.npz\n",
|
||
" 合并后总样本数: 307,185\n",
|
||
" 合并后总样本数: 307,185\n",
|
||
" 随机采样到: 50,000 样本\n",
|
||
" 随机采样到: 50,000 样本\n",
|
||
" 批次 8 最终结果:\n",
|
||
" 平衡后样本数: 2,812\n",
|
||
" 特征维度: 1243\n",
|
||
" 分布: 标签0=76, 标签40=76\n",
|
||
" ==================================================\n",
|
||
" 批次 8 最终结果:\n",
|
||
" 平衡后样本数: 2,812\n",
|
||
" 特征维度: 1243\n",
|
||
" 分布: 标签0=76, 标签40=76\n",
|
||
" ==================================================\n",
|
||
"\n",
|
||
"批次 8: batch_8_files_8\n",
|
||
" 样本数: 2,812\n",
|
||
" 增量训练 (学习率: 0.021405)\n",
|
||
"\n",
|
||
"批次 8: batch_8_files_8\n",
|
||
" 样本数: 2,812\n",
|
||
" 增量训练 (学习率: 0.021405)\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"/usr/local/lib/python3.11/dist-packages/lightgbm/basic.py:2421: UserWarning: Overriding the parameters from Reference Dataset.\n",
|
||
" _log_warning('Overriding the parameters from Reference Dataset.')\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"[22]\tvalidation's multi_logloss: 3.47332\n",
|
||
"[23]\tvalidation's multi_logloss: 3.47467\n",
|
||
"[23]\tvalidation's multi_logloss: 3.47467\n",
|
||
"[24]\tvalidation's multi_logloss: 3.4695\n",
|
||
"[24]\tvalidation's multi_logloss: 3.4695\n",
|
||
"[25]\tvalidation's multi_logloss: 3.46523\n",
|
||
"[25]\tvalidation's multi_logloss: 3.46523\n",
|
||
"[26]\tvalidation's multi_logloss: 3.46263\n",
|
||
"[26]\tvalidation's multi_logloss: 3.46263\n",
|
||
" 批次完成:\n",
|
||
" 验证准确率: 0.1359\n",
|
||
" 训练时间: 126.0秒\n",
|
||
" 模型树数: 1066\n",
|
||
" 当前学习率: 0.021405\n",
|
||
" 批次 9 - 随机选择的文件:\n",
|
||
" 1. t15.2023.11.03_train_concatenated.npz\n",
|
||
" 2. t15.2024.03.08_train_concatenated.npz\n",
|
||
" 3. t15.2023.09.01_train_concatenated.npz\n",
|
||
" 4. t15.2023.08.18_train_concatenated.npz\n",
|
||
" 5. t15.2023.08.27_train_concatenated.npz\n",
|
||
" 6. t15.2023.11.19_train_concatenated.npz\n",
|
||
" 7. t15.2023.09.03_train_concatenated.npz\n",
|
||
" 8. t15.2024.02.25_train_concatenated.npz\n",
|
||
" 批次完成:\n",
|
||
" 验证准确率: 0.1359\n",
|
||
" 训练时间: 126.0秒\n",
|
||
" 模型树数: 1066\n",
|
||
" 当前学习率: 0.021405\n",
|
||
" 批次 9 - 随机选择的文件:\n",
|
||
" 1. t15.2023.11.03_train_concatenated.npz\n",
|
||
" 2. t15.2024.03.08_train_concatenated.npz\n",
|
||
" 3. t15.2023.09.01_train_concatenated.npz\n",
|
||
" 4. t15.2023.08.18_train_concatenated.npz\n",
|
||
" 5. t15.2023.08.27_train_concatenated.npz\n",
|
||
" 6. t15.2023.11.19_train_concatenated.npz\n",
|
||
" 7. t15.2023.09.03_train_concatenated.npz\n",
|
||
" 8. t15.2024.02.25_train_concatenated.npz\n",
|
||
" 合并后总样本数: 318,419\n",
|
||
" 合并后总样本数: 318,419\n",
|
||
" 随机采样到: 50,000 样本\n",
|
||
" 随机采样到: 50,000 样本\n",
|
||
" 批次 9 最终结果:\n",
|
||
" 平衡后样本数: 2,864\n",
|
||
" 特征维度: 1243\n",
|
||
" 分布: 标签0=76, 标签40=76\n",
|
||
" ==================================================\n",
|
||
" 批次 9 最终结果:\n",
|
||
" 平衡后样本数: 2,864\n",
|
||
" 特征维度: 1243\n",
|
||
" 分布: 标签0=76, 标签40=76\n",
|
||
" ==================================================\n",
|
||
"\n",
|
||
"批次 9: batch_9_files_8\n",
|
||
" 样本数: 2,864\n",
|
||
" 增量训练 (学习率: 0.010454)\n",
|
||
"\n",
|
||
"批次 9: batch_9_files_8\n",
|
||
" 样本数: 2,864\n",
|
||
" 增量训练 (学习率: 0.010454)\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"/usr/local/lib/python3.11/dist-packages/lightgbm/basic.py:2421: UserWarning: Overriding the parameters from Reference Dataset.\n",
|
||
" _log_warning('Overriding the parameters from Reference Dataset.')\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stdout",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"[27]\tvalidation's multi_logloss: 3.45802\n",
|
||
"[28]\tvalidation's multi_logloss: 3.45485\n",
|
||
"[28]\tvalidation's multi_logloss: 3.45485\n",
|
||
"[29]\tvalidation's multi_logloss: 3.45208\n",
|
||
"[29]\tvalidation's multi_logloss: 3.45208\n",
|
||
"[30]\tvalidation's multi_logloss: 3.44677\n",
|
||
"[30]\tvalidation's multi_logloss: 3.44677\n",
|
||
"[31]\tvalidation's multi_logloss: 3.44218\n",
|
||
"[31]\tvalidation's multi_logloss: 3.44218\n",
|
||
" 批次完成:\n",
|
||
" 验证准确率: 0.1416\n",
|
||
" 训练时间: 143.2秒\n",
|
||
" 模型树数: 1271\n",
|
||
" 当前学习率: 0.010454\n",
|
||
" 批次 10 - 随机选择的文件:\n",
|
||
" 1. t15.2024.07.21_train_concatenated.npz\n",
|
||
" 2. t15.2024.06.14_train_concatenated.npz\n",
|
||
" 3. t15.2023.12.10_train_concatenated.npz\n",
|
||
" 4. t15.2024.07.19_train_concatenated.npz\n",
|
||
" 5. t15.2023.12.03_train_concatenated.npz\n",
|
||
" 6. t15.2023.11.04_train_concatenated.npz\n",
|
||
" 7. t15.2023.10.01_train_concatenated.npz\n",
|
||
" 8. t15.2023.08.25_train_concatenated.npz\n",
|
||
" 批次完成:\n",
|
||
" 验证准确率: 0.1416\n",
|
||
" 训练时间: 143.2秒\n",
|
||
" 模型树数: 1271\n",
|
||
" 当前学习率: 0.010454\n",
|
||
" 批次 10 - 随机选择的文件:\n",
|
||
" 1. t15.2024.07.21_train_concatenated.npz\n",
|
||
" 2. t15.2024.06.14_train_concatenated.npz\n",
|
||
" 3. t15.2023.12.10_train_concatenated.npz\n",
|
||
" 4. t15.2024.07.19_train_concatenated.npz\n",
|
||
" 5. t15.2023.12.03_train_concatenated.npz\n",
|
||
" 6. t15.2023.11.04_train_concatenated.npz\n",
|
||
" 7. t15.2023.10.01_train_concatenated.npz\n",
|
||
" 8. t15.2023.08.25_train_concatenated.npz\n",
|
||
" 合并后总样本数: 262,351\n",
|
||
" 合并后总样本数: 262,351\n",
|
||
" 随机采样到: 50,000 样本\n",
|
||
" 随机采样到: 50,000 样本\n",
|
||
" 批次 10 最终结果:\n",
|
||
" 平衡后样本数: 2,823\n",
|
||
" 特征维度: 1243\n",
|
||
" 分布: 标签0=76, 标签40=76\n",
|
||
" ==================================================\n",
|
||
" 批次 10 最终结果:\n",
|
||
" 平衡后样本数: 2,823\n",
|
||
" 特征维度: 1243\n",
|
||
" 分布: 标签0=76, 标签40=76\n",
|
||
" ==================================================\n",
|
||
"\n",
|
||
"批次 10: batch_10_files_8\n",
|
||
" 样本数: 2,823\n",
|
||
" 增量训练 (学习率: 0.003423)\n",
|
||
"\n",
|
||
"批次 10: batch_10_files_8\n",
|
||
" 样本数: 2,823\n",
|
||
" 增量训练 (学习率: 0.003423)\n"
|
||
]
|
||
},
|
||
{
|
||
"name": "stderr",
|
||
"output_type": "stream",
|
||
"text": [
|
||
"/usr/local/lib/python3.11/dist-packages/lightgbm/basic.py:2421: UserWarning: Overriding the parameters from Reference Dataset.\n",
|
||
" _log_warning('Overriding the parameters from Reference Dataset.')\n"
|
||
]
|
||
}
|
||
],
|
||
"source": [
|
||
"# 改进的训练参数\n",
|
||
"IMPROVED_TRAINING_PARAMS = {\n",
|
||
" 'num_boost_round': 5, # 每批次的提升轮数\n",
|
||
" 'early_stopping_rounds': 10, # 早停轮数\n",
|
||
" 'n_files_per_batch': 8, # 快速验证用,减少到4\n",
|
||
" 'batch_size': 50000, # 快速验证用,减半\n",
|
||
" 'max_batches': 100 # 仅跑100个批次做冒烟测试\n",
|
||
"}\n",
|
||
"\n",
|
||
"# 开始使用改进的训练器\n",
|
||
"model = trainer.train_incremental(\n",
|
||
" num_boost_round=IMPROVED_TRAINING_PARAMS['num_boost_round'],\n",
|
||
" early_stopping_rounds=IMPROVED_TRAINING_PARAMS['early_stopping_rounds'],\n",
|
||
" n_files_per_batch=IMPROVED_TRAINING_PARAMS['n_files_per_batch'],\n",
|
||
" batch_size=IMPROVED_TRAINING_PARAMS['batch_size'],\n",
|
||
" max_batches=IMPROVED_TRAINING_PARAMS['max_batches']\n",
|
||
")\n",
|
||
"\n",
|
||
"# 训练完成后计算一次验证集PER(每个文件取33%试验)\n",
|
||
"per_summary = trainer.evaluate_val_per_experiment(fraction=0.33, random_state=42, drop_sep40=False, max_trials_per_file=5)\n",
|
||
"print(per_summary)"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 📊 训练结果分析"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 📊 训练结果分析和可视化\n",
|
||
"\n",
|
||
"print(\"📊 分析智能分批训练结果...\")\n",
|
||
"print(\"=\" * 60)\n",
|
||
"\n",
|
||
"# 显示训练进度图表\n",
|
||
"trainer.plot_training_progress()\n",
|
||
"\n",
|
||
"# 保存最终模型\n",
|
||
"final_model_path = \"smart_pipeline_final_model.txt\"\n",
|
||
"if trainer.model:\n",
|
||
" trainer.model.save_model(final_model_path)\n",
|
||
" print(f\"\\n💾 最终模型已保存: {final_model_path}\")\n",
|
||
"\n",
|
||
"# 详细分析\n",
|
||
"if trainer.training_history:\n",
|
||
" print(f\"\\n📈 详细训练分析:\")\n",
|
||
" print(f\" 🎯 训练批次总数: {len(trainer.training_history)}\")\n",
|
||
" \n",
|
||
" # 最佳批次\n",
|
||
" best_batch = max(trainer.training_history, key=lambda x: x['val_accuracy'])\n",
|
||
" print(f\" 🏆 最佳验证准确率: {best_batch['val_accuracy']:.4f} (批次 {best_batch['batch']})\")\n",
|
||
" \n",
|
||
" # 训练效率\n",
|
||
" total_training_time = sum(h['time'] for h in trainer.training_history)\n",
|
||
" avg_batch_time = total_training_time / len(trainer.training_history)\n",
|
||
" print(f\" ⏱️ 总训练时间: {total_training_time:.1f}秒\")\n",
|
||
" print(f\" ⏱️ 平均批次时间: {avg_batch_time:.1f}秒\")\n",
|
||
" \n",
|
||
" # 模型复杂度\n",
|
||
" final_trees = trainer.training_history[-1]['num_trees']\n",
|
||
" print(f\" 🌳 最终模型树数: {final_trees}\")\n",
|
||
" \n",
|
||
" # 收敛性分析\n",
|
||
" recent_accs = [h['val_accuracy'] for h in trainer.training_history[-3:]]\n",
|
||
" if len(recent_accs) >= 2:\n",
|
||
" acc_stability = max(recent_accs) - min(recent_accs)\n",
|
||
" print(f\" 📈 准确率稳定性: {acc_stability:.4f} (最近3批次方差)\")\n",
|
||
" \n",
|
||
" if acc_stability < 0.01:\n",
|
||
" print(\" ✅ 模型已收敛 (准确率变化 < 1%)\")\n",
|
||
" else:\n",
|
||
" print(\" ⚠️ 模型可能需要更多训练\")\n",
|
||
"\n",
|
||
"print(f\"\\n🎉 智能分批训练分析完成!\")\n",
|
||
"print(f\" 💡 使用了改进的数据平衡策略和PCA降维\")\n",
|
||
"print(f\" 💡 训练集应用了下采样+过采样,验证集保持原始分布\")\n",
|
||
"print(f\" 💡 实现了内存友好的分批处理\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "markdown",
|
||
"metadata": {},
|
||
"source": [
|
||
"## 🧪 模型性能评估"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 🧪 模型性能评估\n",
|
||
"\n",
|
||
"from sklearn.metrics import classification_report, confusion_matrix\n",
|
||
"import numpy as np\n",
|
||
"\n",
|
||
"def evaluate_model_performance(model, pipeline, data_type='val'):\n",
|
||
" \"\"\"\n",
|
||
" 评估模型在指定数据集上的性能\n",
|
||
" \"\"\"\n",
|
||
" print(f\"🧪 评估模型在{data_type}数据集上的性能...\")\n",
|
||
" \n",
|
||
" # 加载数据\n",
|
||
" X, y = pipeline.step3_process_data(data_type, apply_sampling=False)\n",
|
||
" \n",
|
||
" if X is None or y is None:\n",
|
||
" print(f\"❌ 无法加载{data_type}数据\")\n",
|
||
" return None\n",
|
||
" \n",
|
||
" print(f\" 📊 数据集大小: {X.shape[0]:,} 样本, {X.shape[1]} 特征\")\n",
|
||
" \n",
|
||
" # 预测\n",
|
||
" start_time = time.time()\n",
|
||
" y_pred_proba = model.predict(X)\n",
|
||
" y_pred = y_pred_proba.argmax(axis=1)\n",
|
||
" pred_time = time.time() - start_time\n",
|
||
" \n",
|
||
" # 计算性能指标\n",
|
||
" accuracy = (y_pred == y).mean()\n",
|
||
" \n",
|
||
" print(f\" ⏱️ 预测时间: {pred_time:.2f}秒\")\n",
|
||
" print(f\" 🎯 整体准确率: {accuracy:.4f}\")\n",
|
||
" \n",
|
||
" # 分析各类别性能\n",
|
||
" from collections import Counter\n",
|
||
" true_counts = Counter(y)\n",
|
||
" pred_counts = Counter(y_pred)\n",
|
||
" \n",
|
||
" print(f\"\\n📊 标签分布对比:\")\n",
|
||
" print(\"标签 | 真实数量 | 预测数量 | 准确率\")\n",
|
||
" print(\"-\" * 40)\n",
|
||
" \n",
|
||
" label_accuracies = {}\n",
|
||
" for label in range(41):\n",
|
||
" if label in true_counts:\n",
|
||
" label_mask = (y == label)\n",
|
||
" if label_mask.sum() > 0:\n",
|
||
" label_acc = (y_pred[label_mask] == label).mean()\n",
|
||
" label_accuracies[label] = label_acc\n",
|
||
" true_count = true_counts.get(label, 0)\n",
|
||
" pred_count = pred_counts.get(label, 0)\n",
|
||
" print(f\"{label:4d} | {true_count:8,} | {pred_count:8,} | {label_acc:7.3f}\")\n",
|
||
" \n",
|
||
" # 重点分析关键标签\n",
|
||
" print(f\"\\n🔍 关键标签性能分析:\")\n",
|
||
" key_labels = [0, 40] # 下采样的标签\n",
|
||
" for label in key_labels:\n",
|
||
" if label in label_accuracies:\n",
|
||
" acc = label_accuracies[label]\n",
|
||
" count = true_counts.get(label, 0)\n",
|
||
" print(f\" 标签 {label} (下采样目标): 准确率 {acc:.4f}, 样本数 {count:,}\")\n",
|
||
" \n",
|
||
" # 少数类性能\n",
|
||
" minority_labels = [label for label, count in true_counts.items() \n",
|
||
" if count < 200 and label not in [0, 40]]\n",
|
||
" if minority_labels:\n",
|
||
" minority_accs = [label_accuracies.get(label, 0) for label in minority_labels[:5]]\n",
|
||
" avg_minority_acc = np.mean(minority_accs) if minority_accs else 0\n",
|
||
" print(f\" 少数类平均准确率 (前5个): {avg_minority_acc:.4f}\")\n",
|
||
" \n",
|
||
" # 置信度分析\n",
|
||
" max_proba = y_pred_proba.max(axis=1)\n",
|
||
" print(f\"\\n📈 预测置信度分析:\")\n",
|
||
" print(f\" 平均置信度: {max_proba.mean():.4f}\")\n",
|
||
" print(f\" 置信度中位数: {np.median(max_proba):.4f}\")\n",
|
||
" print(f\" 高置信度预测 (>0.9): {(max_proba > 0.9).sum():,} / {len(max_proba):,} ({(max_proba > 0.9).mean():.2%})\")\n",
|
||
" \n",
|
||
" return {\n",
|
||
" 'accuracy': accuracy,\n",
|
||
" 'prediction_time': pred_time,\n",
|
||
" 'label_accuracies': label_accuracies,\n",
|
||
" 'confidence_stats': {\n",
|
||
" 'mean': max_proba.mean(),\n",
|
||
" 'median': np.median(max_proba),\n",
|
||
" 'high_confidence_ratio': (max_proba > 0.9).mean()\n",
|
||
" }\n",
|
||
" }\n",
|
||
"\n",
|
||
"# 评估模型性能\n",
|
||
"if trainer.model:\n",
|
||
" print(\"🧪 开始模型性能评估...\")\n",
|
||
" \n",
|
||
" # 验证集评估\n",
|
||
" val_results = evaluate_model_performance(trainer.model, pipeline, 'val')\n",
|
||
" \n",
|
||
" print(f\"\\n\" + \"=\"*60)\n",
|
||
" print(\"🎉 智能分批训练+数据平衡 评估完成!\")\n",
|
||
" print(f\"✅ 实现了数据平衡和PCA降维的完整流程\")\n",
|
||
" print(f\"✅ 使用了内存友好的分批训练策略\")\n",
|
||
" print(f\"✅ 保持了验证集的原始分布以确保评估客观性\")\n",
|
||
"else:\n",
|
||
" print(\"❌ 模型尚未训练完成,请等待训练结束后运行此评估\")"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# ✅ 余弦退火已更新为带重启版本\n",
|
||
"\n",
|
||
"print(\"🎉 余弦退火调度器更新完成!\")\n",
|
||
"\n",
|
||
"# 检查trainer是否已创建,如果未创建则先创建\n",
|
||
"if 'trainer' not in globals():\n",
|
||
" print(\"⚠️ 训练器尚未创建,请先运行前面的代码创建训练器\")\n",
|
||
"else:\n",
|
||
" print(f\"✅ 当前使用:带重启的余弦退火 (SGDR)\")\n",
|
||
" print(f\" 🔄 重启参数: T_0={trainer.t_0}, T_mult={trainer.t_mult}\")\n",
|
||
" print(f\" 📈 学习率范围: {trainer.initial_learning_rate} → {trainer.min_learning_rate}\")\n",
|
||
"\n",
|
||
" # 可视化新的学习率调度\n",
|
||
" import matplotlib.pyplot as plt\n",
|
||
" import numpy as np\n",
|
||
"\n",
|
||
" fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 5))\n",
|
||
"\n",
|
||
" # 模拟300轮的学习率变化\n",
|
||
" rounds = list(range(300))\n",
|
||
" old_lrs = [] # 原始余弦退火\n",
|
||
" new_lrs = [] # 带重启的余弦退火\n",
|
||
"\n",
|
||
" for r in rounds:\n",
|
||
" # 原始余弦退火 (单调递减)\n",
|
||
" old_lr = trainer.min_learning_rate + 0.5 * (trainer.initial_learning_rate - trainer.min_learning_rate) * (1 + np.cos(np.pi * r / 300))\n",
|
||
" old_lrs.append(old_lr)\n",
|
||
" \n",
|
||
" # 带重启的余弦退火\n",
|
||
" new_lr = trainer._cosine_annealing_with_warm_restarts(r)\n",
|
||
" new_lrs.append(new_lr)\n",
|
||
"\n",
|
||
" # 绘制对比图\n",
|
||
" ax1.plot(rounds, old_lrs, 'b-', label='原始余弦退火', linewidth=2)\n",
|
||
" ax1.set_xlabel('Training Round')\n",
|
||
" ax1.set_ylabel('Learning Rate')\n",
|
||
" ax1.set_title('原始余弦退火 (单调递减)')\n",
|
||
" ax1.grid(True, alpha=0.3)\n",
|
||
" ax1.legend()\n",
|
||
"\n",
|
||
" ax2.plot(rounds, new_lrs, 'r-', label='带重启的余弦退火', linewidth=2)\n",
|
||
" ax2.set_xlabel('Training Round')\n",
|
||
" ax2.set_ylabel('Learning Rate')\n",
|
||
" ax2.set_title('带重启的余弦退火 (SGDR)')\n",
|
||
" ax2.grid(True, alpha=0.3)\n",
|
||
" ax2.legend()\n",
|
||
"\n",
|
||
" plt.tight_layout()\n",
|
||
" plt.show()\n",
|
||
"\n",
|
||
" print(\"📊 学习率调度对比可视化完成\")\n",
|
||
" print(\" 🔵 原始版本:单调递减的余弦曲线\")\n",
|
||
" print(\" 🔴 新版本:周期性重启,每次重启后学习率回到最大值\")\n",
|
||
" print(\" 💡 SGDR的优势:多次重启可以帮助模型跳出局部最优解\")\n",
|
||
"\n",
|
||
" # 显示重启点\n",
|
||
" restart_points = []\n",
|
||
" t_cur = 0\n",
|
||
" t_i = trainer.t_0\n",
|
||
" while t_cur < 300:\n",
|
||
" restart_points.append(t_cur)\n",
|
||
" t_cur += t_i\n",
|
||
" t_i *= trainer.t_mult\n",
|
||
"\n",
|
||
" print(f\" 🔄 在300轮训练中的重启点: {restart_points[:5]}...\") # 显示前5个重启点"
|
||
]
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": []
|
||
},
|
||
{
|
||
"cell_type": "code",
|
||
"execution_count": null,
|
||
"metadata": {},
|
||
"outputs": [],
|
||
"source": [
|
||
"# 🧪 测试新的第三小样本数采样策略\n",
|
||
"\n",
|
||
"print(\"🧪 测试新的第三小样本数采样策略...\")\n",
|
||
"print(\"=\" * 60)\n",
|
||
"\n",
|
||
"# 模拟数据测试\n",
|
||
"np.random.seed(42)\n",
|
||
"random.seed(42)\n",
|
||
"\n",
|
||
"# 创建模拟数据\n",
|
||
"n_samples = 10000\n",
|
||
"n_features = 100\n",
|
||
"X_test = np.random.randn(n_samples, n_features)\n",
|
||
"\n",
|
||
"# 创建不平衡的标签分布\n",
|
||
"label_counts_target = {\n",
|
||
" 0: 5000, 1: 100, 2: 150, 3: 80, 4: 200, 5: 120, 6: 90, 7: 180, 8: 110, 9: 160,\n",
|
||
" 10: 140, 11: 170, 12: 130, 13: 190, 14: 105, 15: 95, 16: 175, 17: 125, 18: 155, 19: 135,\n",
|
||
" 20: 145, 21: 165, 22: 115, 23: 185, 24: 85, 25: 195, 26: 75, 27: 205, 28: 70, 29: 210,\n",
|
||
" 30: 65, 31: 215, 32: 60, 33: 220, 34: 55, 35: 225, 36: 50, 37: 230, 38: 45, 39: 235, 40: 3000\n",
|
||
"}\n",
|
||
"\n",
|
||
"y_test = []\n",
|
||
"for label, count in label_counts_target.items():\n",
|
||
" y_test.extend([label] * min(count, n_samples - len(y_test)))\n",
|
||
"y_test = np.array(y_test)\n",
|
||
"X_test = X_test[:len(y_test)]\n",
|
||
"\n",
|
||
"# 随机打乱\n",
|
||
"shuffle_idx = np.random.permutation(len(y_test))\n",
|
||
"X_test = X_test[shuffle_idx]\n",
|
||
"y_test = y_test[shuffle_idx]\n",
|
||
"\n",
|
||
"print(f\"模拟数据创建完成: {X_test.shape[0]:,} 样本, {X_test.shape[1]} 特征\")\n",
|
||
"\n",
|
||
"# 显示原始分布\n",
|
||
"original_counts = Counter(y_test)\n",
|
||
"all_counts = [original_counts.get(i, 0) for i in range(41)]\n",
|
||
"non_zero_counts = [c for c in all_counts if c > 0]\n",
|
||
"sorted_counts = sorted(non_zero_counts)\n",
|
||
"\n",
|
||
"print(f\"原始分布前10个最小: {sorted_counts[:10]}\")\n",
|
||
"print(f\"第三小样本数: {sorted_counts[2] if len(sorted_counts) >= 3 else 'N/A'}\")\n",
|
||
"\n",
|
||
"# 测试balance_dataset函数\n",
|
||
"X_balanced, y_balanced = balance_dataset(X_test, y_test)\n",
|
||
"\n",
|
||
"# 显示平衡后的分布\n",
|
||
"balanced_counts = Counter(y_balanced)\n",
|
||
"print(f\"\\n平衡后各标签样本数:\")\n",
|
||
"for label in range(41):\n",
|
||
" original = original_counts.get(label, 0)\n",
|
||
" balanced = balanced_counts.get(label, 0)\n",
|
||
" if original > 0 or balanced > 0:\n",
|
||
" status = \"📉\" if balanced < original else \"✅\" if balanced == original else \"📈\"\n",
|
||
" print(f\" {status} 标签 {label:2d}: {original:4d} → {balanced:4d}\")\n",
|
||
"\n",
|
||
"print(f\"\\n✅ 测试完成!\")\n",
|
||
"print(f\" 原始样本数: {len(y_test):,}\")\n",
|
||
"print(f\" 平衡后样本数: {len(y_balanced):,}\")\n",
|
||
"print(f\" 数据变化比例: {len(y_balanced)/len(y_test):.2f}x\")"
|
||
]
|
||
}
|
||
],
|
||
"metadata": {
|
||
"kaggle": {
|
||
"accelerator": "tpu1vmV38",
|
||
"dataSources": [
|
||
{
|
||
"databundleVersionId": 13056355,
|
||
"sourceId": 106809,
|
||
"sourceType": "competition"
|
||
}
|
||
],
|
||
"dockerImageVersionId": 31091,
|
||
"isGpuEnabled": false,
|
||
"isInternetEnabled": true,
|
||
"language": "python",
|
||
"sourceType": "notebook"
|
||
},
|
||
"kernelspec": {
|
||
"display_name": "Python 3 (ipykernel)",
|
||
"language": "python",
|
||
"name": "python3"
|
||
},
|
||
"language_info": {
|
||
"codemirror_mode": {
|
||
"name": "ipython",
|
||
"version": 3
|
||
},
|
||
"file_extension": ".py",
|
||
"mimetype": "text/x-python",
|
||
"name": "python",
|
||
"nbconvert_exporter": "python",
|
||
"pygments_lexer": "ipython3",
|
||
"version": "3.11.13"
|
||
}
|
||
},
|
||
"nbformat": 4,
|
||
"nbformat_minor": 4
|
||
}
|