{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 环境配置与Utils" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Looking in indexes: https://download.pytorch.org/whl/cu126\n", "Requirement already satisfied: torch in /usr/local/lib/python3.11/dist-packages (2.6.0+cu124)\n", "Requirement already satisfied: torchvision in /usr/local/lib/python3.11/dist-packages (0.21.0+cu124)\n", "Requirement already satisfied: torchaudio in /usr/local/lib/python3.11/dist-packages (2.6.0+cu124)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch) (3.18.0)\n", "Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.11/dist-packages (from torch) (4.14.0)\n", "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch) (3.5)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch) (3.1.6)\n", "Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch) (2025.5.1)\n", "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n", "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n", "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n", "Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.11/dist-packages (from torch) (9.1.0.70)\n", "Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.5.8)\n", "Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /usr/local/lib/python3.11/dist-packages (from torch) (11.2.1.3)\n", "Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /usr/local/lib/python3.11/dist-packages (from torch) (10.3.5.147)\n", "Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /usr/local/lib/python3.11/dist-packages (from torch) (11.6.1.9)\n", "Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /usr/local/lib/python3.11/dist-packages (from torch) (12.3.1.170)\n", "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch) (0.6.2)\n", "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch) (2.21.5)\n", "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n", "Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n", "Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch) (3.2.0)\n", "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch) (1.13.1)\n", "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch) (1.3.0)\n", "Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from torchvision) (1.26.4)\n", "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.11/dist-packages (from torchvision) (11.2.1)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch) (3.0.2)\n", "Requirement already satisfied: mkl_fft in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (1.3.8)\n", "Requirement already satisfied: mkl_random in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (1.2.4)\n", "Requirement already satisfied: mkl_umath in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (0.1.1)\n", "Requirement already satisfied: mkl in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (2025.2.0)\n", "Requirement already satisfied: tbb4py in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (2022.2.0)\n", "Requirement already satisfied: mkl-service in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (2.4.1)\n", "Requirement already satisfied: intel-openmp<2026,>=2024 in /usr/local/lib/python3.11/dist-packages (from mkl->numpy->torchvision) (2024.2.0)\n", "Requirement already satisfied: tbb==2022.* in /usr/local/lib/python3.11/dist-packages (from mkl->numpy->torchvision) (2022.2.0)\n", "Requirement already satisfied: tcmlib==1.* in /usr/local/lib/python3.11/dist-packages (from tbb==2022.*->mkl->numpy->torchvision) (1.4.0)\n", "Requirement already satisfied: intel-cmplr-lib-rt in /usr/local/lib/python3.11/dist-packages (from mkl_umath->numpy->torchvision) (2024.2.0)\n", "Requirement already satisfied: intel-cmplr-lib-ur==2024.2.0 in /usr/local/lib/python3.11/dist-packages (from intel-openmp<2026,>=2024->mkl->numpy->torchvision) (2024.2.0)\n", "Requirement already satisfied: jupyter==1.1.1 in /usr/local/lib/python3.11/dist-packages (1.1.1)\n", "Requirement already satisfied: numpy<2.1.0,>=1.26.0 in /usr/local/lib/python3.11/dist-packages (1.26.4)\n", "Requirement already satisfied: pandas==2.3.0 in /usr/local/lib/python3.11/dist-packages (2.3.0)\n", "Requirement already satisfied: matplotlib==3.10.1 in /usr/local/lib/python3.11/dist-packages (3.10.1)\n", "Requirement already satisfied: scipy==1.15.2 in /usr/local/lib/python3.11/dist-packages (1.15.2)\n", "Requirement already satisfied: scikit-learn==1.6.1 in /usr/local/lib/python3.11/dist-packages (1.6.1)\n", "Requirement already satisfied: lightgbm==4.3.0 in /usr/local/lib/python3.11/dist-packages (4.3.0)\n", "Requirement already satisfied: tqdm==4.67.1 in /usr/local/lib/python3.11/dist-packages (4.67.1)\n", "Requirement already satisfied: g2p_en==2.1.0 in /usr/local/lib/python3.11/dist-packages (2.1.0)\n", "Requirement already satisfied: h5py==3.13.0 in /usr/local/lib/python3.11/dist-packages (3.13.0)\n", "Requirement already satisfied: omegaconf==2.3.0 in /usr/local/lib/python3.11/dist-packages (2.3.0)\n", "Requirement already satisfied: editdistance==0.8.1 in /usr/local/lib/python3.11/dist-packages (0.8.1)\n", "Requirement already satisfied: huggingface-hub==0.33.1 in /usr/local/lib/python3.11/dist-packages (0.33.1)\n", "Requirement already satisfied: transformers==4.53.0 in /usr/local/lib/python3.11/dist-packages (4.53.0)\n", "Requirement already satisfied: tokenizers==0.21.2 in /usr/local/lib/python3.11/dist-packages (0.21.2)\n", "Requirement already satisfied: accelerate==1.8.1 in /usr/local/lib/python3.11/dist-packages (1.8.1)\n", "Requirement already satisfied: bitsandbytes==0.46.0 in /usr/local/lib/python3.11/dist-packages (0.46.0)\n", "Requirement already satisfied: seaborn==0.13.2 in /usr/local/lib/python3.11/dist-packages (0.13.2)\n", "Requirement already satisfied: notebook in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.5.4)\n", "Requirement already satisfied: jupyter-console in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.1.0)\n", "Requirement already satisfied: nbconvert in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.4.5)\n", "Requirement already satisfied: ipykernel in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.17.1)\n", "Requirement already satisfied: ipywidgets in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (8.1.5)\n", "Requirement already satisfied: jupyterlab in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (3.6.8)\n", "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas==2.3.0) (2.9.0.post0)\n", "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas==2.3.0) (2025.2)\n", "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas==2.3.0) (2025.2)\n", "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (1.3.2)\n", "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (0.12.1)\n", "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (4.58.4)\n", "Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (1.4.8)\n", "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (25.0)\n", "Requirement already satisfied: pillow>=8 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (11.2.1)\n", "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (3.0.9)\n", "Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn==1.6.1) (1.5.1)\n", "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn==1.6.1) (3.6.0)\n", "Requirement already satisfied: nltk>=3.2.4 in /usr/local/lib/python3.11/dist-packages (from g2p_en==2.1.0) (3.9.1)\n", "Requirement already satisfied: inflect>=0.3.1 in /usr/local/lib/python3.11/dist-packages (from g2p_en==2.1.0) (7.5.0)\n", "Requirement already satisfied: distance>=0.1.3 in /usr/local/lib/python3.11/dist-packages (from g2p_en==2.1.0) (0.1.3)\n", "Requirement already satisfied: antlr4-python3-runtime==4.9.* in /usr/local/lib/python3.11/dist-packages (from omegaconf==2.3.0) (4.9.3)\n", "Requirement already satisfied: PyYAML>=5.1.0 in /usr/local/lib/python3.11/dist-packages (from omegaconf==2.3.0) (6.0.2)\n", "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (3.18.0)\n", "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (2025.5.1)\n", "Requirement already satisfied: requests in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (2.32.4)\n", "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (4.14.0)\n", "Requirement already satisfied: hf-xet<2.0.0,>=1.1.2 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (1.1.5)\n", "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.11/dist-packages (from transformers==4.53.0) (2024.11.6)\n", "Requirement already satisfied: safetensors>=0.4.3 in /usr/local/lib/python3.11/dist-packages (from transformers==4.53.0) (0.5.3)\n", "Requirement already satisfied: psutil in /usr/local/lib/python3.11/dist-packages (from accelerate==1.8.1) (7.0.0)\n", "Requirement already satisfied: torch>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from accelerate==1.8.1) (2.6.0+cu124)\n", "Requirement already satisfied: mkl_fft in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (1.3.8)\n", "Requirement already satisfied: mkl_random in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (1.2.4)\n", "Requirement already satisfied: mkl_umath in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (0.1.1)\n", "Requirement already satisfied: mkl in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (2025.2.0)\n", "Requirement already satisfied: tbb4py in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (2022.2.0)\n", "Requirement already satisfied: mkl-service in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (2.4.1)\n", "Requirement already satisfied: more_itertools>=8.5.0 in /usr/local/lib/python3.11/dist-packages (from inflect>=0.3.1->g2p_en==2.1.0) (10.7.0)\n", "Requirement already satisfied: typeguard>=4.0.1 in /usr/local/lib/python3.11/dist-packages (from inflect>=0.3.1->g2p_en==2.1.0) (4.4.4)\n", "Requirement already satisfied: click in /usr/local/lib/python3.11/dist-packages (from nltk>=3.2.4->g2p_en==2.1.0) (8.2.1)\n", "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.8.2->pandas==2.3.0) (1.17.0)\n", "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (3.5)\n", "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (3.1.6)\n", "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n", "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n", "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n", "Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (9.1.0.70)\n", "Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.5.8)\n", "Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (11.2.1.3)\n", "Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (10.3.5.147)\n", "Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (11.6.1.9)\n", "Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.3.1.170)\n", "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (0.6.2)\n", "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (2.21.5)\n", "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n", "Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n", "Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (3.2.0)\n", "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (1.13.1)\n", "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch>=2.0.0->accelerate==1.8.1) (1.3.0)\n", "Requirement already satisfied: debugpy>=1.0 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (1.8.0)\n", "Requirement already satisfied: ipython>=7.23.1 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (7.34.0)\n", "Requirement already satisfied: jupyter-client>=6.1.12 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (8.6.3)\n", "Requirement already satisfied: matplotlib-inline>=0.1 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (0.1.7)\n", "Requirement already satisfied: nest-asyncio in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (1.6.0)\n", "Requirement already satisfied: pyzmq>=17 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (24.0.1)\n", "Requirement already satisfied: tornado>=6.1 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (6.5.1)\n", "Requirement already satisfied: traitlets>=5.1.0 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (5.7.1)\n", "Requirement already satisfied: comm>=0.1.3 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->jupyter==1.1.1) (0.2.2)\n", "Requirement already satisfied: widgetsnbextension~=4.0.12 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->jupyter==1.1.1) (4.0.14)\n", "Requirement already satisfied: jupyterlab-widgets~=3.0.12 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->jupyter==1.1.1) (3.0.15)\n", "Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-console->jupyter==1.1.1) (3.0.51)\n", "Requirement already satisfied: pygments in /usr/local/lib/python3.11/dist-packages (from jupyter-console->jupyter==1.1.1) (2.19.2)\n", "Requirement already satisfied: jupyter-core in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (5.8.1)\n", "Requirement already satisfied: jupyterlab-server~=2.19 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (2.27.3)\n", "Requirement already satisfied: jupyter-server<3,>=1.16.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (2.12.5)\n", "Requirement already satisfied: jupyter-ydoc~=0.2.4 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (0.2.5)\n", "Requirement already satisfied: jupyter-server-ydoc~=0.8.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (0.8.0)\n", "Requirement already satisfied: nbclassic in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (1.3.1)\n", "Requirement already satisfied: argon2-cffi in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (25.1.0)\n", "Requirement already satisfied: ipython-genutils in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (0.2.0)\n", "Requirement already satisfied: nbformat in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (5.10.4)\n", "Requirement already satisfied: Send2Trash>=1.8.0 in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (1.8.3)\n", "Requirement already satisfied: terminado>=0.8.3 in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (0.18.1)\n", "Requirement already satisfied: prometheus-client in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (0.22.1)\n", "Requirement already satisfied: mistune<2,>=0.8.1 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.8.4)\n", "Requirement already satisfied: jupyterlab-pygments in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.3.0)\n", "Requirement already satisfied: entrypoints>=0.2.2 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.4)\n", "Requirement already satisfied: bleach in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (6.2.0)\n", "Requirement already satisfied: pandocfilters>=1.4.1 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (1.5.1)\n", "Requirement already satisfied: testpath in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.6.0)\n", "Requirement already satisfied: defusedxml in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.7.1)\n", "Requirement already satisfied: beautifulsoup4 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (4.13.4)\n", "Requirement already satisfied: nbclient<0.6.0,>=0.5.0 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.5.13)\n", "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (3.0.2)\n", "Requirement already satisfied: intel-openmp<2026,>=2024 in /usr/local/lib/python3.11/dist-packages (from mkl->numpy<2.1.0,>=1.26.0) (2024.2.0)\n", "Requirement already satisfied: tbb==2022.* in /usr/local/lib/python3.11/dist-packages (from mkl->numpy<2.1.0,>=1.26.0) (2022.2.0)\n", "Requirement already satisfied: tcmlib==1.* in /usr/local/lib/python3.11/dist-packages (from tbb==2022.*->mkl->numpy<2.1.0,>=1.26.0) (1.4.0)\n", "Requirement already satisfied: intel-cmplr-lib-rt in /usr/local/lib/python3.11/dist-packages (from mkl_umath->numpy<2.1.0,>=1.26.0) (2024.2.0)\n", "Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (3.4.2)\n", "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (3.10)\n", "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (2.5.0)\n", "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (2025.6.15)\n", "Requirement already satisfied: intel-cmplr-lib-ur==2024.2.0 in /usr/local/lib/python3.11/dist-packages (from intel-openmp<2026,>=2024->mkl->numpy<2.1.0,>=1.26.0) (2024.2.0)\n", "Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (75.2.0)\n", "Requirement already satisfied: jedi>=0.16 in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.19.2)\n", "Requirement already satisfied: decorator in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (4.4.2)\n", "Requirement already satisfied: pickleshare in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.7.5)\n", "Requirement already satisfied: backcall in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.2.0)\n", "Requirement already satisfied: pexpect>4.3 in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (4.9.0)\n", "Requirement already satisfied: platformdirs>=2.5 in /usr/local/lib/python3.11/dist-packages (from jupyter-core->jupyterlab->jupyter==1.1.1) (4.3.8)\n", "Requirement already satisfied: anyio>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (4.9.0)\n", "Requirement already satisfied: jupyter-events>=0.9.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.12.0)\n", "Requirement already satisfied: jupyter-server-terminals in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.5.3)\n", "Requirement already satisfied: overrides in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (7.7.0)\n", "Requirement already satisfied: websocket-client in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.8.0)\n", "Requirement already satisfied: jupyter-server-fileid<1,>=0.6.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (0.9.3)\n", "Requirement already satisfied: ypy-websocket<0.9.0,>=0.8.2 in /usr/local/lib/python3.11/dist-packages (from jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (0.8.4)\n", "Requirement already satisfied: y-py<0.7.0,>=0.6.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-ydoc~=0.2.4->jupyterlab->jupyter==1.1.1) (0.6.2)\n", "Requirement already satisfied: babel>=2.10 in /usr/local/lib/python3.11/dist-packages (from jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (2.17.0)\n", "Requirement already satisfied: json5>=0.9.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (0.12.0)\n", "Requirement already satisfied: jsonschema>=4.18.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (4.24.0)\n", "Requirement already satisfied: notebook-shim>=0.2.3 in /usr/local/lib/python3.11/dist-packages (from nbclassic->jupyterlab->jupyter==1.1.1) (0.2.4)\n", "Requirement already satisfied: fastjsonschema>=2.15 in /usr/local/lib/python3.11/dist-packages (from nbformat->notebook->jupyter==1.1.1) (2.21.1)\n", "Requirement already satisfied: wcwidth in /usr/local/lib/python3.11/dist-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->jupyter-console->jupyter==1.1.1) (0.2.13)\n", "Requirement already satisfied: ptyprocess in /usr/local/lib/python3.11/dist-packages (from terminado>=0.8.3->notebook->jupyter==1.1.1) (0.7.0)\n", "Requirement already satisfied: argon2-cffi-bindings in /usr/local/lib/python3.11/dist-packages (from argon2-cffi->notebook->jupyter==1.1.1) (21.2.0)\n", "Requirement already satisfied: soupsieve>1.2 in /usr/local/lib/python3.11/dist-packages (from beautifulsoup4->nbconvert->jupyter==1.1.1) (2.7)\n", "Requirement already satisfied: webencodings in /usr/local/lib/python3.11/dist-packages (from bleach->nbconvert->jupyter==1.1.1) (0.5.1)\n", "Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.11/dist-packages (from anyio>=3.1.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.3.1)\n", "Requirement already satisfied: parso<0.9.0,>=0.8.4 in /usr/local/lib/python3.11/dist-packages (from jedi>=0.16->ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.8.4)\n", "Requirement already satisfied: attrs>=22.2.0 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (25.3.0)\n", "Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (2025.4.1)\n", "Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (0.36.2)\n", "Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (0.25.1)\n", "Requirement already satisfied: python-json-logger>=2.0.4 in /usr/local/lib/python3.11/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (3.3.0)\n", "Requirement already satisfied: rfc3339-validator in /usr/local/lib/python3.11/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.1.4)\n", "Requirement already satisfied: rfc3986-validator>=0.1.1 in /usr/local/lib/python3.11/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.1.1)\n", "Requirement already satisfied: aiofiles<23,>=22.1.0 in /usr/local/lib/python3.11/dist-packages (from ypy-websocket<0.9.0,>=0.8.2->jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (22.1.0)\n", "Requirement already satisfied: aiosqlite<1,>=0.17.0 in /usr/local/lib/python3.11/dist-packages (from ypy-websocket<0.9.0,>=0.8.2->jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (0.21.0)\n", "Requirement already satisfied: cffi>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from argon2-cffi-bindings->argon2-cffi->notebook->jupyter==1.1.1) (1.17.1)\n", "Requirement already satisfied: pycparser in /usr/local/lib/python3.11/dist-packages (from cffi>=1.0.1->argon2-cffi-bindings->argon2-cffi->notebook->jupyter==1.1.1) (2.22)\n", "Requirement already satisfied: fqdn in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.5.1)\n", "Requirement already satisfied: isoduration in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (20.11.0)\n", "Requirement already satisfied: jsonpointer>1.13 in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (3.0.0)\n", "Requirement already satisfied: uri-template in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.3.0)\n", "Requirement already satisfied: webcolors>=24.6.0 in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (24.11.1)\n", "Requirement already satisfied: arrow>=0.15.0 in /usr/local/lib/python3.11/dist-packages (from isoduration->jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.3.0)\n", "Requirement already satisfied: types-python-dateutil>=2.8.10 in /usr/local/lib/python3.11/dist-packages (from arrow>=0.15.0->isoduration->jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (2.9.0.20250516)\n", "Obtaining file:///kaggle/working/nejm-brain-to-text\n", " Preparing metadata (setup.py): started\n", " Preparing metadata (setup.py): finished with status 'done'\n", "Installing collected packages: nejm_b2txt_utils\n", " Running setup.py develop for nejm_b2txt_utils\n", "Successfully installed nejm_b2txt_utils-0.0.0\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ "Cloning into 'nejm-brain-to-text'...\n", "Updating files: 100% (2633/2633), done.\n" ] } ], "source": [ "%%bash\n", "rm -rf /kaggle/working/nejm-brain-to-text/\n", "git clone https://github.com/ZH-CEN/nejm-brain-to-text.git\n", "cp /kaggle/input/brain-to-text-baseline-model/t15_copyTask.pkl /kaggle/working/nejm-brain-to-text/data/t15_copyTask.pkl\n", "\n", "ln -s /kaggle/input/brain-to-text-25/t15_pretrained_rnn_baseline/t15_pretrained_rnn_baseline /kaggle/working/nejm-brain-to-text/data\n", "ln -s /kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final /kaggle/working/nejm-brain-to-text/data\n", "ln -s /kaggle/input/rnn-pretagged-data /kaggle/working/nejm-brain-to-text/data/concatenated_data\n", "\n", "pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126\n", "\n", "pip install \\\n", " jupyter==1.1.1 \\\n", " \"numpy>=1.26.0,<2.1.0\" \\\n", " pandas==2.3.0 \\\n", " matplotlib==3.10.1 \\\n", " scipy==1.15.2 \\\n", " scikit-learn==1.6.1 \\\n", " lightgbm==4.3.0 \\\n", " tqdm==4.67.1 \\\n", " g2p_en==2.1.0 \\\n", " h5py==3.13.0 \\\n", " omegaconf==2.3.0 \\\n", " editdistance==0.8.1 \\\n", " huggingface-hub==0.33.1 \\\n", " transformers==4.53.0 \\\n", " tokenizers==0.21.2 \\\n", " accelerate==1.8.1 \\\n", " bitsandbytes==0.46.0 \\\n", " seaborn==0.13.2\n", "cd /kaggle/working/nejm-brain-to-text/\n", "pip install -e ." ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "==================================================\n", "🔧 LightGBM GPU环境检查\n", "==================================================\n", "❌ 未检测到NVIDIA GPU或驱动\n", "\n", "✅ CUDA工具包:\n", " Cuda compilation tools, release 12.5, V12.5.82\n" ] } ], "source": [ "# 🚀 LightGBM GPU支持检查与配置\n", "\n", "print(\"=\"*50)\n", "print(\"🔧 LightGBM GPU环境检查\")\n", "print(\"=\"*50)\n", "\n", "# 检查CUDA和GPU驱动\n", "import subprocess\n", "import sys\n", "\n", "def run_command(command):\n", " \"\"\"运行命令并返回结果\"\"\"\n", " try:\n", " result = subprocess.run(command, shell=True, capture_output=True, text=True, timeout=10)\n", " return result.stdout.strip(), result.returncode == 0\n", " except Exception as e:\n", " return str(e), False\n", "\n", "# 检查NVIDIA GPU\n", "nvidia_output, nvidia_success = run_command(\"nvidia-smi --query-gpu=name,memory.total,driver_version --format=csv,noheader,nounits\")\n", "if nvidia_success:\n", " print(\"✅ NVIDIA GPU检测:\")\n", " for line in nvidia_output.split('\\n'):\n", " if line.strip():\n", " print(f\" {line}\")\n", "else:\n", " print(\"❌ 未检测到NVIDIA GPU或驱动\")\n", "\n", "# 检查CUDA版本\n", "cuda_output, cuda_success = run_command(\"nvcc --version\")\n", "if cuda_success:\n", " print(\"\\n✅ CUDA工具包:\")\n", " # 提取CUDA版本\n", " for line in cuda_output.split('\\n'):\n", " if 'release' in line:\n", " print(f\" {line.strip()}\")\n", "else:\n", " print(\"\\n❌ 未安装CUDA工具包\")\n" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/kaggle/working/nejm-brain-to-text\n" ] } ], "source": [ "%cd /kaggle/working/nejm-brain-to-text\n", "import numpy as np\n", "import os\n", "import pickle\n", "import matplotlib.pyplot as plt\n", "import matplotlib\n", "from g2p_en import G2p\n", "import pandas as pd\n", "import numpy as np\n", "from nejm_b2txt_utils.general_utils import *\n", "matplotlib.rcParams['pdf.fonttype'] = 42\n", "matplotlib.rcParams['ps.fonttype'] = 42\n", "matplotlib.rcParams['font.family'] = 'sans-serif'\n", "matplotlib.rcParams['font.sans-serif'] = ['SimHei', 'DejaVu Sans', 'Arial Unicode MS', 'sans-serif']\n", "matplotlib.rcParams['axes.unicode_minus'] = False\n" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "/kaggle/working/nejm-brain-to-text/model_training\n" ] } ], "source": [ "%cd model_training/\n", "from data_augmentations import gauss_smooth" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "LOGIT_TO_PHONEME = [\n", " 'BLANK',\n", " 'AA', 'AE', 'AH', 'AO', 'AW',\n", " 'AY', 'B', 'CH', 'D', 'DH',\n", " 'EH', 'ER', 'EY', 'F', 'G',\n", " 'HH', 'IH', 'IY', 'JH', 'K',\n", " 'L', 'M', 'N', 'NG', 'OW',\n", " 'OY', 'P', 'R', 'S', 'SH',\n", " 'T', 'TH', 'UH', 'UW', 'V',\n", " 'W', 'Y', 'Z', 'ZH',\n", " ' | ',\n", "]\n", "# 全局配置\n", "BALANCE_CONFIG = {\n", " 'enable_balance': True, # 是否启用数据平衡\n", " 'undersample_labels': [0, 40], # 需要下采样的标签 (blank等高频标签)\n", " 'oversample_threshold': 0.5, # 过采样阈值 (相对于均值的比例)\n", " 'random_state': 42 # 随机种子\n", "}\n", "# 全局PCA配置\n", "PCA_CONFIG = {\n", " 'enable_pca': True, # 是否启用PCA\n", " 'n_components': None, # None=自动选择, 或指定具体数值\n", " 'variance_threshold': 0.95, # 保留95%的方差\n", " 'sample_size': 15000, # 用于拟合PCA的样本数\n", "}\n", "\n", "# 全局PCA对象 (确保只拟合一次)\n", "GLOBAL_PCA = {\n", " 'scaler': None,\n", " 'pca': None,\n", " 'is_fitted': False,\n", " 'n_components': None\n", "}\n", "# 设置数据目录和参数【PCA初始化】\n", "data_dir = '/kaggle/working/nejm-brain-to-text/data/concatenated_data'\n", "MAX_SAMPLES_PER_FILE = -1 # 每个文件最大样本数,可调整" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 数据读取工作流" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 2️⃣ 数据加载与PCA降维" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [], "source": [ "# 🚀 内存友好的数据读取 - 分批加载策略 + PCA降维 【这里还缺一个采样】\n", "\n", "import os\n", "import numpy as np\n", "import gc\n", "from sklearn.decomposition import PCA\n", "from sklearn.preprocessing import StandardScaler\n", "import joblib\n", "import matplotlib.pyplot as plt\n", "\n", "\n", "def load_data_batch(data_dir, data_type, max_samples_per_file=5000):\n", " \"\"\"\n", " 分批加载指定类型的数据\n", " \n", " Args:\n", " data_dir: 数据目录\n", " data_type: 'train', 'val', 'test'\n", " max_samples_per_file: 每个文件最大加载样本数\n", " \n", " Returns:\n", " generator: 数据批次生成器\n", " \"\"\"\n", " files = [f for f in os.listdir(data_dir) if f.endswith('.npz') and data_type in f]\n", " \n", " for file_idx, f in enumerate(files):\n", " print(f\" 正在加载文件 {file_idx+1}/{len(files)}: {f}\")\n", " \n", " data = np.load(os.path.join(data_dir, f), allow_pickle=True)\n", " trials = data['neural_logits_concatenated']\n", " \n", " # 限制每个文件的样本数\n", " if len(trials) > max_samples_per_file and max_samples_per_file != -1:\n", " trials = trials[:max_samples_per_file]\n", " print(f\" 限制样本数至: {max_samples_per_file}\")\n", " \n", " yield trials, f\n", " \n", " # 清理内存\n", " del data, trials\n", " gc.collect()\n", "\n", "def extract_features_labels_batch(trials_batch):\n", " \"\"\"\n", " 从试验批次中提取特征和标签\n", " \"\"\"\n", " features = []\n", " labels = []\n", " \n", " for trial in trials_batch:\n", " if trial.shape[0] > 0:\n", " for t in range(trial.shape[0]):\n", " neural_features = trial[t, :7168] # 前7168维神经特征\n", " rnn_logits = trial[t, 7168:] # 后41维RNN输出\n", " phoneme_label = np.argmax(rnn_logits)\n", " \n", " features.append(neural_features)\n", " labels.append(phoneme_label)\n", " \n", " return np.array(features), np.array(labels)\n", "\n", "def fit_global_pca(data_dir, config):\n", " \"\"\"\n", " 在训练数据上拟合全局PCA (只执行一次)\n", " \"\"\"\n", " if GLOBAL_PCA['is_fitted'] or not config['enable_pca']:\n", " print(\"🔧 PCA已拟合或未启用,跳过拟合步骤\")\n", " return\n", " \n", " print(f\"\\n🔧 拟合全局PCA降维器...\")\n", " print(f\" 配置: {config}\")\n", " \n", " # 收集训练样本\n", " sample_features = []\n", " collected_samples = 0\n", " \n", " for trials_batch, filename in load_data_batch(data_dir, 'train', 5000):\n", " features, labels = extract_features_labels_batch(trials_batch)\n", " sample_features.append(features)\n", " collected_samples += features.shape[0]\n", " \n", " if collected_samples >= config['sample_size']:\n", " break\n", " \n", " if sample_features:\n", " # 合并样本数据\n", " X_sample = np.vstack(sample_features)[:config['sample_size']]\n", " print(f\" 实际样本数: {X_sample.shape[0]}\")\n", " print(f\" 原始特征数: {X_sample.shape[1]}\")\n", " \n", " # 标准化\n", " GLOBAL_PCA['scaler'] = StandardScaler()\n", " X_sample_scaled = GLOBAL_PCA['scaler'].fit_transform(X_sample)\n", " \n", " # 确定PCA成分数\n", " if config['n_components'] is None:\n", " print(f\" 🔍 自动选择PCA成分数...\")\n", " pca_full = PCA()\n", " pca_full.fit(X_sample_scaled)\n", " \n", " cumsum_var = np.cumsum(pca_full.explained_variance_ratio_)\n", " optimal_components = np.argmax(cumsum_var >= config['variance_threshold']) + 1\n", " GLOBAL_PCA['n_components'] = min(optimal_components, X_sample.shape[1])\n", " \n", " print(f\" 保留{config['variance_threshold']*100}%方差需要: {optimal_components} 个成分\")\n", " print(f\" 选择成分数: {GLOBAL_PCA['n_components']}\")\n", " else:\n", " GLOBAL_PCA['n_components'] = config['n_components']\n", " print(f\" 使用指定成分数: {GLOBAL_PCA['n_components']}\")\n", " \n", " # 拟合最终PCA\n", " GLOBAL_PCA['pca'] = PCA(n_components=GLOBAL_PCA['n_components'], random_state=42)\n", " GLOBAL_PCA['pca'].fit(X_sample_scaled)\n", " GLOBAL_PCA['is_fitted'] = True\n", " \n", " # 保存模型\n", " pca_path = \"global_pca_model.joblib\"\n", " joblib.dump({\n", " 'scaler': GLOBAL_PCA['scaler'], \n", " 'pca': GLOBAL_PCA['pca'],\n", " 'n_components': GLOBAL_PCA['n_components']\n", " }, pca_path)\n", " \n", " print(f\" ✅ 全局PCA拟合完成!\")\n", " print(f\" 降维: {X_sample.shape[1]} → {GLOBAL_PCA['n_components']}\")\n", " print(f\" 降维比例: {GLOBAL_PCA['n_components']/X_sample.shape[1]:.2%}\")\n", " print(f\" 保留方差: {GLOBAL_PCA['pca'].explained_variance_ratio_.sum():.4f}\")\n", " print(f\" 模型已保存: {pca_path}\")\n", " \n", " # 清理样本数据\n", " del sample_features, X_sample, X_sample_scaled\n", " gc.collect()\n", " else:\n", " print(\"❌ 无法收集样本数据用于PCA拟合\")\n", "\n", "def apply_pca_transform(features):\n", " \"\"\"\n", " 应用全局PCA变换\n", " \"\"\"\n", " if not PCA_CONFIG['enable_pca'] or not GLOBAL_PCA['is_fitted']:\n", " return features\n", " \n", " # 标准化 + PCA变换\n", " features_scaled = GLOBAL_PCA['scaler'].transform(features)\n", " features_pca = GLOBAL_PCA['pca'].transform(features_scaled)\n", " return features_pca" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 📊 数据平衡策略 - 标签分布分析与采样优化" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "# 【采样核心实现】\n", "def balance_dataset(X, y, config=BALANCE_CONFIG):\n", " \"\"\"\n", " 对数据集进行平衡处理:下采样 + 过采样\n", " \n", " Args:\n", " X: 特征数据\n", " y: 标签数据\n", " config: 平衡配置\n", " \n", " Returns:\n", " X_balanced, y_balanced: 平衡后的数据\n", " \"\"\"\n", " if not config['enable_balance']:\n", " print(\"🔕 数据平衡已禁用,返回原始数据\")\n", " return X, y\n", " \n", " print(f\"\\n⚖️ 开始数据平衡处理...\")\n", " print(f\" 原始数据: {X.shape[0]:,} 样本\")\n", " \n", " # 分析当前分布 (只考虑1-39号标签的均值)\n", " label_counts = Counter(y)\n", " counts_exclude_0_40 = [label_counts.get(i, 0) for i in range(1, 40)] # 1-39号标签\n", " mean_count = np.mean(counts_exclude_0_40) # 只计算1-39号标签的均值\n", " \n", " print(f\" 均值样本数 (标签1-39): {mean_count:.0f}\")\n", " print(f\" 下采样标签: {config['undersample_labels']}\")\n", " print(f\" 过采样阈值: {config['oversample_threshold']} * 均值\")\n", " \n", " # 准备平衡后的数据\n", " X_balanced = []\n", " y_balanced = []\n", " \n", " random.seed(config['random_state'])\n", " np.random.seed(config['random_state'])\n", " \n", " for label in range(41):\n", " # 获取当前标签的所有样本\n", " label_mask = (y == label)\n", " X_label = X[label_mask]\n", " y_label = y[label_mask]\n", " current_count = len(y_label)\n", " \n", " if current_count == 0:\n", " continue\n", " \n", " # 决定采样策略\n", " if label in config['undersample_labels']:\n", " # 下采样到均值水平\n", " target_count = int(mean_count)\n", " if current_count > target_count:\n", " # 下采样\n", " indices = np.random.choice(current_count, target_count, replace=False)\n", " X_resampled = X_label[indices]\n", " y_resampled = y_label[indices]\n", " print(f\" 📉 标签 {label}: {current_count} → {target_count} (下采样)\")\n", " else:\n", " X_resampled = X_label\n", " y_resampled = y_label\n", " print(f\" ➡️ 标签 {label}: {current_count} (无需下采样)\")\n", " \n", " elif current_count < mean_count * config['oversample_threshold']:\n", " # 过采样到阈值水平\n", " target_count = int(mean_count * config['oversample_threshold'])\n", " if current_count < target_count:\n", " # 过采样\n", " X_resampled, y_resampled = resample(\n", " X_label, y_label, \n", " n_samples=target_count, \n", " random_state=config['random_state']\n", " )\n", " print(f\" 📈 标签 {label}: {current_count} → {target_count} (过采样)\")\n", " else:\n", " X_resampled = X_label\n", " y_resampled = y_label\n", " print(f\" ➡️ 标签 {label}: {current_count} (无需过采样)\")\n", " else:\n", " # 保持不变\n", " X_resampled = X_label\n", " y_resampled = y_label\n", " print(f\" ✅ 标签 {label}: {current_count} (已平衡)\")\n", " \n", " X_balanced.append(X_resampled)\n", " y_balanced.append(y_resampled)\n", " \n", " # 合并所有平衡后的数据\n", " X_balanced = np.vstack(X_balanced)\n", " y_balanced = np.hstack(y_balanced)\n", " \n", " # 随机打乱\n", " shuffle_indices = np.random.permutation(len(y_balanced))\n", " X_balanced = X_balanced[shuffle_indices]\n", " y_balanced = y_balanced[shuffle_indices]\n", " \n", " print(f\" ✅ 平衡完成: {X_balanced.shape[0]:,} 样本\")\n", " print(f\" 数据变化: {X.shape[0]:,} → {X_balanced.shape[0]:,} ({X_balanced.shape[0]/X.shape[0]:.2f}x)\")\n", " \n", " return X_balanced, y_balanced\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 🔄 集成数据平衡的内存友好数据加载器" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 🧪 数据平衡效果测试" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 🚀 改进版智能数据处理管道" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "🚀 创建智能数据处理管道...\n", "✅ 管道创建完成,准备执行步骤1...\n" ] } ], "source": [ "# 🚀 改进版智能数据处理管道【没有解决分批训练的问题】\n", "# 流程:分析分布 → 确定采样比率 → 拟合PCA(只下采样) → 数据处理(下采样+上采样+PCA)\n", "\n", "import numpy as np\n", "import matplotlib.pyplot as plt\n", "from collections import Counter\n", "from sklearn.utils import resample\n", "from sklearn.decomposition import PCA\n", "from sklearn.preprocessing import StandardScaler\n", "import joblib\n", "import random\n", "import gc\n", "\n", "class SmartDataPipeline:\n", " \"\"\"\n", " 智能数据处理管道\n", " 步骤1: 分析数据分布,确定采样策略\n", " 步骤2: 仅下采样拟合PCA参数\n", " 步骤3: 数据处理时应用完整采样+PCA降维\n", " \"\"\"\n", " \n", " def __init__(self, data_dir, random_state=42):\n", " self.data_dir = data_dir\n", " self.random_state = random_state\n", " \n", " # 步骤1: 分布分析结果\n", " self.distribution_analysis = None\n", " self.sampling_strategy = None\n", " \n", " # 步骤2: PCA参数(基于下采样数据拟合)\n", " self.pca_scaler = None\n", " self.pca_model = None\n", " self.pca_components = None\n", " self.pca_fitted = False\n", " \n", " # 配置参数\n", " self.undersample_labels = [0, 40] # 需要下采样的标签\n", " self.oversample_threshold = 0.5 # 过采样阈值(相对于均值)\n", " self.pca_variance_threshold = 0.95 # PCA保留方差比例\n", " self.pca_sample_size = 15000 # PCA拟合样本数\n", " \n", " def step1_analyze_distribution(self, max_samples=100000):\n", " \"\"\"\n", " 步骤1: 分析数据分布,确定采样策略\n", " \"\"\"\n", " print(\"🔍 步骤1: 分析数据分布...\")\n", " \n", " # 分析验证集分布(代表整体分布特征)\n", " all_labels = []\n", " for trials_batch, filename in load_data_batch(self.data_dir, 'val', 5000):\n", " _, labels = extract_features_labels_batch(trials_batch)\n", " all_labels.extend(labels.tolist())\n", " if len(all_labels) >= max_samples:\n", " break\n", " \n", " # 统计分析\n", " label_counts = Counter(all_labels)\n", " \n", " # 计算1-39标签的均值(排除0和40)\n", " counts_1_39 = [label_counts.get(i, 0) for i in range(1, 40)]\n", " target_mean = np.mean(counts_1_39)\n", " \n", " # 生成采样策略\n", " sampling_strategy = {}\n", " for label in range(41):\n", " current_count = label_counts.get(label, 0)\n", " \n", " if label in self.undersample_labels:\n", " # 下采样到均值水平\n", " target_count = int(target_mean)\n", " action = 'undersample' if current_count > target_count else 'keep'\n", " elif current_count < target_mean * self.oversample_threshold:\n", " # 过采样到阈值水平\n", " target_count = int(target_mean * self.oversample_threshold)\n", " action = 'oversample' if current_count < target_count else 'keep'\n", " else:\n", " # 保持不变\n", " target_count = current_count\n", " action = 'keep'\n", " \n", " sampling_strategy[label] = {\n", " 'current_count': current_count,\n", " 'target_count': target_count,\n", " 'action': action\n", " }\n", " \n", " self.distribution_analysis = {\n", " 'label_counts': label_counts,\n", " 'target_mean': target_mean,\n", " 'total_samples': len(all_labels)\n", " }\n", " self.sampling_strategy = sampling_strategy\n", " \n", " print(f\" ✅ 分析完成: {len(all_labels):,} 样本\")\n", " print(f\" 📊 标签1-39均值: {target_mean:.0f}\")\n", " print(f\" 📉 下采样标签: {self.undersample_labels} → {target_mean:.0f}\")\n", " print(f\" 📈 过采样阈值: {self.oversample_threshold} × 均值 = {target_mean * self.oversample_threshold:.0f}\")\n", " \n", " return self.distribution_analysis, self.sampling_strategy\n", "\n", "# 创建智能数据处理管道\n", "print(\"🚀 创建智能数据处理管道...\")\n", "pipeline = SmartDataPipeline(data_dir, random_state=42)\n", "print(\"✅ 管道创建完成,准备执行步骤1...\")" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "✅ 步骤2方法已添加到管道\n" ] } ], "source": [ "# 继续添加智能管道的其他方法【管道完善】\n", "\n", "def step2_fit_pca_with_undersampling(self):\n", " \"\"\"\n", " 步骤2: 仅对下采样数据拟合PCA参数(不进行过采样,避免PCA被过采样影响)\n", " \"\"\"\n", " if self.sampling_strategy is None:\n", " raise ValueError(\"请先执行步骤1: step1_analyze_distribution()\")\n", " \n", " print(\"\\n🔧 步骤2: 拟合PCA参数(仅下采样,不过采样)...\")\n", " \n", " # 收集用于PCA拟合的样本(只下采样,不过采样)\n", " pca_features = []\n", " collected_samples = 0\n", " \n", " for trials_batch, filename in load_data_batch(self.data_dir, 'train', 3000):\n", " features, labels = extract_features_labels_batch(trials_batch)\n", " \n", " # 对当前批次应用仅下采样策略\n", " downsampled_features, downsampled_labels = self._apply_undersampling_only(features, labels)\n", " \n", " if downsampled_features.shape[0] > 0:\n", " pca_features.append(downsampled_features)\n", " collected_samples += downsampled_features.shape[0]\n", " \n", " if collected_samples >= self.pca_sample_size:\n", " break\n", " \n", " if pca_features:\n", " # 合并样本\n", " X_pca_sample = np.vstack(pca_features)[:self.pca_sample_size]\n", " print(f\" 📦 PCA拟合样本: {X_pca_sample.shape[0]:,} 个下采样样本\")\n", " print(f\" 🔢 原始特征维度: {X_pca_sample.shape[1]}\")\n", " \n", " # 标准化\n", " self.pca_scaler = StandardScaler()\n", " X_scaled = self.pca_scaler.fit_transform(X_pca_sample)\n", " \n", " # 确定PCA成分数\n", " pca_full = PCA()\n", " pca_full.fit(X_scaled)\n", " cumsum_var = np.cumsum(pca_full.explained_variance_ratio_)\n", " optimal_components = np.argmax(cumsum_var >= self.pca_variance_threshold) + 1\n", " self.pca_components = min(optimal_components, X_pca_sample.shape[1])\n", " \n", " # 拟合最终PCA\n", " self.pca_model = PCA(n_components=self.pca_components, random_state=self.random_state)\n", " self.pca_model.fit(X_scaled)\n", " self.pca_fitted = True\n", " \n", " # 保存PCA模型\n", " pca_path = \"smart_pipeline_pca.joblib\"\n", " joblib.dump({\n", " 'scaler': self.pca_scaler,\n", " 'pca': self.pca_model,\n", " 'components': self.pca_components\n", " }, pca_path)\n", " \n", " print(f\" ✅ PCA拟合完成!\")\n", " print(f\" 降维: {X_pca_sample.shape[1]} → {self.pca_components}\")\n", " print(f\" 降维比例: {self.pca_components/X_pca_sample.shape[1]:.2%}\")\n", " print(f\" 保留方差: {self.pca_model.explained_variance_ratio_.sum():.4f}\")\n", " print(f\" 模型保存: {pca_path}\")\n", " \n", " # 清理内存\n", " del pca_features, X_pca_sample, X_scaled\n", " gc.collect()\n", " else:\n", " raise ValueError(\"无法收集PCA拟合样本\")\n", "\n", "def _apply_undersampling_only(self, X, y):\n", " \"\"\"\n", " 仅应用下采样策略(用于PCA拟合)\n", " \"\"\"\n", " X_result = []\n", " y_result = []\n", " \n", " np.random.seed(self.random_state)\n", " \n", " for label in range(41):\n", " label_mask = (y == label)\n", " X_label = X[label_mask]\n", " y_label = y[label_mask]\n", " current_count = len(y_label)\n", " \n", " if current_count == 0:\n", " continue\n", " \n", " strategy = self.sampling_strategy[label]\n", " \n", " if strategy['action'] == 'undersample' and current_count > strategy['target_count']:\n", " # 下采样\n", " indices = np.random.choice(current_count, strategy['target_count'], replace=False)\n", " X_resampled = X_label[indices]\n", " y_resampled = y_label[indices]\n", " else:\n", " # 保持原样\n", " X_resampled = X_label\n", " y_resampled = y_label\n", " \n", " X_result.append(X_resampled)\n", " y_result.append(y_resampled)\n", " \n", " if X_result:\n", " return np.vstack(X_result), np.hstack(y_result)\n", " else:\n", " return np.array([]).reshape(0, X.shape[1]), np.array([])\n", "\n", "# 动态添加方法到类\n", "SmartDataPipeline.step2_fit_pca_with_undersampling = step2_fit_pca_with_undersampling\n", "SmartDataPipeline._apply_undersampling_only = _apply_undersampling_only\n", "\n", "print(\"✅ 步骤2方法已添加到管道\")" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "✅ 所有方法已添加到智能管道\n", "\n", "📋 智能数据处理管道状态:\n", " 🔍 步骤1 - 分布分析: ❌ 未完成\n", " 🔧 步骤2 - PCA拟合: ❌ 未完成\n", "\n", "🎯 使用流程:\n", " 1. pipeline.step1_analyze_distribution()\n", " 2. pipeline.step2_fit_pca_with_undersampling()\n", " 3. pipeline.step3_process_data('train') # 训练集\n", " pipeline.step3_process_data('val') # 验证集\n" ] } ], "source": [ "# 添加智能管道的剩余方法\n", "\n", "def _apply_full_sampling(self, X, y):\n", " \"\"\"\n", " 应用完整的采样策略(下采样+过采样)\n", " \"\"\"\n", " X_result = []\n", " y_result = []\n", " \n", " np.random.seed(self.random_state)\n", " \n", " for label in range(41):\n", " label_mask = (y == label)\n", " X_label = X[label_mask]\n", " y_label = y[label_mask]\n", " current_count = len(y_label)\n", " \n", " if current_count == 0:\n", " continue\n", " \n", " strategy = self.sampling_strategy[label]\n", " target_count = strategy['target_count']\n", " \n", " if strategy['action'] == 'undersample' and current_count > target_count:\n", " # 下采样\n", " indices = np.random.choice(current_count, target_count, replace=False)\n", " X_resampled = X_label[indices]\n", " y_resampled = y_label[indices]\n", " elif strategy['action'] == 'oversample' and current_count < target_count:\n", " # 过采样\n", " X_resampled, y_resampled = resample(\n", " X_label, y_label, \n", " n_samples=target_count, \n", " random_state=self.random_state\n", " )\n", " else:\n", " # 保持原样\n", " X_resampled = X_label\n", " y_resampled = y_label\n", " \n", " X_result.append(X_resampled)\n", " y_result.append(y_resampled)\n", " \n", " if X_result:\n", " return np.vstack(X_result), np.hstack(y_result)\n", " else:\n", " return np.array([]).reshape(0, X.shape[1]), np.array([])\n", "\n", "def _apply_pca_transform(self, X):\n", " \"\"\"\n", " 应用PCA变换\n", " \"\"\"\n", " if not self.pca_fitted:\n", " return X\n", " \n", " X_scaled = self.pca_scaler.transform(X)\n", " X_pca = self.pca_model.transform(X_scaled)\n", " return X_pca\n", "\n", "def step3_process_data(self, data_type, apply_sampling=None):\n", " \"\"\"\n", " 步骤3: 处理数据(采样+PCA降维)\n", " \n", " Args:\n", " data_type: 'train', 'val', 'test'\n", " apply_sampling: 是否应用采样策略,None=训练集应用,验证/测试集不应用\n", " \"\"\"\n", " if not self.pca_fitted:\n", " raise ValueError(\"请先执行步骤2: step2_fit_pca_with_undersampling()\")\n", " \n", " if apply_sampling is None:\n", " apply_sampling = (data_type == 'train')\n", " \n", " print(f\"\\n🔄 步骤3: 处理{data_type}数据...\")\n", " print(f\" 采样策略: {'启用' if apply_sampling else '禁用'}\")\n", " \n", " all_features = []\n", " all_labels = []\n", " \n", " for trials_batch, filename in load_data_batch(self.data_dir, data_type, 3000):\n", " features, labels = extract_features_labels_batch(trials_batch)\n", " \n", " # 应用采样策略\n", " if apply_sampling:\n", " features_sampled, labels_sampled = self._apply_full_sampling(features, labels)\n", " else:\n", " features_sampled, labels_sampled = features, labels\n", " \n", " # 应用PCA降维\n", " if features_sampled.shape[0] > 0:\n", " features_pca = self._apply_pca_transform(features_sampled)\n", " all_features.append(features_pca)\n", " all_labels.append(labels_sampled)\n", " \n", " if all_features:\n", " X = np.vstack(all_features)\n", " y = np.hstack(all_labels)\n", " \n", " # 随机打乱\n", " shuffle_indices = np.random.permutation(len(y))\n", " X = X[shuffle_indices]\n", " y = y[shuffle_indices]\n", " \n", " print(f\" ✅ 处理完成: {X.shape[0]:,} 样本, {X.shape[1]} 特征\")\n", " \n", " # 清理内存\n", " del all_features, all_labels\n", " gc.collect()\n", " \n", " return X, y\n", " else:\n", " return None, None\n", "\n", "def print_summary(self):\n", " \"\"\"\n", " 打印管道状态总结\n", " \"\"\"\n", " print(\"\\n📋 智能数据处理管道状态:\")\n", " print(f\" 🔍 步骤1 - 分布分析: {'✅ 完成' if self.distribution_analysis else '❌ 未完成'}\")\n", " print(f\" 🔧 步骤2 - PCA拟合: {'✅ 完成' if self.pca_fitted else '❌ 未完成'}\")\n", " \n", " if self.distribution_analysis:\n", " target_mean = self.distribution_analysis['target_mean']\n", " print(f\" 📊 标签1-39均值: {target_mean:.0f}\")\n", " \n", " if self.pca_fitted:\n", " print(f\" 🔬 PCA降维: 7168 → {self.pca_components} ({self.pca_components/7168:.1%})\")\n", " print(f\" 📈 保留方差: {self.pca_model.explained_variance_ratio_.sum():.4f}\")\n", " \n", " print(f\"\\n🎯 使用流程:\")\n", " print(f\" 1. pipeline.step1_analyze_distribution()\")\n", " print(f\" 2. pipeline.step2_fit_pca_with_undersampling()\")\n", " print(f\" 3. pipeline.step3_process_data('train') # 训练集\")\n", " print(f\" pipeline.step3_process_data('val') # 验证集\")\n", "\n", "# 动态添加剩余方法到类\n", "SmartDataPipeline._apply_full_sampling = _apply_full_sampling\n", "SmartDataPipeline._apply_pca_transform = _apply_pca_transform\n", "SmartDataPipeline.step3_process_data = step3_process_data\n", "SmartDataPipeline.print_summary = print_summary\n", "\n", "print(\"✅ 所有方法已添加到智能管道\")\n", "pipeline.print_summary()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 🔥 执行智能数据处理管道" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "🚀 开始执行智能数据处理管道...\n", "============================================================\n", "\n", "======================🔍 STEP 1: 分析数据分布======================\n", "🔍 步骤1: 分析数据分布...\n", " 正在加载文件 1/41: t15.2023.11.17_val_concatenated.npz\n", " 正在加载文件 2/41: t15.2023.12.17_val_concatenated.npz\n", " 正在加载文件 3/41: t15.2023.10.15_val_concatenated.npz\n", " 正在加载文件 4/41: t15.2023.10.08_val_concatenated.npz\n", " 正在加载文件 5/41: t15.2025.01.10_val_concatenated.npz\n", " 正在加载文件 6/41: t15.2023.12.08_val_concatenated.npz\n", " 正在加载文件 7/41: t15.2024.03.08_val_concatenated.npz\n", " 正在加载文件 8/41: t15.2024.03.15_val_concatenated.npz\n", " 正在加载文件 9/41: t15.2025.03.14_val_concatenated.npz\n", " 正在加载文件 10/41: t15.2024.02.25_val_concatenated.npz\n", " 正在加载文件 11/41: t15.2025.03.30_val_concatenated.npz\n", " 正在加载文件 12/41: t15.2023.09.29_val_concatenated.npz\n", " 正在加载文件 13/41: t15.2023.09.01_val_concatenated.npz\n", " ✅ 分析完成: 101,906 样本\n", " 📊 标签1-39均值: 389\n", " 📉 下采样标签: [0, 40] → 389\n", " 📈 过采样阈值: 0.5 × 均值 = 194\n", "\n", "📊 采样策略总结:\n", " 📉 下采样标签: 2 个\n", " 📈 过采样标签: 11 个\n", " ✅ 保持不变: 28 个\n", "\n", "✅ 步骤1完成!\n" ] } ], "source": [ "# 🔥 执行智能数据处理管道【确定采样策略】\n", "\n", "print(\"🚀 开始执行智能数据处理管道...\")\n", "print(\"=\" * 60)\n", "\n", "# 步骤1: 分析数据分布\n", "print(\"\\n\" + \"🔍 STEP 1: 分析数据分布\".center(60, \"=\"))\n", "distribution, strategy = pipeline.step1_analyze_distribution()\n", "\n", "# 显示采样策略总结\n", "print(f\"\\n📊 采样策略总结:\")\n", "undersample_count = sum(1 for s in strategy.values() if s['action'] == 'undersample')\n", "oversample_count = sum(1 for s in strategy.values() if s['action'] == 'oversample')\n", "keep_count = sum(1 for s in strategy.values() if s['action'] == 'keep')\n", "\n", "print(f\" 📉 下采样标签: {undersample_count} 个\")\n", "print(f\" 📈 过采样标签: {oversample_count} 个\") \n", "print(f\" ✅ 保持不变: {keep_count} 个\")\n", "\n", "print(\"\\n✅ 步骤1完成!\")" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "\n", "=====================🔧 STEP 2: 拟合PCA参数======================\n", "\n", "🔧 步骤2: 拟合PCA参数(仅下采样,不过采样)...\n", " 正在加载文件 1/45: t15.2025.04.13_train_concatenated.npz\n", " 正在加载文件 2/45: t15.2024.07.21_train_concatenated.npz\n", " 正在加载文件 3/45: t15.2024.03.17_train_concatenated.npz\n", " 📦 PCA拟合样本: 15,000 个下采样样本\n", " 🔢 原始特征维度: 7168\n" ] }, { "ename": "KeyboardInterrupt", "evalue": "", "output_type": "error", "traceback": [ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", "\u001b[0;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)", "\u001b[0;32m/tmp/ipykernel_36/3241517313.py\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0;31m# 步骤2: 拟合PCA参数【确定PCA策略】\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"\\n\"\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;34m\"🔧 STEP 2: 拟合PCA参数\"\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcenter\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m60\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"=\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mpipeline\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep2_fit_pca_with_undersampling\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 5\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"\\n✅ 步骤2完成!\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/tmp/ipykernel_36/3022750261.py\u001b[0m in \u001b[0;36mstep2_fit_pca_with_undersampling\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 39\u001b[0m \u001b[0;31m# 确定PCA成分数\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 40\u001b[0m \u001b[0mpca_full\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mPCA\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 41\u001b[0;31m \u001b[0mpca_full\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_scaled\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 42\u001b[0m \u001b[0mcumsum_var\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcumsum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mpca_full\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mexplained_variance_ratio_\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 43\u001b[0m \u001b[0moptimal_components\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0margmax\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcumsum_var\u001b[0m \u001b[0;34m>=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpca_variance_threshold\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m+\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/sklearn/base.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(estimator, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1387\u001b[0m )\n\u001b[1;32m 1388\u001b[0m ):\n\u001b[0;32m-> 1389\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfit_method\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mestimator\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1390\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1391\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/sklearn/decomposition/_pca.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, X, y)\u001b[0m\n\u001b[1;32m 440\u001b[0m \u001b[0mReturns\u001b[0m \u001b[0mthe\u001b[0m \u001b[0minstance\u001b[0m \u001b[0mitself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 441\u001b[0m \"\"\"\n\u001b[0;32m--> 442\u001b[0;31m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_fit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 443\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 444\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/sklearn/decomposition/_pca.py\u001b[0m in \u001b[0;36m_fit\u001b[0;34m(self, X)\u001b[0m\n\u001b[1;32m 540\u001b[0m \u001b[0;31m# Call different fits for either full or truncated SVD\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 541\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_fit_svd_solver\u001b[0m \u001b[0;32min\u001b[0m \u001b[0;34m(\u001b[0m\u001b[0;34m\"full\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"covariance_eigh\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 542\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_fit_full\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_components\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mxp\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mis_array_api_compliant\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 543\u001b[0m \u001b[0;32melif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_fit_svd_solver\u001b[0m \u001b[0;32min\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m\"arpack\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m\"randomized\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 544\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_fit_truncated\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mn_components\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mxp\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/sklearn/decomposition/_pca.py\u001b[0m in \u001b[0;36m_fit_full\u001b[0;34m(self, X, n_components, xp, is_array_api_compliant)\u001b[0m\n\u001b[1;32m 581\u001b[0m \u001b[0;31m# solver by default though (assuming both are built against the\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 582\u001b[0m \u001b[0;31m# same BLAS).\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 583\u001b[0;31m \u001b[0mU\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mS\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mVt\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlinalg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msvd\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_centered\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfull_matrices\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 584\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 585\u001b[0m \u001b[0mU\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mS\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mVt\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mxp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlinalg\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msvd\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_centered\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfull_matrices\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/scipy/linalg/_decomp_svd.py\u001b[0m in \u001b[0;36msvd\u001b[0;34m(a, full_matrices, compute_uv, overwrite_a, check_finite, lapack_driver)\u001b[0m\n\u001b[1;32m 160\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 161\u001b[0m \u001b[0;31m# perform decomposition\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 162\u001b[0;31m u, s, v, info = gesXd(a1, compute_uv=compute_uv, lwork=lwork,\n\u001b[0m\u001b[1;32m 163\u001b[0m full_matrices=full_matrices, overwrite_a=overwrite_a)\n\u001b[1;32m 164\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n", "\u001b[0;31mKeyboardInterrupt\u001b[0m: " ] } ], "source": [ "# 步骤2: 拟合PCA参数【确定PCA策略】\n", "print(\"\\n\" + \"🔧 STEP 2: 拟合PCA参数\".center(60, \"=\"))\n", "pipeline.step2_fit_pca_with_undersampling()\n", "\n", "print(\"\\n✅ 步骤2完成!\")\n", "pipeline.print_summary()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 🚀 使用智能管道进行分批训练" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 🚀 使用智能管道进行分批训练\n", "\n", "import lightgbm as lgb\n", "import time\n", "from collections import Counter\n", "import matplotlib.pyplot as plt\n", "\n", "class SmartBatchTrainer:\n", " \"\"\"\n", " 智能分批训练器,集成智能数据管道\n", " \"\"\"\n", " \n", " def __init__(self, pipeline, params=None, min_learning_rate=1e-4):\n", " self.pipeline = pipeline\n", " self.model = None\n", " self.training_history = {} # 改为字典,因为只有一次训练\n", " self.batch_count = 0\n", " self.min_learning_rate = min_learning_rate\n", " self.lr_history = [] # 用于可视化\n", " \n", " # 默认LightGBM参数(GPU优化)\n", " self.params = params or {\n", " 'objective': 'multiclass',\n", " 'num_class': 41,\n", " 'metric': 'multi_logloss',\n", " 'boosting_type': 'gbdt',\n", " 'device_type': 'cpu',\n", " # 'gpu_platform_id': 0,\n", " # 'gpu_device_id': 0,\n", " 'max_bin': 255,\n", " 'num_leaves': 127,\n", " 'learning_rate': 0.08, #默认0.08\n", " 'feature_fraction': 0.8,\n", " 'bagging_fraction': 0.8,\n", " 'bagging_freq': 5,\n", " 'min_data_in_leaf': 20,\n", " 'lambda_l1': 0.1,\n", " 'lambda_l2': 0.1,\n", " 'verbose': -1,\n", " 'num_threads': -1\n", " }\n", " \n", " self.initial_learning_rate = self.params.get('learning_rate', 0.08)\n", " \n", " print(f\"🎯 智能分批训练器创建完成\")\n", " print(f\" 🔧 LightGBM参数已配置:{self.params['device_type'].upper()}模式\")\n", " print(f\" 💡 学习率调度: 余弦退火 (从 {self.initial_learning_rate} 到 {self.min_learning_rate})\")\n", " \n", " def prepare_validation_data(self):\n", " \"\"\"\n", " 准备验证数据(仅PCA,保持原始分布)\n", " \"\"\"\n", " print(\"🔄 准备验证数据...\")\n", " X_val, y_val = self.pipeline.step3_process_data('val', apply_sampling=False)\n", " if X_val is None:\n", " raise ValueError(\"无法加载验证数据\")\n", " val_counts = Counter(y_val)\n", " print(f\" ✅ 验证数据准备完成: {X_val.shape[0]:,} 样本\")\n", " print(f\" 📊 验证集分布 (标签0: {val_counts.get(0, 0):,}, 标签40: {val_counts.get(40, 0):,})\")\n", " \n", " return lgb.Dataset(X_val, label=y_val, free_raw_data=False)\n", " \n", " def get_training_batch_generator(self):\n", " \"\"\"\n", " 获取训练批次生成器(平衡采样+PCA)\n", " \"\"\"\n", " print(\"🔄 准备训练批次生成器...\")\n", " \n", " # 使用管道的批次生成器\n", " for trials_batch, filename in load_data_batch(self.pipeline.data_dir, 'train', 2000):\n", " features, labels = extract_features_labels_batch(trials_batch)\n", " \n", " # 应用完整采样策略\n", " features_sampled, labels_sampled = self.pipeline._apply_full_sampling(features, labels)\n", " \n", " # 应用PCA降维\n", " if features_sampled.shape[0] > 0:\n", " features_pca = self.pipeline._apply_pca_transform(features_sampled)\n", " \n", " # 分析当前批次分布\n", " batch_counts = Counter(labels_sampled)\n", " \n", " print(f\" 📦 批次: {filename}\")\n", " print(f\" 样本数: {features_pca.shape[0]:,}\")\n", " print(f\" 平衡后分布: 标签0={batch_counts.get(0,0)}, 标签40={batch_counts.get(40,0)}\")\n", " \n", " yield lgb.Dataset(features_pca, label=labels_sampled), filename\n", " \n", " def prepare_full_data(self):\n", " \"\"\"\n", " 一次性准备所有训练和验证数据\n", " \"\"\"\n", " print(\"🔄 准备全量训练和验证数据...\")\n", " \n", " # 1. 准备验证数据 (保持原始分布)\n", " X_val, y_val = self.pipeline.step3_process_data('val', apply_sampling=False)\n", " if X_val is None:\n", " raise ValueError(\"无法加载验证数据\")\n", " val_counts = Counter(y_val)\n", " print(f\" ✅ 验证数据准备完成: {X_val.shape[0]:,} 样本\")\n", " print(f\" 📊 验证集分布 (标签0: {val_counts.get(0, 0):,}, 标签40: {val_counts.get(40, 0):,})\")\n", " val_data = lgb.Dataset(X_val, label=y_val, free_raw_data=False)\n", " \n", " # 2. 准备训练数据 (应用完整采样和PCA策略)\n", " X_train, y_train = self.pipeline.step3_process_data('train', apply_sampling=True)\n", " if X_train is None:\n", " raise ValueError(\"无法加载训练数据\")\n", " train_counts = Counter(y_train)\n", " print(f\" ✅ 训练数据准备完成: {X_train.shape[0]:,} 样本, {X_train.shape[1]} 特征\")\n", " print(f\" 📊 训练集(采样后)分布 (标签0: {train_counts.get(0, 0):,}, 标签40: {train_counts.get(40, 0):,})\")\n", " train_data = lgb.Dataset(X_train, label=y_train)\n", " \n", " return train_data, val_data, X_val, y_val\n", " \n", " def prepare_training_data(self):\n", " \"\"\"\n", " 准备训练数据(仅PCA,保持原始分布)\n", " \"\"\"\n", " print(\"🔄 准备训练数据...\")\n", " # 2. 准备训练数据 (应用完整采样和PCA策略)\n", " X_train, y_train = self.pipeline.step3_process_data('train', apply_sampling=True)\n", " if X_train is None:\n", " raise ValueError(\"无法加载训练数据\")\n", " train_counts = Counter(y_train)\n", " print(f\" ✅ 训练数据准备完成: {X_train.shape[0]:,} 样本, {X_train.shape[1]} 特征\")\n", " print(f\" 📊 训练集(采样后)分布 (标签0: {train_counts.get(0, 0):,}, 标签40: {train_counts.get(40, 0):,})\")\n", " \n", " return lgb.Dataset(X_train, label=y_train, free_raw_data=False)\n", " \n", " # 余弦退火调度器函数\n", " def _cosine_annealing_scheduler(self, current_round, t_max):\n", " eta_max = self.initial_learning_rate\n", " eta_min = self.min_learning_rate\n", " lr = eta_min + 0.5 * (eta_max - eta_min) * (1 + np.cos(np.pi * current_round / t_max))\n", " return lr\n", " \n", " def train_incremental(self, num_boost_round=100, early_stopping_rounds=10):\n", " \"\"\"\n", " 增量分批训练\n", " \"\"\"\n", " print(f\"\\n🚀 开始智能分批训练...\")\n", " print(f\" 📝 训练轮数 (每批次): {num_boost_round}\")\n", " print(f\" ⏹️ 早停轮数: {early_stopping_rounds}\")\n", " print(\"=\" * 60)\n", " \n", " # 准备验证数据\n", " val_data = self.prepare_validation_data()\n", " \n", " print(f\"\\n🔄 开始分批增量训练...\")\n", " total_start_time = time.time()\n", " \n", " # ⭐️ 新增: 为学习率调度器定义T_max\n", " # 我们将每个批次的训练视为一个完整的退火周期\n", " t_max_per_batch = num_boost_round\n", " \n", " for train_data, filename in self.get_training_batch_generator():\n", " self.batch_count += 1\n", " batch_start_time = time.time()\n", " self.last_batch_lr_history = [] # 重置每个批次的LR历史\n", " \n", " print(f\"\\n📈 批次 {self.batch_count}: {filename}\")\n", " \n", " # ⭐️ 新增: 创建学习率调度回调 和 记录回调\n", " lr_scheduler_callback = lgb.reset_parameter(\n", " learning_rate=lambda current_round: self._cosine_annealing_scheduler(current_round, t_max_per_batch)\n", " )\n", "\n", " # 这个简单的回调用于记录每个周期的学习率,以便后续可视化\n", " def record_lr_callback(env):\n", " self.last_batch_lr_history.append(env.model.params['learning_rate'])\n", "\n", " # 组合所有回调\n", " training_callbacks = [\n", " lgb.early_stopping(stopping_rounds=early_stopping_rounds, verbose=True),\n", " lgb.log_evaluation(period=10), # 每10轮打印一次\n", " lr_scheduler_callback,\n", " record_lr_callback\n", " ]\n", "\n", " # 训练当前批次\n", " current_model_args = {\n", " 'params': self.params,\n", " 'train_set': train_data,\n", " 'num_boost_round': num_boost_round,\n", " 'valid_sets': [val_data],\n", " 'valid_names': ['validation'],\n", " 'callbacks': training_callbacks\n", " }\n", " \n", " if self.model is None:\n", " print(\" 🎯 初始模型训练...\")\n", " self.model = lgb.train(**current_model_args)\n", " else:\n", " print(\" ⚡ 增量训练...\")\n", " current_model_args['init_model'] = self.model\n", " self.model = lgb.train(**current_model_args)\n", "\n", " # 记录训练历史\n", " batch_time = time.time() - batch_start_time\n", " \n", " # 评估当前模型\n", " val_pred = self.model.predict(self.X_val)\n", " val_accuracy = (val_pred.argmax(axis=1) == self.y_val).mean()\n", " \n", " batch_info = {\n", " 'batch': self.batch_count,\n", " 'filename': filename,\n", " 'time': batch_time,\n", " 'val_accuracy': val_accuracy,\n", " 'num_trees': self.model.num_trees(),\n", " 'lr_history': self.last_batch_lr_history.copy() # 保存当前批次的LR历史\n", " }\n", " \n", " self.training_history.append(batch_info)\n", " \n", " print(f\" ✅ 批次完成: {batch_time:.1f}秒\")\n", " print(f\" 📊 验证准确率: {val_accuracy:.4f}\")\n", " print(f\" 🌳 模型树数: {self.model.num_trees()}\")\n", " \n", " model_path = f\"smart_batch_model_batch_{self.batch_count}.txt\"\n", " self.model.save_model(model_path)\n", " print(f\" 💾 模型已保存: {model_path}\")\n", " \n", " total_time = time.time() - total_start_time\n", " print(f\"\\n🎉 智能分批训练完成!\")\n", " print(f\" ⏱️ 总训练时间: {total_time:.1f}秒\")\n", " print(f\" 📊 处理批次数: {self.batch_count}\")\n", " print(f\" 🌳 最终模型树数: {self.model.num_trees()}\")\n", " \n", " return self.model\n", " \n", " def train(self, num_boost_round=1000, early_stopping_rounds=50):\n", " \"\"\"\n", " 执行一次性全量训练\n", " \"\"\"\n", " print(f\"\\n🚀 开始全量数据训练...\")\n", " print(f\" 📝 训练轮数: {num_boost_round}\")\n", " print(f\" ⏹️ 早停轮数: {early_stopping_rounds}\")\n", " print(\"=\" * 60)\n", " \n", " # 准备数据\n", " train_data, val_data, X_val, y_val = self.prepare_full_data()\n", " \n", " start_time = time.time()\n", " \n", " # 定义学习率调度和记录回调\n", " lr_scheduler_callback = lgb.reset_parameter(\n", " learning_rate=lambda current_round: self._cosine_annealing_scheduler(current_round, num_boost_round)\n", " )\n", " def record_lr_callback(env):\n", " self.lr_history.append(env.model.params['learning_rate'])\n", " \n", " training_callbacks = [\n", " lgb.early_stopping(stopping_rounds=early_stopping_rounds, verbose=True),\n", " lgb.log_evaluation(period=1), # 每100轮打印日志\n", " lr_scheduler_callback,\n", " record_lr_callback\n", " ]\n", " \n", " # 训练模型\n", " print(\"\\n📈 开始模型训练...\")\n", " self.model = lgb.train(\n", " params=self.params,\n", " train_set=train_data,\n", " num_boost_round=num_boost_round,\n", " valid_sets=[val_data],\n", " valid_names=['validation'],\n", " callbacks=training_callbacks\n", " )\n", " \n", " training_time = time.time() - start_time\n", " \n", " # 评估模型\n", " val_pred = self.model.predict(X_val)\n", " val_accuracy = (val_pred.argmax(axis=1) == y_val).mean()\n", " \n", " # 记录训练历史\n", " self.training_history = {\n", " 'time': training_time,\n", " 'val_accuracy': val_accuracy,\n", " 'num_trees': self.model.num_trees(),\n", " 'lr_history': self.lr_history,\n", " 'best_iteration': self.model.best_iteration\n", " }\n", " \n", " print(f\"\\n🎉 全量数据训练完成!\")\n", " print(f\" ⏱️ 总训练时间: {training_time:.1f}秒\")\n", " print(f\" 🌳 最终模型树数: {self.model.num_trees()} (最佳轮次: {self.model.best_iteration})\")\n", " print(f\" 🎯 最终验证准确率: {val_accuracy:.4f}\")\n", " \n", " # 保存模型\n", " model_path = \"full_train_model.txt\"\n", " self.model.save_model(model_path)\n", " print(f\" 💾 模型已保存: {model_path}\")\n", " \n", " return self.model\n", " \n", " def plot_training_progress(self):\n", " \"\"\"\n", " 绘制训练进度\n", " \"\"\"\n", " if not self.training_history:\n", " print(\"❌ 没有训练历史记录\")\n", " return\n", " \n", " # ⭐️ 修改: 增加学习率的可视化图表\n", " fig, ((ax1, ax2), (ax3, ax4), (ax5, ax6)) = plt.subplots(3, 2, figsize=(15, 15))\n", " \n", " batches = [h['batch'] for h in self.training_history]\n", " accuracies = [h['val_accuracy'] for h in self.training_history]\n", " times = [h['time'] for h in self.training_history]\n", " trees = [h['num_trees'] for h in self.training_history]\n", " \n", " # 1. 验证准确率\n", " ax1.plot(batches, accuracies, 'b-o', linewidth=2, markersize=6)\n", " ax1.set_xlabel('Training Batch')\n", " ax1.set_ylabel('Validation Accuracy')\n", " ax1.set_title('Validation Accuracy Progress')\n", " ax1.grid(True, alpha=0.3)\n", " ax1.set_ylim(0, 1)\n", " \n", " # 2. 批次训练时间\n", " ax2.bar(batches, times, color='green', alpha=0.7)\n", " ax2.set_xlabel('Training Batch')\n", " ax2.set_ylabel('Training Time (seconds)')\n", " ax2.set_title('Training Time per Batch')\n", " ax2.grid(True, alpha=0.3)\n", " \n", " # 3. 模型树数增长\n", " ax3.plot(batches, trees, 'r-s', linewidth=2, markersize=6)\n", " ax3.set_xlabel('Training Batch')\n", " ax3.set_ylabel('Number of Trees')\n", " ax3.set_title('Model Complexity Growth')\n", " ax3.grid(True, alpha=0.3)\n", " \n", " # 4. 累计准确率提升\n", " ax4.plot(batches, [acc - accuracies[0] for acc in accuracies], 'purple', linewidth=2, marker='D')\n", " ax4.set_xlabel('Training Batch')\n", " ax4.set_ylabel('Accuracy Improvement')\n", " ax4.set_title('Cumulative Accuracy Improvement')\n", " ax4.grid(True, alpha=0.3)\n", " ax4.axhline(y=0, color='black', linestyle='--', alpha=0.5)\n", "\n", " # ⭐️ 新增: 5. 最后一个批次的学习率曲线\n", " last_lr_history = self.training_history[-1]['lr_history']\n", " ax5.plot(range(len(last_lr_history)), last_lr_history, color='orange', marker='.')\n", " ax5.set_xlabel('Boosting Round in Last Batch')\n", " ax5.set_ylabel('Learning Rate')\n", " ax5.set_title(f'Cosine Annealing LR in Last Batch (Batch {batches[-1]})')\n", " ax5.grid(True, alpha=0.3)\n", " \n", " # 隐藏第六个子图\n", " ax6.axis('off')\n", "\n", " plt.tight_layout()\n", " plt.show()\n", " \n", " # 打印统计信息\n", " print(f\"\\n📈 训练进度统计:\")\n", " print(f\" 🎯 初始准确率: {accuracies[0]:.4f}\")\n", " print(f\" 🎯 最终准确率: {accuracies[-1]:.4f}\")\n", " print(f\" 📈 准确率提升: {accuracies[-1] - accuracies[0]:.4f}\")\n", " print(f\" ⏱️ 平均批次时间: {np.mean(times):.1f}秒\")\n", " print(f\" 🌳 最终模型树数: {trees[-1]}\")\n", "\n", "\n", "print(\"🚀 创建智能分批训练器...\")\n", "# 实例化时可以传入最小学习率\n", "trainer = SmartBatchTrainer(pipeline, min_learning_rate=0.001) \n", "print(\"✅ 训练器创建完成,准备开始训练!\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# # 全量训练\n", "\n", "# print(\"🔥 开始智能分批训练!\")\n", "# print(\"=\" * 80)\n", "\n", "# # 训练参数\n", "# TRAINING_PARAMS = {\n", "# 'num_boost_round': 300, # 每批次的提升轮数\n", "# 'early_stopping_rounds': 15 # 早停轮数\n", "# }\n", "\n", "# print(f\"📝 训练配置:\")\n", "# print(f\" 训练轮数: {TRAINING_PARAMS['num_boost_round']}\")\n", "# print(f\" 早停轮数: {TRAINING_PARAMS['early_stopping_rounds']}\")\n", "# print(f\" 数据平衡: 启用(下采样标签0,40 + 过采样少数类)\")\n", "# print(f\" PCA降维: 7168 → {pipeline.pca_components} 特征\")\n", "\n", "# print(f\"\\n🚀 启动训练...\")\n", "\n", "# # 开始训练\n", "# model = trainer.train(\n", "# num_boost_round=TRAINING_PARAMS['num_boost_round'],\n", "# early_stopping_rounds=TRAINING_PARAMS['early_stopping_rounds']\n", "# )" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 📊 训练结果分析" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 📊 训练结果分析和可视化\n", "\n", "print(\"📊 分析智能分批训练结果...\")\n", "print(\"=\" * 60)\n", "\n", "# 显示训练进度图表\n", "trainer.plot_training_progress()\n", "\n", "# 保存最终模型\n", "final_model_path = \"smart_pipeline_final_model.txt\"\n", "if trainer.model:\n", " trainer.model.save_model(final_model_path)\n", " print(f\"\\n💾 最终模型已保存: {final_model_path}\")\n", "\n", "# 详细分析\n", "if trainer.training_history:\n", " print(f\"\\n📈 详细训练分析:\")\n", " print(f\" 🎯 训练批次总数: {len(trainer.training_history)}\")\n", " \n", " # 最佳批次\n", " best_batch = max(trainer.training_history, key=lambda x: x['val_accuracy'])\n", " print(f\" 🏆 最佳验证准确率: {best_batch['val_accuracy']:.4f} (批次 {best_batch['batch']})\")\n", " \n", " # 训练效率\n", " total_training_time = sum(h['time'] for h in trainer.training_history)\n", " avg_batch_time = total_training_time / len(trainer.training_history)\n", " print(f\" ⏱️ 总训练时间: {total_training_time:.1f}秒\")\n", " print(f\" ⏱️ 平均批次时间: {avg_batch_time:.1f}秒\")\n", " \n", " # 模型复杂度\n", " final_trees = trainer.training_history[-1]['num_trees']\n", " print(f\" 🌳 最终模型树数: {final_trees}\")\n", " \n", " # 收敛性分析\n", " recent_accs = [h['val_accuracy'] for h in trainer.training_history[-3:]]\n", " if len(recent_accs) >= 2:\n", " acc_stability = max(recent_accs) - min(recent_accs)\n", " print(f\" 📈 准确率稳定性: {acc_stability:.4f} (最近3批次方差)\")\n", " \n", " if acc_stability < 0.01:\n", " print(\" ✅ 模型已收敛 (准确率变化 < 1%)\")\n", " else:\n", " print(\" ⚠️ 模型可能需要更多训练\")\n", "\n", "print(f\"\\n🎉 智能分批训练分析完成!\")\n", "print(f\" 💡 使用了改进的数据平衡策略和PCA降维\")\n", "print(f\" 💡 训练集应用了下采样+过采样,验证集保持原始分布\")\n", "print(f\" 💡 实现了内存友好的分批处理\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 🧪 模型性能评估" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 🧪 模型性能评估\n", "\n", "from sklearn.metrics import classification_report, confusion_matrix\n", "import numpy as np\n", "\n", "def evaluate_model_performance(model, pipeline, data_type='val'):\n", " \"\"\"\n", " 评估模型在指定数据集上的性能\n", " \"\"\"\n", " print(f\"🧪 评估模型在{data_type}数据集上的性能...\")\n", " \n", " # 加载数据\n", " X, y = pipeline.step3_process_data(data_type, apply_sampling=False)\n", " \n", " if X is None or y is None:\n", " print(f\"❌ 无法加载{data_type}数据\")\n", " return None\n", " \n", " print(f\" 📊 数据集大小: {X.shape[0]:,} 样本, {X.shape[1]} 特征\")\n", " \n", " # 预测\n", " start_time = time.time()\n", " y_pred_proba = model.predict(X)\n", " y_pred = y_pred_proba.argmax(axis=1)\n", " pred_time = time.time() - start_time\n", " \n", " # 计算性能指标\n", " accuracy = (y_pred == y).mean()\n", " \n", " print(f\" ⏱️ 预测时间: {pred_time:.2f}秒\")\n", " print(f\" 🎯 整体准确率: {accuracy:.4f}\")\n", " \n", " # 分析各类别性能\n", " from collections import Counter\n", " true_counts = Counter(y)\n", " pred_counts = Counter(y_pred)\n", " \n", " print(f\"\\n📊 标签分布对比:\")\n", " print(\"标签 | 真实数量 | 预测数量 | 准确率\")\n", " print(\"-\" * 40)\n", " \n", " label_accuracies = {}\n", " for label in range(41):\n", " if label in true_counts:\n", " label_mask = (y == label)\n", " if label_mask.sum() > 0:\n", " label_acc = (y_pred[label_mask] == label).mean()\n", " label_accuracies[label] = label_acc\n", " true_count = true_counts.get(label, 0)\n", " pred_count = pred_counts.get(label, 0)\n", " print(f\"{label:4d} | {true_count:8,} | {pred_count:8,} | {label_acc:7.3f}\")\n", " \n", " # 重点分析关键标签\n", " print(f\"\\n🔍 关键标签性能分析:\")\n", " key_labels = [0, 40] # 下采样的标签\n", " for label in key_labels:\n", " if label in label_accuracies:\n", " acc = label_accuracies[label]\n", " count = true_counts.get(label, 0)\n", " print(f\" 标签 {label} (下采样目标): 准确率 {acc:.4f}, 样本数 {count:,}\")\n", " \n", " # 少数类性能\n", " minority_labels = [label for label, count in true_counts.items() \n", " if count < 200 and label not in [0, 40]]\n", " if minority_labels:\n", " minority_accs = [label_accuracies.get(label, 0) for label in minority_labels[:5]]\n", " avg_minority_acc = np.mean(minority_accs) if minority_accs else 0\n", " print(f\" 少数类平均准确率 (前5个): {avg_minority_acc:.4f}\")\n", " \n", " # 置信度分析\n", " max_proba = y_pred_proba.max(axis=1)\n", " print(f\"\\n📈 预测置信度分析:\")\n", " print(f\" 平均置信度: {max_proba.mean():.4f}\")\n", " print(f\" 置信度中位数: {np.median(max_proba):.4f}\")\n", " print(f\" 高置信度预测 (>0.9): {(max_proba > 0.9).sum():,} / {len(max_proba):,} ({(max_proba > 0.9).mean():.2%})\")\n", " \n", " return {\n", " 'accuracy': accuracy,\n", " 'prediction_time': pred_time,\n", " 'label_accuracies': label_accuracies,\n", " 'confidence_stats': {\n", " 'mean': max_proba.mean(),\n", " 'median': np.median(max_proba),\n", " 'high_confidence_ratio': (max_proba > 0.9).mean()\n", " }\n", " }\n", "\n", "# 评估模型性能\n", "if trainer.model:\n", " print(\"🧪 开始模型性能评估...\")\n", " \n", " # 验证集评估\n", " val_results = evaluate_model_performance(trainer.model, pipeline, 'val')\n", " \n", " print(f\"\\n\" + \"=\"*60)\n", " print(\"🎉 智能分批训练+数据平衡 评估完成!\")\n", " print(f\"✅ 实现了数据平衡和PCA降维的完整流程\")\n", " print(f\"✅ 使用了内存友好的分批训练策略\")\n", " print(f\"✅ 保持了验证集的原始分布以确保评估客观性\")\n", "else:\n", " print(\"❌ 模型尚未训练完成,请等待训练结束后运行此评估\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# 测试集总评-连接语言模型" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "smart_pipeline_final_model.txt" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "ename": "", "evalue": "", "output_type": "error", "traceback": [ "\u001b[1;31mFailed to connect to the remote Jupyter Server 'https://kkb-production.jupyter-proxy.kaggle.net/'. Verify the server is running and reachable. (Failed to connect to the remote Jupyter Server 'https://kkb-production.jupyter-proxy.kaggle.net/'. Verify the server is running and reachable. (request to https://kkb-production.jupyter-proxy.kaggle.net/k/261889069/eyJhbGciOiJkaXIiLCJlbmMiOiJBMTI4Q0JDLUhTMjU2IiwidHlwIjoiSldUIn0..3JPUo77-E5Unxk40FcVuDw.UQ5HO58Y63DL5av9cYv59hBnb4Rw6GbfyzcRyPp9ID-u0ODR4KJuqXcaUS7TKXpddj60a_dRVtxSjqjhxD7xtc5fM80xoPpibRRjVKonb_HwqUKs_96UIdvPfI_MeKXYJ3Tb0AXf-5TxLoOaYyps8zaC5bp8r7jzr1uNTM56M7RH09kDMCNnIhvD7zWEZJlQULZ3sY6N8v36OVsY05q5c6ZnVePk92Qw-buRKiNK5bIo4qmSjUssmdP5SqMShwc3.iAgJSIm0bnGknjcE5jhAvQ/proxy/api/kernels?1757936253840 failed, reason: Client network socket disconnected before secure TLS connection was established).)." ] } ], "source": [ "# 🚀 加载预训练的 LightGBM 分类模型\n", "\n", "import lightgbm as lgb\n", "import os\n", "\n", "def load_lgbm_model(model_path):\n", " \"\"\" \n", " 加载预训练的 LightGBM 模型\n", " \n", " Args:\n", " model_path: 模型文件路径\n", " \n", " Returns:\n", " lgb.Booster: 加载的模型\n", " \"\"\"\n", " if not os.path.exists(model_path):\n", " raise FileNotFoundError(f\"模型文件不存在: {model_path}\")\n", " \n", " print(f\"📂 正在加载模型: {model_path}\")\n", " \n", " # 加载模型\n", " model = lgb.Booster(model_file=model_path)\n", " \n", " print(f\"✅ 模型加载成功!\")\n", " print(f\" 🌳 模型树数: {model.num_trees()}\")\n", " print(f\" 📊 特征数: {model.num_feature()}\")\n", " print(f\" 🏷️ 类别数: {model.num_model_per_iteration()}\")\n", " \n", " # 显示模型基本信息\n", " model_info = {\n", " 'num_trees': model.num_trees(),\n", " 'num_features': model.num_feature(),\n", " 'num_classes': model.num_model_per_iteration(),\n", " 'objective': model.params.get('objective', 'unknown'),\n", " 'boosting_type': model.params.get('boosting_type', 'unknown')\n", " }\n", " \n", " print(f\"\\n📋 模型详细信息:\")\n", " for key, value in model_info.items():\n", " print(f\" {key}: {value}\")\n", " \n", " return model\n", "\n", "# 加载我们训练好的模型\n", "MODEL_PATH = \"full_train_model.txt\"\n", "\n", "try:\n", " lgbm_model = load_lgbm_model(MODEL_PATH)\n", " print(f\"\\n🎉 LightGBM 模型加载完成,准备用于推理!\")\n", " \n", "except FileNotFoundError as e:\n", " print(f\"❌ 错误: {e}\")\n", " print(f\"💡 请确保模型文件 '{MODEL_PATH}' 存在于当前目录\")\n", " lgbm_model = None\n", " \n", "except Exception as e:\n", " print(f\"❌ 加载模型时发生错误: {e}\")\n", " lgbm_model = None" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 🧪 测试模型预测功能\n", "\n", "def test_model_prediction(model, test_features=None):\n", " \"\"\"\n", " 测试模型的预测功能\n", " \n", " Args:\n", " model: LightGBM 模型\n", " test_features: 测试特征数据,如果为None则创建虚拟数据\n", " \"\"\"\n", " if model is None:\n", " print(\"❌ 模型未加载,无法进行测试\")\n", " return\n", " \n", " print(\"🧪 测试模型预测功能...\")\n", " \n", " # 如果没有提供测试数据,创建虚拟数据\n", " if test_features is None:\n", " print(\" 📝 创建虚拟测试数据...\")\n", " # 根据模型期望的特征数创建随机数据\n", " num_features = model.num_feature()\n", " num_samples = 5\n", " test_features = np.random.randn(num_samples, num_features).astype(np.float32)\n", " print(f\" 🔢 虚拟数据形状: {test_features.shape}\")\n", " \n", " try:\n", " # 进行预测\n", " predictions = model.predict(test_features)\n", " \n", " print(f\" ✅ 预测成功!\")\n", " print(f\" 📊 预测形状: {predictions.shape}\")\n", " print(f\" 🎯 预测范例 (前3个样本的前5个类别概率):\")\n", " \n", " for i in range(min(3, predictions.shape[0])):\n", " pred_probs = predictions[i][:5] # 只显示前5个类别\n", " predicted_class = np.argmax(predictions[i])\n", " max_prob = np.max(predictions[i])\n", " \n", " print(f\" 样本 {i+1}: 预测类别={predicted_class}, 置信度={max_prob:.4f}\")\n", " print(f\" 前5类概率: {pred_probs}\")\n", " \n", " print(f\" 📈 预测置信度统计:\")\n", " max_probs = np.max(predictions, axis=1)\n", " print(f\" 平均置信度: {np.mean(max_probs):.4f}\")\n", " print(f\" 最高置信度: {np.max(max_probs):.4f}\")\n", " print(f\" 最低置信度: {np.min(max_probs):.4f}\")\n", " \n", " return True\n", " \n", " except Exception as e:\n", " print(f\" ❌ 预测失败: {e}\")\n", " return False\n", "\n", "# 测试加载的模型\n", "if lgbm_model is not None:\n", " test_success = test_model_prediction(lgbm_model)\n", " \n", " if test_success:\n", " print(f\"\\n🎉 模型测试成功! 可以用于实际推理任务\")\n", " print(f\"💡 模型期望输入: {lgbm_model.num_feature()} 维特征向量\")\n", " print(f\"💡 模型输出: {lgbm_model.num_model_per_iteration()} 个类别的概率分布\")\n", " else:\n", " print(f\"\\n❌ 模型测试失败,请检查模型文件\")\n", "else:\n", " print(\"❌ 模型未加载,跳过测试\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 测试集的实义测试" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "✅ 神经数据处理函数定义完成\n" ] } ], "source": [ "# 🚀 仿照RNN评估流程处理测试集数据 + LightGBM预测\n", "\n", "import h5py\n", "import torch\n", "import torch.nn.functional as F\n", "from scipy.ndimage import gaussian_filter1d\n", "import os\n", "from tqdm import tqdm\n", "\n", "def load_h5py_file(file_path, b2txt_csv_df):\n", " data = {\n", " 'neural_features': [],\n", " 'n_time_steps': [],\n", " 'seq_class_ids': [],\n", " 'seq_len': [],\n", " 'transcriptions': [],\n", " 'sentence_label': [],\n", " 'session': [],\n", " 'block_num': [],\n", " 'trial_num': [],\n", " 'corpus': [],\n", " }\n", " # Open the hdf5 file for that day\n", " with h5py.File(file_path, 'r') as f:\n", "\n", " keys = list(f.keys())\n", "\n", " # For each trial in the selected trials in that day\n", " for key in keys:\n", " g = f[key]\n", "\n", " neural_features = g['input_features'][:] # pyright: ignore[reportIndexIssue]\n", " n_time_steps = g.attrs['n_time_steps']\n", " seq_class_ids = g['seq_class_ids'][:] if 'seq_class_ids' in g else None # type: ignore\n", " seq_len = g.attrs['seq_len'] if 'seq_len' in g.attrs else None\n", " transcription = g['transcription'][:] if 'transcription' in g else None # type: ignore\n", " sentence_label = g.attrs['sentence_label'][:] if 'sentence_label' in g.attrs else None # pyright: ignore[reportIndexIssue]\n", " session = g.attrs['session']\n", " block_num = g.attrs['block_num']\n", " trial_num = g.attrs['trial_num']\n", "\n", " # match this trial up with the csv to get the corpus name\n", " year, month, day = session.split('.')[1:] # pyright: ignore[reportAttributeAccessIssue]\n", " date = f'{year}-{month}-{day}'\n", " row = b2txt_csv_df[(b2txt_csv_df['Date'] == date) & (b2txt_csv_df['Block number'] == block_num)]\n", " corpus_name = row['Corpus'].values[0]\n", "\n", " data['neural_features'].append(neural_features)\n", " data['n_time_steps'].append(n_time_steps)\n", " data['seq_class_ids'].append(seq_class_ids)\n", " data['seq_len'].append(seq_len)\n", " data['transcriptions'].append(transcription)\n", " data['sentence_label'].append(sentence_label)\n", " data['session'].append(session)\n", " data['block_num'].append(block_num)\n", " data['trial_num'].append(trial_num)\n", " data['corpus'].append(corpus_name)\n", " return data\n", "\n", "def gauss_smooth_torch(inputs, device, smooth_kernel_std=2, smooth_kernel_size=100, padding='valid'):\n", " \"\"\"\n", " PyTorch版本的高斯平滑 (仿照data_augmentations.py)\n", " \"\"\"\n", " # 创建高斯核\n", " inp = np.zeros(smooth_kernel_size, dtype=np.float32)\n", " inp[smooth_kernel_size // 2] = 1\n", " gaussKernel = gaussian_filter1d(inp, smooth_kernel_std)\n", " \n", " # 过滤小值\n", " validIdx = np.argwhere(gaussKernel > 0.01)\n", " if len(validIdx) > 0:\n", " gaussKernel = gaussKernel[validIdx.flatten()]\n", " gaussKernel = np.squeeze(gaussKernel / np.sum(gaussKernel))\n", " \n", " # 转换为PyTorch张量\n", " gaussKernel = torch.tensor(gaussKernel, dtype=inputs.dtype, device=device)\n", " gaussKernel = gaussKernel.view(1, 1, -1) # [1, 1, kernel_size]\n", " \n", " # 准备卷积\n", " B, T, C = inputs.shape\n", " inputs = inputs.permute(0, 2, 1) # [B, C, T]\n", " gaussKernel = gaussKernel.repeat(C, 1, 1) # [C, 1, kernel_size]\n", " \n", " # 执行卷积\n", " smoothed = F.conv1d(inputs, gaussKernel, padding=padding, groups=C)\n", " return smoothed.permute(0, 2, 1) # [B, T, C]\n", "\n", "def apply_patch_processing(x, patch_size=14, patch_stride=4):\n", " \"\"\"\n", " 应用patch处理 (仿照rnn_model.py的forward方法)\n", " \n", " Args:\n", " x: 输入张量 [batch, timesteps, features]\n", " patch_size: patch大小\n", " patch_stride: patch步长\n", " \n", " Returns:\n", " 处理后的张量 [batch, num_patches, patch_size * features]\n", " \"\"\"\n", " if patch_size <= 0:\n", " return x\n", " \n", " x = x.unsqueeze(1) # [batches, 1, timesteps, feature_dim]\n", " x = x.permute(0, 3, 1, 2) # [batches, feature_dim, 1, timesteps]\n", " \n", " # 使用unfold提取patches\n", " x_unfold = x.unfold(3, patch_size, patch_stride) # [batches, feature_dim, 1, num_patches, patch_size]\n", " \n", " # 移除虚拟维度并重新排列\n", " x_unfold = x_unfold.squeeze(2) # [batches, feature_dim, num_patches, patch_size]\n", " x_unfold = x_unfold.permute(0, 2, 3, 1) # [batches, num_patches, patch_size, feature_dim]\n", " \n", " # 展平最后两个维度\n", " x = x_unfold.reshape(x_unfold.size(0), x_unfold.size(1), -1) # [batch, num_patches, patch_size * features]\n", " \n", " return x\n", "\n", "def process_neural_data_for_lgbm(neural_input, device, model_args):\n", " \"\"\"\n", " 仿照RNN处理流程:高斯平滑 + patch处理\n", " \n", " Args:\n", " neural_input: 神经数据 [batch, time, features]\n", " device: PyTorch设备\n", " model_args: 模型参数配置\n", " \n", " Returns:\n", " 处理后的特征数据,准备输入LightGBM\n", " \"\"\"\n", " # 1. 高斯平滑\n", " smoothed_input = gauss_smooth_torch(\n", " inputs=neural_input,\n", " device=device,\n", " smooth_kernel_std=model_args.get('smooth_kernel_std', 2),\n", " smooth_kernel_size=model_args.get('smooth_kernel_size', 100),\n", " padding='valid'\n", " )\n", " \n", " # 2. Patch处理\n", " patch_size = model_args.get('patch_size', 14)\n", " patch_stride = model_args.get('patch_stride', 4)\n", " \n", " if patch_size > 0:\n", " patched_input = apply_patch_processing(\n", " smoothed_input, \n", " patch_size=patch_size, \n", " patch_stride=patch_stride\n", " )\n", " # 展平为2D: [batch * num_patches, patch_size * features]\n", " batch_size, num_patches, patch_features = patched_input.shape\n", " features_2d = patched_input.reshape(-1, patch_features)\n", " else:\n", " # 如果不使用patch,直接展平\n", " features_2d = smoothed_input.reshape(-1, smoothed_input.shape[-1])\n", " \n", " return features_2d.cpu().numpy()\n", "\n", "print(\"✅ 神经数据处理函数定义完成\")" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "t15.2023.08.11\tt15.2023.09.29\tt15.2023.11.04\tt15.2024.02.25\tt15.2024.07.19\n", "t15.2023.08.13\tt15.2023.10.01\tt15.2023.11.17\tt15.2024.03.03\tt15.2024.07.21\n", "t15.2023.08.18\tt15.2023.10.06\tt15.2023.11.19\tt15.2024.03.08\tt15.2024.07.28\n", "t15.2023.08.20\tt15.2023.10.08\tt15.2023.11.26\tt15.2024.03.15\tt15.2025.01.10\n", "t15.2023.08.25\tt15.2023.10.13\tt15.2023.12.03\tt15.2024.03.17\tt15.2025.01.12\n", "t15.2023.08.27\tt15.2023.10.15\tt15.2023.12.08\tt15.2024.04.25\tt15.2025.03.14\n", "t15.2023.09.01\tt15.2023.10.20\tt15.2023.12.10\tt15.2024.04.28\tt15.2025.03.16\n", "t15.2023.09.03\tt15.2023.10.22\tt15.2023.12.17\tt15.2024.05.10\tt15.2025.03.30\n", "t15.2023.09.24\tt15.2023.11.03\tt15.2023.12.29\tt15.2024.06.14\tt15.2025.04.13\n" ] } ], "source": [ "!ls data/hdf5_data_final\t" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 🔥 测试集数据加载与LightGBM预测\n", "\n", "class TestSetPredictor:\n", " \"\"\"\n", " 测试集预测器 - 仿照RNN评估流程\n", " \"\"\"\n", " \n", " def __init__(self, lgbm_model, pipeline, data_dir, device='cpu'):\n", " self.lgbm_model = lgbm_model\n", " self.pipeline = pipeline\n", " self.data_dir = data_dir\n", " self.device = torch.device(device)\n", " \n", " # 配置参数 (仿照RNN模型参数)\n", " self.model_args = {\n", " 'smooth_kernel_std': 2,\n", " 'smooth_kernel_size': 100,\n", " 'patch_size': 14,\n", " 'patch_stride': 4,\n", " }\n", " \n", " print(f\"🎯 测试集预测器初始化完成\")\n", " print(f\" 设备: {self.device}\")\n", " print(f\" Patch配置: size={self.model_args['patch_size']}, stride={self.model_args['patch_stride']}\")\n", " print(f\" 平滑配置: std={self.model_args['smooth_kernel_std']}, size={self.model_args['smooth_kernel_size']}\")\n", " \n", " def load_test_sessions(self):\n", " \"\"\"\n", " 加载所有测试会话数据\n", " \"\"\"\n", " print(f\"🔍 扫描测试数据目录: {self.data_dir}\")\n", " \n", " test_data = {}\n", " total_trials = 0\n", " \n", " # 扫描数据目录中的所有会话\n", " for session_name in os.listdir(self.data_dir):\n", " session_path = os.path.join(self.data_dir, session_name)\n", " if not os.path.isdir(session_path):\n", " continue\n", " \n", " # 查找测试数据文件\n", " test_file = os.path.join(session_path, 'data_test.hdf5')\n", " if os.path.exists(test_file):\n", " print(f\" 📂 发现测试会话: {session_name}\")\n", " \n", " try:\n", " # 加载数据 (传入空的CSV,因为测试集不需要)\n", " data = load_h5py_file(test_file, None)\n", " test_data[session_name] = data\n", " \n", " num_trials = len(data['neural_features'])\n", " total_trials += num_trials\n", " \n", " print(f\" ✅ 加载成功: {num_trials} 个试验\")\n", " print(f\" 📊 神经特征形状: {data['neural_features'][0].shape if num_trials > 0 else 'N/A'}\")\n", " \n", " except Exception as e:\n", " print(f\" ❌ 加载失败: {e}\")\n", " \n", " print(f\"\\n📊 测试数据加载总结:\")\n", " print(f\" 会话数: {len(test_data)}\")\n", " print(f\" 总试验数: {total_trials}\")\n", " \n", " return test_data\n", " \n", " def predict_test_set(self, test_data):\n", " \"\"\"\n", " 对测试集进行预测\n", " \"\"\"\n", " if self.lgbm_model is None:\n", " raise ValueError(\"LightGBM模型未加载\")\n", " \n", " if not self.pipeline.pca_fitted:\n", " raise ValueError(\"PCA模型未拟合\")\n", " \n", " print(f\"\\n🚀 开始测试集预测...\")\n", " \n", " # 统计总试验数\n", " total_trials = sum(len(data['neural_features']) for data in test_data.values())\n", " \n", " results = {\n", " 'session': [],\n", " 'block': [],\n", " 'trial': [],\n", " 'predicted_sequence': [], # 完整的音素序列\n", " 'true_sequence': [], # 真实的音素序列\n", " 'sentence_label': [], # 句子标签(如果有的话)\n", " 'logits': [], # 原始预测概率\n", " 'sequence_length': [], # 预测序列长度\n", " 'true_length': [] # 真实序列长度\n", " }\n", " \n", " with tqdm(total=total_trials, desc='LightGBM预测进度', unit='trial') as pbar:\n", " for session_name, data in test_data.items():\n", " print(f\"\\n📈 处理会话: {session_name}\")\n", " \n", " for trial_idx in range(len(data['neural_features'])):\n", " # 1. 获取神经数据\n", " neural_input = data['neural_features'][trial_idx]\n", " \n", " # 2. 添加批次维度并转换为张量\n", " neural_input = np.expand_dims(neural_input, axis=0)\n", " neural_tensor = torch.tensor(neural_input, device=self.device, dtype=torch.float32)\n", " \n", " # 3. 应用RNN式的处理流程:高斯平滑 + patch处理\n", " processed_features = process_neural_data_for_lgbm(\n", " neural_tensor, self.device, self.model_args\n", " )\n", " \n", " # 4. 应用PCA降维\n", " if processed_features.shape[0] > 0:\n", " features_pca = self.pipeline._apply_pca_transform(processed_features)\n", " \n", " # 5. LightGBM预测 - 获取完整序列\n", " predictions = self.lgbm_model.predict(features_pca)\n", " \n", " # 6. 处理预测结果 - 保持序列形式\n", " if len(predictions.shape) > 1:\n", " # 每一行是一个时间步/patch的预测\n", " logits_sequence = predictions # [num_patches, 41]\n", " else:\n", " # 单个预测,扩展为序列\n", " logits_sequence = predictions.reshape(1, -1) # [1, 41]\n", " \n", " # 7. 转换为音素序列 (仿照RNN后处理) TODO:这里可以做过滤!!!!!\n", " predicted_classes = np.argmax(logits_sequence, axis=-1) # [num_patches]\n", " \n", " # 8. 后处理音素序列 (仿照evaluate_model.py)\n", " # 移除blank (0)\n", " pred_seq = [int(p) for p in predicted_classes if p != 0]\n", " # 移除连续重复\n", " pred_seq = [pred_seq[i] for i in range(len(pred_seq)) if i == 0 or pred_seq[i] != pred_seq[i-1]]\n", " # 转换为音素符号\n", " predicted_phoneme_sequence = [LOGIT_TO_PHONEME[p] for p in pred_seq]\n", " \n", " # 8. 读取真实音素序列(如果存在)\n", " true_phoneme_sequence = []\n", " sentence_label = \"\"\n", " true_length = 0\n", " \n", " if 'seq_class_ids' in data and 'seq_len' in data:\n", " # 仿照evaluate_model.py的处理方式\n", " true_seq = data['seq_class_ids'][trial_idx][0:data['seq_len'][trial_idx]]\n", " true_phoneme_sequence = [LOGIT_TO_PHONEME[p] for p in true_seq]\n", " true_length = len(true_phoneme_sequence)\n", " \n", " if 'sentence_label' in data:\n", " sentence_label = data['sentence_label'][trial_idx]\n", " \n", " # 9. 存储结果\n", " results['session'].append(session_name)\n", " results['block'].append(data['block_num'][trial_idx])\n", " results['trial'].append(data['trial_num'][trial_idx])\n", " results['predicted_sequence'].append(predicted_phoneme_sequence)\n", " results['true_sequence'].append(true_phoneme_sequence)\n", " results['sentence_label'].append(sentence_label)\n", " results['logits'].append(logits_sequence.tolist())\n", " results['sequence_length'].append(len(predicted_phoneme_sequence))\n", " results['true_length'].append(true_length)\n", " \n", " pbar.update(1)\n", " \n", " print(f\"\\n🎉 测试集预测完成!\")\n", " print(f\" 总预测数: {len(results['predicted_sequence'])}\")\n", " print(f\" 平均预测序列长度: {np.mean(results['sequence_length']):.1f}\")\n", " print(f\" 预测序列长度范围: {min(results['sequence_length'])} - {max(results['sequence_length'])}\")\n", " \n", " # 统计真实序列情况\n", " has_true_seq = sum(1 for seq in results['true_sequence'] if seq)\n", " if has_true_seq > 0:\n", " true_lengths = [length for length in results['true_length'] if length > 0]\n", " print(f\" 包含真实序列: {has_true_seq} / {len(results['predicted_sequence'])} 个试验\")\n", " if true_lengths:\n", " print(f\" 平均真实序列长度: {np.mean(true_lengths):.1f}\")\n", " print(f\" 真实序列长度范围: {min(true_lengths)} - {max(true_lengths)}\")\n", " else:\n", " print(f\" ⚠️ 测试集无真实序列标签,无法计算WER\")\n", " \n", " return results\n", " \n", " def save_results(self, results, output_path=\"test_predictions.csv\"):\n", " \"\"\"\n", " 保存预测结果\n", " \"\"\"\n", " import pandas as pd\n", " \n", " df = pd.DataFrame(results)\n", " df.to_csv(output_path, index=False)\n", " \n", " print(f\"💾 预测结果已保存: {output_path}\")\n", " print(f\"📊 结果统计:\")\n", " print(f\" 预测样本数: {len(df)}\")\n", " print(f\" 平均预测序列长度: {df['sequence_length'].mean():.1f}\")\n", " \n", " # 检查是否有真实序列\n", " has_true_seq = sum(1 for seq in df['true_sequence'] if seq and len(seq) > 0)\n", " if has_true_seq > 0:\n", " print(f\" 包含真实序列: {has_true_seq} 个试验\")\n", " print(f\" 平均真实序列长度: {df['true_length'].mean():.1f}\")\n", " \n", " # 统计预测音素分布\n", " all_predicted_phonemes = []\n", " for seq in df['predicted_sequence']:\n", " if isinstance(seq, list):\n", " all_predicted_phonemes.extend(seq)\n", " \n", " if all_predicted_phonemes:\n", " from collections import Counter\n", " phoneme_counts = Counter(all_predicted_phonemes)\n", " print(f\" 预测音素分布 (前10):\")\n", " for phoneme, count in phoneme_counts.most_common(10):\n", " print(f\" {phoneme}: {count}\")\n", " \n", " return df\n", "\n", "print(\"✅ 测试集预测器类定义完成\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 🚀 执行测试集预测\n", "\n", "# 检查必要组件是否准备就绪\n", "if lgbm_model is None:\n", " print(\"❌ LightGBM模型未加载,请先运行模型加载代码\")\n", "elif not hasattr(pipeline, 'pca_fitted') or not pipeline.pca_fitted:\n", " print(\"❌ PCA模型未拟合,请先运行智能管道的步骤1和步骤2\")\n", "else:\n", " print(\"✅ 所有组件准备就绪,开始测试集预测...\")\n", " \n", " # 配置测试数据路径\n", " TEST_DATA_DIR = \"/kaggle/working/nejm-brain-to-text/data/hdf5_data_final\"\n", " \n", " # 创建预测器\n", " predictor = TestSetPredictor(\n", " lgbm_model=lgbm_model,\n", " pipeline=pipeline,\n", " data_dir=TEST_DATA_DIR,\n", " device='cpu' # 可以改为'cuda'如果有GPU\n", " )\n", " \n", " print(f\"\\n\" + \"=\"*60)\n", " print(\"🔍 第1步: 加载测试集数据\")\n", " print(\"=\"*60)\n", " \n", " # 加载测试数据\n", " test_data = predictor.load_test_sessions()\n", " \n", " if not test_data:\n", " print(\"❌ 未找到测试数据,请检查数据路径\")\n", " else:\n", " print(f\"\\n\" + \"=\"*60)\n", " print(\"🔮 第2步: 执行LightGBM预测\")\n", " print(\"=\"*60)\n", " \n", " # 执行预测\n", " try:\n", " prediction_results = predictor.predict_test_set(test_data)\n", " \n", " print(f\"\\n\" + \"=\"*60)\n", " print(\"💾 第3步: 保存预测结果\")\n", " print(\"=\"*60)\n", " \n", " # 保存结果\n", " results_df = predictor.save_results(\n", " prediction_results, \n", " \"lgbm_test_predictions.csv\"\n", " )\n", " \n", " print(f\"\\n🎉 测试集预测流程完成!\")\n", " print(f\" 📁 数据路径: {TEST_DATA_DIR}\")\n", " print(f\" 📊 处理会话: {len(test_data)} 个\")\n", " print(f\" 🎯 预测样本: {len(prediction_results['predicted_phonemes'])} 个\")\n", " print(f\" 💾 结果文件: lgbm_test_predictions.csv\")\n", " \n", " except Exception as e:\n", " print(f\"❌ 预测过程中发生错误: {e}\")\n", " import traceback\n", " traceback.print_exc()" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 📊 测试集预测结果分析\n", "\n", "def analyze_test_predictions(csv_path=\"lgbm_test_predictions.csv\"):\n", " \"\"\"\n", " 分析测试集预测结果\n", " \"\"\"\n", " if not os.path.exists(csv_path):\n", " print(f\"❌ 结果文件不存在: {csv_path}\")\n", " return\n", " \n", " print(f\"📊 分析测试集预测结果: {csv_path}\")\n", " \n", " import pandas as pd\n", " import matplotlib.pyplot as plt\n", " \n", " # 加载结果\n", " df = pd.read_csv(csv_path)\n", " \n", " print(f\"\\n📈 基本统计:\")\n", " print(f\" 总预测数: {len(df):,}\")\n", " print(f\" 会话数: {df['session'].nunique()}\")\n", " print(f\" 平均置信度: {df['prediction_confidence'].mean():.4f}\")\n", " print(f\" 置信度标准差: {df['prediction_confidence'].std():.4f}\")\n", " \n", " # 置信度分布\n", " plt.figure(figsize=(15, 10))\n", " \n", " # 1. 置信度直方图\n", " plt.subplot(2, 3, 1)\n", " plt.hist(df['prediction_confidence'], bins=50, alpha=0.7, color='skyblue', edgecolor='black')\n", " plt.xlabel('预测置信度')\n", " plt.ylabel('频数')\n", " plt.title('预测置信度分布')\n", " plt.grid(True, alpha=0.3)\n", " \n", " # 2. 音素分布 (前15个)\n", " plt.subplot(2, 3, 2)\n", " phoneme_counts = df['predicted_phonemes'].value_counts().head(15)\n", " phoneme_counts.plot(kind='bar', color='lightcoral')\n", " plt.xlabel('音素')\n", " plt.ylabel('预测次数')\n", " plt.title('预测音素分布 (前15)')\n", " plt.xticks(rotation=45)\n", " plt.grid(True, alpha=0.3)\n", " \n", " # 3. 各会话的预测数量\n", " plt.subplot(2, 3, 3)\n", " session_counts = df['session'].value_counts()\n", " session_counts.plot(kind='bar', color='lightgreen')\n", " plt.xlabel('会话')\n", " plt.ylabel('预测数量')\n", " plt.title('各会话预测数量')\n", " plt.xticks(rotation=45)\n", " plt.grid(True, alpha=0.3)\n", " \n", " # 4. 各会话的平均置信度\n", " plt.subplot(2, 3, 4)\n", " session_confidence = df.groupby('session')['prediction_confidence'].mean()\n", " session_confidence.plot(kind='bar', color='gold')\n", " plt.xlabel('会话')\n", " plt.ylabel('平均置信度')\n", " plt.title('各会话平均置信度')\n", " plt.xticks(rotation=45)\n", " plt.grid(True, alpha=0.3)\n", " \n", " # 5. 高置信度预测的音素分布\n", " plt.subplot(2, 3, 5)\n", " high_conf_df = df[df['prediction_confidence'] > 0.8]\n", " if len(high_conf_df) > 0:\n", " high_conf_phonemes = high_conf_df['predicted_phonemes'].value_counts().head(10)\n", " high_conf_phonemes.plot(kind='bar', color='orange')\n", " plt.xlabel('音素')\n", " plt.ylabel('高置信度预测次数')\n", " plt.title(f'高置信度(>0.8)音素分布\\n总数: {len(high_conf_df)}')\n", " plt.xticks(rotation=45)\n", " else:\n", " plt.text(0.5, 0.5, '无高置信度预测', ha='center', va='center')\n", " plt.title('高置信度预测分布')\n", " plt.grid(True, alpha=0.3)\n", " \n", " # 6. 置信度箱线图(按会话)\n", " plt.subplot(2, 3, 6)\n", " import seaborn as sns\n", " sns.boxplot(data=df, x='session', y='prediction_confidence')\n", " plt.xlabel('会话')\n", " plt.ylabel('置信度')\n", " plt.title('各会话置信度分布')\n", " plt.xticks(rotation=45)\n", " plt.grid(True, alpha=0.3)\n", " \n", " plt.tight_layout()\n", " plt.show()\n", " \n", " # 详细统计\n", " print(f\"\\n📋 详细统计:\")\n", " print(f\" 音素类别数: {df['predicted_phonemes'].nunique()}\")\n", " print(f\" 最常预测音素: {df['predicted_phonemes'].mode().iloc[0]} ({df['predicted_phonemes'].value_counts().iloc[0]} 次)\")\n", " print(f\" 高置信度预测 (>0.9): {(df['prediction_confidence'] > 0.9).sum()} / {len(df)} ({(df['prediction_confidence'] > 0.9).mean():.2%})\")\n", " print(f\" 低置信度预测 (<0.5): {(df['prediction_confidence'] < 0.5).sum()} / {len(df)} ({(df['prediction_confidence'] < 0.5).mean():.2%})\")\n", " \n", " return df\n", "\n", "# 如果预测结果文件存在,自动分析\n", "if os.path.exists(\"lgbm_test_predictions.csv\"):\n", " print(\"🔍 发现预测结果文件,开始自动分析...\")\n", " results_analysis = analyze_test_predictions()\n", "else:\n", " print(\"💡 运行上面的预测代码后,将自动分析结果\")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 🔍 序列预测逻辑说明\n", "\n", "你说得对!我们需要的是完整的音素序列,而不是单个预测的平均值。\n", "\n", "### 🎯 RNN vs LightGBM 序列处理对比\n", "\n", "#### **RNN的处理方式:**\n", "```\n", "神经数据 → 高斯平滑 → Patch分割 → RNN → 序列logits → 音素序列\n", "[1,100,512] → [1,86,512] → [1,19,7168] → RNN → [1,19,41] → ['AE','T','SH',...]\n", "```\n", "\n", "#### **我们的LightGBM处理方式:**\n", "```\n", "神经数据 → 高斯平滑 → Patch分割 → PCA → LightGBM → 序列logits → 音素序列 \n", "[1,100,512] → [1,86,512] → [19,7168] → [19,PCA_dim] → LightGBM → [19,41] → ['AE','T','SH',...]\n", "```\n", "\n", "### 🔄 关键修改\n", "1. **保持序列维度**: 不再对patch预测取平均,而是保持每个patch的独立预测\n", "2. **后处理序列**: 像RNN一样进行blank移除和重复音素合并\n", "3. **输出格式**: 每个试验输出完整的音素序列列表,而不是单个音素\n", "\n", "这样我们就能得到与RNN相同格式的序列输出!" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 🎯 序列预测结果展示和分析(含真实对比和WER计算)\n", "\n", "import pandas as pd\n", "import ast\n", "import editdistance\n", "\n", "def display_sequence_predictions_with_wer(csv_path=\"lgbm_test_predictions.csv\", num_examples=5):\n", " \"\"\"\n", " 展示序列预测结果,真实结果对比,并计算WER\n", " \"\"\"\n", " if not os.path.exists(csv_path):\n", " print(f\"❌ 结果文件不存在: {csv_path}\")\n", " return\n", " \n", " df = pd.read_csv(csv_path)\n", " df['predicted_sequence'] = df['predicted_sequence'].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x)\n", " if 'true_sequence' in df.columns:\n", " df['true_sequence'] = df['true_sequence'].apply(lambda x: ast.literal_eval(x) if isinstance(x, str) else x)\n", " else:\n", " print(\"⚠️ 没有 true_sequence 列,无法对比和计算WER\")\n", " return\n", " \n", " print(f\"\\n📊 序列预测统计:\")\n", " print(f\" 总试验数: {len(df):,}\")\n", " print(f\" 平均序列长度: {df['sequence_length'].mean():.1f}\")\n", " print(f\" 序列长度范围: {df['sequence_length'].min()} - {df['sequence_length'].max()}\")\n", " \n", " # 随机展示若干例子\n", " sample_df = df.sample(min(num_examples, len(df)), random_state=42)\n", " print(f\"\\n🎭 预测序列与真实序列对比 (随机 {num_examples} 个):\")\n", " print(\"=\" * 80)\n", " for idx, row in sample_df.iterrows():\n", " pred_seq = row['predicted_sequence']\n", " true_seq = row['true_sequence']\n", " print(f\"Trial {idx+1}:\")\n", " print(f\" True: {' '.join(true_seq)}\")\n", " print(f\" Predicted: {' '.join(pred_seq)}\")\n", " ed = editdistance.eval(true_seq, pred_seq)\n", " print(f\" Edit Distance: {ed} / {len(true_seq)} = {ed/len(true_seq):.2%}\")\n", " print(\"-\" * 40)\n", " \n", " # 计算整体WER\n", " total_ed = 0\n", " total_len = 0\n", " for idx, row in df.iterrows():\n", " pred_seq = row['predicted_sequence']\n", " true_seq = row['true_sequence']\n", " if true_seq:\n", " ed = editdistance.eval(true_seq, pred_seq)\n", " total_ed += ed\n", " total_len += len(true_seq)\n", " if total_len > 0:\n", " print(f\"\\nAggregate Phoneme WER: {total_ed} / {total_len} = {total_ed/total_len:.2%}\")\n", " else:\n", " print(\"No ground truth available for WER calculation.\")\n", " return df\n", "\n", "# 如果预测结果文件存在,展示序列预测结果和WER\n", "if os.path.exists(\"lgbm_test_predictions.csv\"):\n", " print(\"🎯 展示序列预测结果与WER...\")\n", " sequence_results = display_sequence_predictions_with_wer()\n", "else:\n", " print(\"💡 运行预测代码后,将展示序列预测结果和WER\")\n" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def edit_distance(seq1, seq2):\n", " \"\"\"\n", " 计算两个序列之间的编辑距离 (Levenshtein distance)\n", " 用于计算 WER (Word Error Rate)\n", " \"\"\"\n", " m, n = len(seq1), len(seq2)\n", " \n", " # 创建动态规划表\n", " dp = [[0] * (n + 1) for _ in range(m + 1)]\n", " \n", " # 初始化边界条件\n", " for i in range(m + 1):\n", " dp[i][0] = i\n", " for j in range(n + 1):\n", " dp[0][j] = j\n", " \n", " # 填充动态规划表\n", " for i in range(1, m + 1):\n", " for j in range(1, n + 1):\n", " if seq1[i-1] == seq2[j-1]:\n", " dp[i][j] = dp[i-1][j-1]\n", " else:\n", " dp[i][j] = 1 + min(\n", " dp[i-1][j], # 删除\n", " dp[i][j-1], # 插入\n", " dp[i-1][j-1] # 替换\n", " )\n", " \n", " return dp[m][n]\n", "\n", "def calculate_wer(predicted_seq, true_seq):\n", " \"\"\"\n", " 计算序列的 Word Error Rate (WER)\n", " WER = 编辑距离 / 真实序列长度\n", " \"\"\"\n", " if len(true_seq) == 0:\n", " return 1.0 if len(predicted_seq) > 0 else 0.0\n", " \n", " edit_dist = edit_distance(predicted_seq, true_seq)\n", " wer = edit_dist / len(true_seq)\n", " \n", " return wer\n", "\n", "print(\"✅ WER 计算功能已定义\")" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# 🚀 运行测试集预测(序列级别,包含WER计算)\n", "\n", "print(\"🔥 实例化 TestSetPredictor...\")\n", "predictor = TestSetPredictor()\n", "\n", "print(\"\\n🎯 开始测试集序列预测...\")\n", "results = predictor.predict_test_set()\n", "\n", "print(\"\\n💾 保存预测结果...\")\n", "df_results = predictor.save_results(results, \"lgbm_test_predictions_with_wer.csv\")\n", "\n", "print(\"\\n📊 计算整体WER...\")\n", "total_ed = 0\n", "total_len = 0\n", "valid_predictions = 0\n", "\n", "for result in results:\n", " pred_seq = result['predicted_sequence']\n", " true_seq = result['true_sequence']\n", " \n", " if true_seq and len(true_seq) > 0:\n", " wer = calculate_wer(pred_seq, true_seq)\n", " result['wer'] = wer\n", " \n", " ed = edit_distance(pred_seq, true_seq)\n", " total_ed += ed\n", " total_len += len(true_seq)\n", " valid_predictions += 1\n", " \n", " print(f\"Session {result['session']}, Block {result['block']}, Trial {result['trial']}: WER = {wer:.2%}\")\n", "\n", "if total_len > 0:\n", " overall_wer = total_ed / total_len\n", " print(f\"\\n🎯 整体序列 WER: {total_ed} / {total_len} = {overall_wer:.2%}\")\n", " print(f\"✅ 有效预测数量: {valid_predictions}\")\n", "else:\n", " print(\"❌ 没有可用的真实序列用于WER计算\")\n", "\n", "print(\"\\n🎉 测试集序列预测完成!\")" ] } ], "metadata": { "kaggle": { "accelerator": "tpu1vmV38", "dataSources": [ { "databundleVersionId": 13056355, "sourceId": 106809, "sourceType": "competition" } ], "dockerImageVersionId": 31091, "isGpuEnabled": false, "isInternetEnabled": true, "language": "python", "sourceType": "notebook" }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" } }, "nbformat": 4, "nbformat_minor": 4 }