4160 lines
		
	
	
		
			221 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
			
		
		
	
	
			4160 lines
		
	
	
		
			221 KiB
		
	
	
	
		
			Plaintext
		
	
	
	
	
	
| {
 | ||
|  "cells": [
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "metadata": {},
 | ||
|    "source": [
 | ||
|     "# 环境配置 与 utils"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 1,
 | ||
|    "metadata": {},
 | ||
|    "outputs": [
 | ||
|     {
 | ||
|      "name": "stdout",
 | ||
|      "output_type": "stream",
 | ||
|      "text": [
 | ||
|       "Looking in indexes: https://download.pytorch.org/whl/cu126\n",
 | ||
|       "Requirement already satisfied: torch in /usr/local/lib/python3.11/dist-packages (2.6.0+cu124)\n",
 | ||
|       "Requirement already satisfied: torchvision in /usr/local/lib/python3.11/dist-packages (0.21.0+cu124)\n",
 | ||
|       "Requirement already satisfied: torchaudio in /usr/local/lib/python3.11/dist-packages (2.6.0+cu124)\n",
 | ||
|       "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch) (3.18.0)\n",
 | ||
|       "Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.11/dist-packages (from torch) (4.14.0)\n",
 | ||
|       "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch) (3.5)\n",
 | ||
|       "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch) (3.1.6)\n",
 | ||
|       "Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch) (2025.5.1)\n",
 | ||
|       "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n",
 | ||
|       "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n",
 | ||
|       "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n",
 | ||
|       "Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.11/dist-packages (from torch) (9.1.0.70)\n",
 | ||
|       "Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.5.8)\n",
 | ||
|       "Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /usr/local/lib/python3.11/dist-packages (from torch) (11.2.1.3)\n",
 | ||
|       "Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /usr/local/lib/python3.11/dist-packages (from torch) (10.3.5.147)\n",
 | ||
|       "Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /usr/local/lib/python3.11/dist-packages (from torch) (11.6.1.9)\n",
 | ||
|       "Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /usr/local/lib/python3.11/dist-packages (from torch) (12.3.1.170)\n",
 | ||
|       "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch) (0.6.2)\n",
 | ||
|       "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch) (2.21.5)\n",
 | ||
|       "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n",
 | ||
|       "Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n",
 | ||
|       "Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch) (3.2.0)\n",
 | ||
|       "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch) (1.13.1)\n",
 | ||
|       "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch) (1.3.0)\n",
 | ||
|       "Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from torchvision) (1.26.4)\n",
 | ||
|       "Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.11/dist-packages (from torchvision) (11.2.1)\n",
 | ||
|       "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch) (3.0.2)\n",
 | ||
|       "Requirement already satisfied: mkl_fft in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (1.3.8)\n",
 | ||
|       "Requirement already satisfied: mkl_random in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (1.2.4)\n",
 | ||
|       "Requirement already satisfied: mkl_umath in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (0.1.1)\n",
 | ||
|       "Requirement already satisfied: mkl in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (2025.2.0)\n",
 | ||
|       "Requirement already satisfied: tbb4py in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (2022.2.0)\n",
 | ||
|       "Requirement already satisfied: mkl-service in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (2.4.1)\n",
 | ||
|       "Requirement already satisfied: intel-openmp<2026,>=2024 in /usr/local/lib/python3.11/dist-packages (from mkl->numpy->torchvision) (2024.2.0)\n",
 | ||
|       "Requirement already satisfied: tbb==2022.* in /usr/local/lib/python3.11/dist-packages (from mkl->numpy->torchvision) (2022.2.0)\n",
 | ||
|       "Requirement already satisfied: tcmlib==1.* in /usr/local/lib/python3.11/dist-packages (from tbb==2022.*->mkl->numpy->torchvision) (1.4.0)\n",
 | ||
|       "Requirement already satisfied: intel-cmplr-lib-rt in /usr/local/lib/python3.11/dist-packages (from mkl_umath->numpy->torchvision) (2024.2.0)\n",
 | ||
|       "Requirement already satisfied: intel-cmplr-lib-ur==2024.2.0 in /usr/local/lib/python3.11/dist-packages (from intel-openmp<2026,>=2024->mkl->numpy->torchvision) (2024.2.0)\n",
 | ||
|       "Requirement already satisfied: jupyter==1.1.1 in /usr/local/lib/python3.11/dist-packages (1.1.1)\n",
 | ||
|       "Requirement already satisfied: numpy<2.1.0,>=1.26.0 in /usr/local/lib/python3.11/dist-packages (1.26.4)\n",
 | ||
|       "Requirement already satisfied: pandas==2.3.0 in /usr/local/lib/python3.11/dist-packages (2.3.0)\n",
 | ||
|       "Requirement already satisfied: matplotlib==3.10.1 in /usr/local/lib/python3.11/dist-packages (3.10.1)\n",
 | ||
|       "Requirement already satisfied: scipy==1.15.2 in /usr/local/lib/python3.11/dist-packages (1.15.2)\n",
 | ||
|       "Requirement already satisfied: scikit-learn==1.6.1 in /usr/local/lib/python3.11/dist-packages (1.6.1)\n",
 | ||
|       "Requirement already satisfied: tqdm==4.67.1 in /usr/local/lib/python3.11/dist-packages (4.67.1)\n",
 | ||
|       "Requirement already satisfied: g2p_en==2.1.0 in /usr/local/lib/python3.11/dist-packages (2.1.0)\n",
 | ||
|       "Requirement already satisfied: h5py==3.13.0 in /usr/local/lib/python3.11/dist-packages (3.13.0)\n",
 | ||
|       "Requirement already satisfied: omegaconf==2.3.0 in /usr/local/lib/python3.11/dist-packages (2.3.0)\n",
 | ||
|       "Requirement already satisfied: editdistance==0.8.1 in /usr/local/lib/python3.11/dist-packages (0.8.1)\n",
 | ||
|       "Requirement already satisfied: huggingface-hub==0.33.1 in /usr/local/lib/python3.11/dist-packages (0.33.1)\n",
 | ||
|       "Requirement already satisfied: transformers==4.53.0 in /usr/local/lib/python3.11/dist-packages (4.53.0)\n",
 | ||
|       "Requirement already satisfied: tokenizers==0.21.2 in /usr/local/lib/python3.11/dist-packages (0.21.2)\n",
 | ||
|       "Requirement already satisfied: accelerate==1.8.1 in /usr/local/lib/python3.11/dist-packages (1.8.1)\n",
 | ||
|       "Requirement already satisfied: bitsandbytes==0.46.0 in /usr/local/lib/python3.11/dist-packages (0.46.0)\n",
 | ||
|       "Requirement already satisfied: notebook in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.5.4)\n",
 | ||
|       "Requirement already satisfied: jupyter-console in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.1.0)\n",
 | ||
|       "Requirement already satisfied: nbconvert in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.4.5)\n",
 | ||
|       "Requirement already satisfied: ipykernel in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.17.1)\n",
 | ||
|       "Requirement already satisfied: ipywidgets in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (8.1.5)\n",
 | ||
|       "Requirement already satisfied: jupyterlab in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (3.6.8)\n",
 | ||
|       "Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas==2.3.0) (2.9.0.post0)\n",
 | ||
|       "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas==2.3.0) (2025.2)\n",
 | ||
|       "Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas==2.3.0) (2025.2)\n",
 | ||
|       "Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (1.3.2)\n",
 | ||
|       "Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (0.12.1)\n",
 | ||
|       "Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (4.58.4)\n",
 | ||
|       "Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (1.4.8)\n",
 | ||
|       "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (25.0)\n",
 | ||
|       "Requirement already satisfied: pillow>=8 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (11.2.1)\n",
 | ||
|       "Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (3.0.9)\n",
 | ||
|       "Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn==1.6.1) (1.5.1)\n",
 | ||
|       "Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn==1.6.1) (3.6.0)\n",
 | ||
|       "Requirement already satisfied: nltk>=3.2.4 in /usr/local/lib/python3.11/dist-packages (from g2p_en==2.1.0) (3.9.1)\n",
 | ||
|       "Requirement already satisfied: inflect>=0.3.1 in /usr/local/lib/python3.11/dist-packages (from g2p_en==2.1.0) (7.5.0)\n",
 | ||
|       "Requirement already satisfied: distance>=0.1.3 in /usr/local/lib/python3.11/dist-packages (from g2p_en==2.1.0) (0.1.3)\n",
 | ||
|       "Requirement already satisfied: antlr4-python3-runtime==4.9.* in /usr/local/lib/python3.11/dist-packages (from omegaconf==2.3.0) (4.9.3)\n",
 | ||
|       "Requirement already satisfied: PyYAML>=5.1.0 in /usr/local/lib/python3.11/dist-packages (from omegaconf==2.3.0) (6.0.2)\n",
 | ||
|       "Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (3.18.0)\n",
 | ||
|       "Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (2025.5.1)\n",
 | ||
|       "Requirement already satisfied: requests in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (2.32.4)\n",
 | ||
|       "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (4.14.0)\n",
 | ||
|       "Requirement already satisfied: hf-xet<2.0.0,>=1.1.2 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (1.1.5)\n",
 | ||
|       "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.11/dist-packages (from transformers==4.53.0) (2024.11.6)\n",
 | ||
|       "Requirement already satisfied: safetensors>=0.4.3 in /usr/local/lib/python3.11/dist-packages (from transformers==4.53.0) (0.5.3)\n",
 | ||
|       "Requirement already satisfied: psutil in /usr/local/lib/python3.11/dist-packages (from accelerate==1.8.1) (7.0.0)\n",
 | ||
|       "Requirement already satisfied: torch>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from accelerate==1.8.1) (2.6.0+cu124)\n",
 | ||
|       "Requirement already satisfied: mkl_fft in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (1.3.8)\n",
 | ||
|       "Requirement already satisfied: mkl_random in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (1.2.4)\n",
 | ||
|       "Requirement already satisfied: mkl_umath in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (0.1.1)\n",
 | ||
|       "Requirement already satisfied: mkl in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (2025.2.0)\n",
 | ||
|       "Requirement already satisfied: tbb4py in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (2022.2.0)\n",
 | ||
|       "Requirement already satisfied: mkl-service in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (2.4.1)\n",
 | ||
|       "Requirement already satisfied: more_itertools>=8.5.0 in /usr/local/lib/python3.11/dist-packages (from inflect>=0.3.1->g2p_en==2.1.0) (10.7.0)\n",
 | ||
|       "Requirement already satisfied: typeguard>=4.0.1 in /usr/local/lib/python3.11/dist-packages (from inflect>=0.3.1->g2p_en==2.1.0) (4.4.4)\n",
 | ||
|       "Requirement already satisfied: click in /usr/local/lib/python3.11/dist-packages (from nltk>=3.2.4->g2p_en==2.1.0) (8.2.1)\n",
 | ||
|       "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.8.2->pandas==2.3.0) (1.17.0)\n",
 | ||
|       "Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (3.5)\n",
 | ||
|       "Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (3.1.6)\n",
 | ||
|       "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n",
 | ||
|       "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n",
 | ||
|       "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n",
 | ||
|       "Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (9.1.0.70)\n",
 | ||
|       "Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.5.8)\n",
 | ||
|       "Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (11.2.1.3)\n",
 | ||
|       "Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (10.3.5.147)\n",
 | ||
|       "Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (11.6.1.9)\n",
 | ||
|       "Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.3.1.170)\n",
 | ||
|       "Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (0.6.2)\n",
 | ||
|       "Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (2.21.5)\n",
 | ||
|       "Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n",
 | ||
|       "Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n",
 | ||
|       "Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (3.2.0)\n",
 | ||
|       "Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (1.13.1)\n",
 | ||
|       "Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch>=2.0.0->accelerate==1.8.1) (1.3.0)\n",
 | ||
|       "Requirement already satisfied: debugpy>=1.0 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (1.8.0)\n",
 | ||
|       "Requirement already satisfied: ipython>=7.23.1 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (7.34.0)\n",
 | ||
|       "Requirement already satisfied: jupyter-client>=6.1.12 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (8.6.3)\n",
 | ||
|       "Requirement already satisfied: matplotlib-inline>=0.1 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (0.1.7)\n",
 | ||
|       "Requirement already satisfied: nest-asyncio in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (1.6.0)\n",
 | ||
|       "Requirement already satisfied: pyzmq>=17 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (24.0.1)\n",
 | ||
|       "Requirement already satisfied: tornado>=6.1 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (6.5.1)\n",
 | ||
|       "Requirement already satisfied: traitlets>=5.1.0 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (5.7.1)\n",
 | ||
|       "Requirement already satisfied: comm>=0.1.3 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->jupyter==1.1.1) (0.2.2)\n",
 | ||
|       "Requirement already satisfied: widgetsnbextension~=4.0.12 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->jupyter==1.1.1) (4.0.14)\n",
 | ||
|       "Requirement already satisfied: jupyterlab-widgets~=3.0.12 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->jupyter==1.1.1) (3.0.15)\n",
 | ||
|       "Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-console->jupyter==1.1.1) (3.0.51)\n",
 | ||
|       "Requirement already satisfied: pygments in /usr/local/lib/python3.11/dist-packages (from jupyter-console->jupyter==1.1.1) (2.19.2)\n",
 | ||
|       "Requirement already satisfied: jupyter-core in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (5.8.1)\n",
 | ||
|       "Requirement already satisfied: jupyterlab-server~=2.19 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (2.27.3)\n",
 | ||
|       "Requirement already satisfied: jupyter-server<3,>=1.16.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (2.12.5)\n",
 | ||
|       "Requirement already satisfied: jupyter-ydoc~=0.2.4 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (0.2.5)\n",
 | ||
|       "Requirement already satisfied: jupyter-server-ydoc~=0.8.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (0.8.0)\n",
 | ||
|       "Requirement already satisfied: nbclassic in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (1.3.1)\n",
 | ||
|       "Requirement already satisfied: argon2-cffi in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (25.1.0)\n",
 | ||
|       "Requirement already satisfied: ipython-genutils in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (0.2.0)\n",
 | ||
|       "Requirement already satisfied: nbformat in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (5.10.4)\n",
 | ||
|       "Requirement already satisfied: Send2Trash>=1.8.0 in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (1.8.3)\n",
 | ||
|       "Requirement already satisfied: terminado>=0.8.3 in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (0.18.1)\n",
 | ||
|       "Requirement already satisfied: prometheus-client in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (0.22.1)\n",
 | ||
|       "Requirement already satisfied: mistune<2,>=0.8.1 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.8.4)\n",
 | ||
|       "Requirement already satisfied: jupyterlab-pygments in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.3.0)\n",
 | ||
|       "Requirement already satisfied: entrypoints>=0.2.2 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.4)\n",
 | ||
|       "Requirement already satisfied: bleach in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (6.2.0)\n",
 | ||
|       "Requirement already satisfied: pandocfilters>=1.4.1 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (1.5.1)\n",
 | ||
|       "Requirement already satisfied: testpath in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.6.0)\n",
 | ||
|       "Requirement already satisfied: defusedxml in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.7.1)\n",
 | ||
|       "Requirement already satisfied: beautifulsoup4 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (4.13.4)\n",
 | ||
|       "Requirement already satisfied: nbclient<0.6.0,>=0.5.0 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.5.13)\n",
 | ||
|       "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (3.0.2)\n",
 | ||
|       "Requirement already satisfied: intel-openmp<2026,>=2024 in /usr/local/lib/python3.11/dist-packages (from mkl->numpy<2.1.0,>=1.26.0) (2024.2.0)\n",
 | ||
|       "Requirement already satisfied: tbb==2022.* in /usr/local/lib/python3.11/dist-packages (from mkl->numpy<2.1.0,>=1.26.0) (2022.2.0)\n",
 | ||
|       "Requirement already satisfied: tcmlib==1.* in /usr/local/lib/python3.11/dist-packages (from tbb==2022.*->mkl->numpy<2.1.0,>=1.26.0) (1.4.0)\n",
 | ||
|       "Requirement already satisfied: intel-cmplr-lib-rt in /usr/local/lib/python3.11/dist-packages (from mkl_umath->numpy<2.1.0,>=1.26.0) (2024.2.0)\n",
 | ||
|       "Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (3.4.2)\n",
 | ||
|       "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (3.10)\n",
 | ||
|       "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (2.5.0)\n",
 | ||
|       "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (2025.6.15)\n",
 | ||
|       "Requirement already satisfied: intel-cmplr-lib-ur==2024.2.0 in /usr/local/lib/python3.11/dist-packages (from intel-openmp<2026,>=2024->mkl->numpy<2.1.0,>=1.26.0) (2024.2.0)\n",
 | ||
|       "Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (75.2.0)\n",
 | ||
|       "Requirement already satisfied: jedi>=0.16 in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.19.2)\n",
 | ||
|       "Requirement already satisfied: decorator in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (4.4.2)\n",
 | ||
|       "Requirement already satisfied: pickleshare in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.7.5)\n",
 | ||
|       "Requirement already satisfied: backcall in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.2.0)\n",
 | ||
|       "Requirement already satisfied: pexpect>4.3 in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (4.9.0)\n",
 | ||
|       "Requirement already satisfied: platformdirs>=2.5 in /usr/local/lib/python3.11/dist-packages (from jupyter-core->jupyterlab->jupyter==1.1.1) (4.3.8)\n",
 | ||
|       "Requirement already satisfied: anyio>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (4.9.0)\n",
 | ||
|       "Requirement already satisfied: jupyter-events>=0.9.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.12.0)\n",
 | ||
|       "Requirement already satisfied: jupyter-server-terminals in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.5.3)\n",
 | ||
|       "Requirement already satisfied: overrides in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (7.7.0)\n",
 | ||
|       "Requirement already satisfied: websocket-client in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.8.0)\n",
 | ||
|       "Requirement already satisfied: jupyter-server-fileid<1,>=0.6.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (0.9.3)\n",
 | ||
|       "Requirement already satisfied: ypy-websocket<0.9.0,>=0.8.2 in /usr/local/lib/python3.11/dist-packages (from jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (0.8.4)\n",
 | ||
|       "Requirement already satisfied: y-py<0.7.0,>=0.6.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-ydoc~=0.2.4->jupyterlab->jupyter==1.1.1) (0.6.2)\n",
 | ||
|       "Requirement already satisfied: babel>=2.10 in /usr/local/lib/python3.11/dist-packages (from jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (2.17.0)\n",
 | ||
|       "Requirement already satisfied: json5>=0.9.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (0.12.0)\n",
 | ||
|       "Requirement already satisfied: jsonschema>=4.18.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (4.24.0)\n",
 | ||
|       "Requirement already satisfied: notebook-shim>=0.2.3 in /usr/local/lib/python3.11/dist-packages (from nbclassic->jupyterlab->jupyter==1.1.1) (0.2.4)\n",
 | ||
|       "Requirement already satisfied: fastjsonschema>=2.15 in /usr/local/lib/python3.11/dist-packages (from nbformat->notebook->jupyter==1.1.1) (2.21.1)\n",
 | ||
|       "Requirement already satisfied: wcwidth in /usr/local/lib/python3.11/dist-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->jupyter-console->jupyter==1.1.1) (0.2.13)\n",
 | ||
|       "Requirement already satisfied: ptyprocess in /usr/local/lib/python3.11/dist-packages (from terminado>=0.8.3->notebook->jupyter==1.1.1) (0.7.0)\n",
 | ||
|       "Requirement already satisfied: argon2-cffi-bindings in /usr/local/lib/python3.11/dist-packages (from argon2-cffi->notebook->jupyter==1.1.1) (21.2.0)\n",
 | ||
|       "Requirement already satisfied: soupsieve>1.2 in /usr/local/lib/python3.11/dist-packages (from beautifulsoup4->nbconvert->jupyter==1.1.1) (2.7)\n",
 | ||
|       "Requirement already satisfied: webencodings in /usr/local/lib/python3.11/dist-packages (from bleach->nbconvert->jupyter==1.1.1) (0.5.1)\n",
 | ||
|       "Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.11/dist-packages (from anyio>=3.1.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.3.1)\n",
 | ||
|       "Requirement already satisfied: parso<0.9.0,>=0.8.4 in /usr/local/lib/python3.11/dist-packages (from jedi>=0.16->ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.8.4)\n",
 | ||
|       "Requirement already satisfied: attrs>=22.2.0 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (25.3.0)\n",
 | ||
|       "Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (2025.4.1)\n",
 | ||
|       "Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (0.36.2)\n",
 | ||
|       "Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (0.25.1)\n",
 | ||
|       "Requirement already satisfied: python-json-logger>=2.0.4 in /usr/local/lib/python3.11/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (3.3.0)\n",
 | ||
|       "Requirement already satisfied: rfc3339-validator in /usr/local/lib/python3.11/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.1.4)\n",
 | ||
|       "Requirement already satisfied: rfc3986-validator>=0.1.1 in /usr/local/lib/python3.11/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.1.1)\n",
 | ||
|       "Requirement already satisfied: aiofiles<23,>=22.1.0 in /usr/local/lib/python3.11/dist-packages (from ypy-websocket<0.9.0,>=0.8.2->jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (22.1.0)\n",
 | ||
|       "Requirement already satisfied: aiosqlite<1,>=0.17.0 in /usr/local/lib/python3.11/dist-packages (from ypy-websocket<0.9.0,>=0.8.2->jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (0.21.0)\n",
 | ||
|       "Requirement already satisfied: cffi>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from argon2-cffi-bindings->argon2-cffi->notebook->jupyter==1.1.1) (1.17.1)\n",
 | ||
|       "Requirement already satisfied: pycparser in /usr/local/lib/python3.11/dist-packages (from cffi>=1.0.1->argon2-cffi-bindings->argon2-cffi->notebook->jupyter==1.1.1) (2.22)\n",
 | ||
|       "Requirement already satisfied: fqdn in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.5.1)\n",
 | ||
|       "Requirement already satisfied: isoduration in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (20.11.0)\n",
 | ||
|       "Requirement already satisfied: jsonpointer>1.13 in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (3.0.0)\n",
 | ||
|       "Requirement already satisfied: uri-template in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.3.0)\n",
 | ||
|       "Requirement already satisfied: webcolors>=24.6.0 in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (24.11.1)\n",
 | ||
|       "Requirement already satisfied: arrow>=0.15.0 in /usr/local/lib/python3.11/dist-packages (from isoduration->jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.3.0)\n",
 | ||
|       "Requirement already satisfied: types-python-dateutil>=2.8.10 in /usr/local/lib/python3.11/dist-packages (from arrow>=0.15.0->isoduration->jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (2.9.0.20250516)\n",
 | ||
|       "Obtaining file:///kaggle/working/nejm-brain-to-text\n",
 | ||
|       "  Preparing metadata (setup.py): started\n",
 | ||
|       "  Preparing metadata (setup.py): finished with status 'done'\n",
 | ||
|       "Installing collected packages: nejm_b2txt_utils\n",
 | ||
|       "  Running setup.py develop for nejm_b2txt_utils\n",
 | ||
|       "Successfully installed nejm_b2txt_utils-0.0.0\n"
 | ||
|      ]
 | ||
|     },
 | ||
|     {
 | ||
|      "name": "stderr",
 | ||
|      "output_type": "stream",
 | ||
|      "text": [
 | ||
|       "Cloning into 'nejm-brain-to-text'...\n"
 | ||
|      ]
 | ||
|     }
 | ||
|    ],
 | ||
|    "source": [
 | ||
|     "%%bash\n",
 | ||
|     "cd /kaggle/working/\n",
 | ||
|     "# rm -rf /kaggle/working/nejm-brain-to-text/\n",
 | ||
|     "# git clone https://github.com/ZH-CEN/nejm-brain-to-text.git\n",
 | ||
|     "# cd /kaggle/working/nejm-brain-to-text/\n",
 | ||
|     "# cp /kaggle/input/brain-to-text-baseline-model/t15_copyTask.pkl /kaggle/working/nejm-brain-to-text/data/t15_copyTask.pkl\n",
 | ||
|     "\n",
 | ||
|     "# Install the local package\n",
 | ||
|     "\n",
 | ||
|     "# ln -s /kaggle/input/brain-to-text-25/t15_pretrained_rnn_baseline/t15_pretrained_rnn_baseline /kaggle/working/nejm-brain-to-text/data\n",
 | ||
|     "# ln -s /kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final /kaggle/working/nejm-brain-to-text/data\n",
 | ||
|     "# ln -s /kaggle/input/rnn-pretagged-data /kaggle/working/nejm-brain-to-text/data\n",
 | ||
|     "\n",
 | ||
|     "# # Install PyTorch with CUDA 12.6\n",
 | ||
|     "pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126\n",
 | ||
|     "\n",
 | ||
|     "# Install additional packages with compatible versions\n",
 | ||
|     "# TODO: remove redis\n",
 | ||
|     "pip install \\\n",
 | ||
|     "    jupyter==1.1.1 \\\n",
 | ||
|     "    \"numpy>=1.26.0,<2.1.0\" \\\n",
 | ||
|     "    pandas==2.3.0 \\\n",
 | ||
|     "    matplotlib==3.10.1 \\\n",
 | ||
|     "    scipy==1.15.2 \\\n",
 | ||
|     "    scikit-learn==1.6.1 \\\n",
 | ||
|     "    tqdm==4.67.1 \\\n",
 | ||
|     "    g2p_en==2.1.0 \\\n",
 | ||
|     "    h5py==3.13.0 \\\n",
 | ||
|     "    omegaconf==2.3.0 \\\n",
 | ||
|     "    editdistance==0.8.1 \\\n",
 | ||
|     "    huggingface-hub==0.33.1 \\\n",
 | ||
|     "    transformers==4.53.0 \\\n",
 | ||
|     "    tokenizers==0.21.2 \\\n",
 | ||
|     "    accelerate==1.8.1 \\\n",
 | ||
|     "    bitsandbytes==0.46.0\n",
 | ||
|     "pip install -e ."
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 2,
 | ||
|    "metadata": {},
 | ||
|    "outputs": [
 | ||
|     {
 | ||
|      "name": "stdout",
 | ||
|      "output_type": "stream",
 | ||
|      "text": [
 | ||
|       "/kaggle/working/nejm-brain-to-text\n"
 | ||
|      ]
 | ||
|     }
 | ||
|    ],
 | ||
|    "source": [
 | ||
|     "%cd /kaggle/working/nejm-brain-to-text\n",
 | ||
|     "import numpy as np\n",
 | ||
|     "import os\n",
 | ||
|     "import pickle\n",
 | ||
|     "import matplotlib.pyplot as plt\n",
 | ||
|     "import matplotlib\n",
 | ||
|     "from g2p_en import G2p\n",
 | ||
|     "import pandas as pd\n",
 | ||
|     "import numpy as np\n",
 | ||
|     "from nejm_b2txt_utils.general_utils import *\n",
 | ||
|     "\n",
 | ||
|     "matplotlib.rcParams['pdf.fonttype'] = 42\n",
 | ||
|     "matplotlib.rcParams['ps.fonttype'] = 42\n",
 | ||
|     "matplotlib.rcParams['font.family'] = 'sans-serif'\n"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 4,
 | ||
|    "metadata": {},
 | ||
|    "outputs": [
 | ||
|     {
 | ||
|      "name": "stdout",
 | ||
|      "output_type": "stream",
 | ||
|      "text": [
 | ||
|       "/kaggle/working/nejm-brain-to-text/model_training\n"
 | ||
|      ]
 | ||
|     }
 | ||
|    ],
 | ||
|    "source": [
 | ||
|     "%cd model_training/\n",
 | ||
|     "from data_augmentations import gauss_smooth\n",
 | ||
|     "# single decoding step function that also returns smoothed input\n",
 | ||
|     "# smooths data and puts it through the model, returning both logits and smoothed input.\n",
 | ||
|     "def runSingleDecodingStepWithSmoothedInput(x, input_layer, model, model_args, device):\n",
 | ||
|     "\n",
 | ||
|     "    # Use autocast for efficiency\n",
 | ||
|     "    with torch.autocast(device_type = \"cuda\", enabled = model_args['use_amp'], dtype = torch.bfloat16):\n",
 | ||
|     "\n",
 | ||
|     "        smoothed_x = gauss_smooth(\n",
 | ||
|     "            inputs = x, \n",
 | ||
|     "            device = device,\n",
 | ||
|     "            smooth_kernel_std = model_args['dataset']['data_transforms']['smooth_kernel_std'],\n",
 | ||
|     "            smooth_kernel_size = model_args['dataset']['data_transforms']['smooth_kernel_size'],\n",
 | ||
|     "            padding = 'valid',\n",
 | ||
|     "        )\n",
 | ||
|     "\n",
 | ||
|     "        with torch.no_grad():\n",
 | ||
|     "            logits, _ = model(\n",
 | ||
|     "                x = smoothed_x,\n",
 | ||
|     "                day_idx = torch.tensor([input_layer], device=device),\n",
 | ||
|     "                states = None, # no initial states\n",
 | ||
|     "                return_state = True,\n",
 | ||
|     "            )\n",
 | ||
|     "\n",
 | ||
|     "    # convert both logits and smoothed input from bfloat16 to float32\n",
 | ||
|     "    logits = logits.float().cpu().numpy()\n",
 | ||
|     "    smoothed_input = smoothed_x.float().cpu().numpy()\n",
 | ||
|     "\n",
 | ||
|     "    # # original order is [BLANK, phonemes..., SIL]\n",
 | ||
|     "    # # rearrange so the order is [BLANK, SIL, phonemes...]\n",
 | ||
|     "    # logits = rearrange_speech_logits_pt(logits)\n",
 | ||
|     "\n",
 | ||
|     "    return logits, smoothed_input\n",
 | ||
|     "\n",
 | ||
|     "\n",
 | ||
|     "import h5py\n",
 | ||
|     "def load_h5py_file(file_path, b2txt_csv_df):\n",
 | ||
|     "    data = {\n",
 | ||
|     "        'neural_features': [],\n",
 | ||
|     "        'n_time_steps': [],\n",
 | ||
|     "        'seq_class_ids': [],\n",
 | ||
|     "        'seq_len': [],\n",
 | ||
|     "        'transcriptions': [],\n",
 | ||
|     "        'sentence_label': [],\n",
 | ||
|     "        'session': [],\n",
 | ||
|     "        'block_num': [],\n",
 | ||
|     "        'trial_num': [],\n",
 | ||
|     "        'corpus': [],\n",
 | ||
|     "    }\n",
 | ||
|     "    # Open the hdf5 file for that day\n",
 | ||
|     "    with h5py.File(file_path, 'r') as f:\n",
 | ||
|     "\n",
 | ||
|     "        keys = list(f.keys())\n",
 | ||
|     "\n",
 | ||
|     "        # For each trial in the selected trials in that day\n",
 | ||
|     "        for key in keys:\n",
 | ||
|     "            g = f[key]\n",
 | ||
|     "\n",
 | ||
|     "            neural_features = g['input_features'][:] # pyright: ignore[reportIndexIssue]\n",
 | ||
|     "            n_time_steps = g.attrs['n_time_steps']\n",
 | ||
|     "            seq_class_ids = g['seq_class_ids'][:] if 'seq_class_ids' in g else None # type: ignore\n",
 | ||
|     "            seq_len = g.attrs['seq_len'] if 'seq_len' in g.attrs else None\n",
 | ||
|     "            transcription = g['transcription'][:] if 'transcription' in g else None # type: ignore\n",
 | ||
|     "            sentence_label = g.attrs['sentence_label'][:] if 'sentence_label' in g.attrs else None # pyright: ignore[reportIndexIssue]\n",
 | ||
|     "            session = g.attrs['session']\n",
 | ||
|     "            block_num = g.attrs['block_num']\n",
 | ||
|     "            trial_num = g.attrs['trial_num']\n",
 | ||
|     "\n",
 | ||
|     "            # match this trial up with the csv to get the corpus name\n",
 | ||
|     "            year, month, day = session.split('.')[1:] # pyright: ignore[reportAttributeAccessIssue]\n",
 | ||
|     "            date = f'{year}-{month}-{day}'\n",
 | ||
|     "            row = b2txt_csv_df[(b2txt_csv_df['Date'] == date) & (b2txt_csv_df['Block number'] == block_num)]\n",
 | ||
|     "            corpus_name = row['Corpus'].values[0]\n",
 | ||
|     "\n",
 | ||
|     "            data['neural_features'].append(neural_features)\n",
 | ||
|     "            data['n_time_steps'].append(n_time_steps)\n",
 | ||
|     "            data['seq_class_ids'].append(seq_class_ids)\n",
 | ||
|     "            data['seq_len'].append(seq_len)\n",
 | ||
|     "            data['transcriptions'].append(transcription)\n",
 | ||
|     "            data['sentence_label'].append(sentence_label)\n",
 | ||
|     "            data['session'].append(session)\n",
 | ||
|     "            data['block_num'].append(block_num)\n",
 | ||
|     "            data['trial_num'].append(trial_num)\n",
 | ||
|     "            data['corpus'].append(corpus_name)\n",
 | ||
|     "    return data"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 5,
 | ||
|    "metadata": {},
 | ||
|    "outputs": [],
 | ||
|    "source": [
 | ||
|     "LOGIT_TO_PHONEME = [\n",
 | ||
|     "    'BLANK',\n",
 | ||
|     "    'AA', 'AE', 'AH', 'AO', 'AW',\n",
 | ||
|     "    'AY', 'B',  'CH', 'D', 'DH',\n",
 | ||
|     "    'EH', 'ER', 'EY', 'F', 'G',\n",
 | ||
|     "    'HH', 'IH', 'IY', 'JH', 'K',\n",
 | ||
|     "    'L', 'M', 'N', 'NG', 'OW',\n",
 | ||
|     "    'OY', 'P', 'R', 'S', 'SH',\n",
 | ||
|     "    'T', 'TH', 'UH', 'UW', 'V',\n",
 | ||
|     "    'W', 'Y', 'Z', 'ZH',\n",
 | ||
|     "    ' | ',\n",
 | ||
|     "]"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "metadata": {},
 | ||
|    "source": [
 | ||
|     "# 数据分析与预处理"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "metadata": {},
 | ||
|    "source": [
 | ||
|     "## 数据准备"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 6,
 | ||
|    "metadata": {},
 | ||
|    "outputs": [
 | ||
|     {
 | ||
|      "name": "stdout",
 | ||
|      "output_type": "stream",
 | ||
|      "text": [
 | ||
|       "/kaggle/working/nejm-brain-to-text\n"
 | ||
|      ]
 | ||
|     }
 | ||
|    ],
 | ||
|    "source": [
 | ||
|     "%cd /kaggle/working/nejm-brain-to-text/"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 7,
 | ||
|    "metadata": {},
 | ||
|    "outputs": [],
 | ||
|    "source": [
 | ||
|     "import pandas as pd\n",
 | ||
|     "data = load_h5py_file(file_path='data/hdf5_data_final/t15.2023.08.11/data_train.hdf5',\n",
 | ||
|     "                        b2txt_csv_df=pd.read_csv('data/t15_copyTaskData_description.csv'))"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "metadata": {},
 | ||
|    "source": [
 | ||
|     "- **任务介绍** :机器学习解决高维信号的模式识别问题"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "metadata": {},
 | ||
|    "source": [
 | ||
|     "我们的数据集标签缺少时间戳,现在要进行的是半监督学习"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "metadata": {},
 | ||
|    "source": [
 | ||
|     "- 音素时间均等分割或者按照调研数据设定初始长度。然后筛掉异常值。提取出可用的训练集,再控制时间长短,查看样本类的长度"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 8,
 | ||
|    "metadata": {},
 | ||
|    "outputs": [
 | ||
|     {
 | ||
|      "data": {
 | ||
|       "text/plain": [
 | ||
|        "{'neural_features': array([[ 2.3076649 , -0.78699756, -0.64687246, ...,  0.57367045,\n",
 | ||
|        "         -0.7091646 , -0.11018186],\n",
 | ||
|        "        [-0.5859305 , -0.78699756, -0.64687246, ...,  0.3122117 ,\n",
 | ||
|        "          1.7943763 , -0.76884896],\n",
 | ||
|        "        [-0.5859305 , -0.78699756, -0.64687246, ..., -0.21193463,\n",
 | ||
|        "         -0.8481289 , -0.7648201 ],\n",
 | ||
|        "        ...,\n",
 | ||
|        "        [-0.5859305 ,  0.22756557,  0.9262037 , ..., -0.34710956,\n",
 | ||
|        "          0.9710176 ,  2.5397465 ],\n",
 | ||
|        "        [-0.5859305 ,  0.22756557, -0.64687246, ..., -0.83613133,\n",
 | ||
|        "         -0.68723625,  0.10479005],\n",
 | ||
|        "        [ 0.8608672 , -0.78699756, -0.64687246, ..., -0.7171131 ,\n",
 | ||
|        "          0.7417906 , -0.7008622 ]], dtype=float32),\n",
 | ||
|        " 'n_time_steps': 321,\n",
 | ||
|        " 'seq_class_ids': array([ 7, 28, 17, 24, 40, 17, 31, 40, 20, 21, 25, 29, 12, 40,  0,  0,  0,\n",
 | ||
|        "         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n",
 | ||
|        "         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n",
 | ||
|        "         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n",
 | ||
|        "         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n",
 | ||
|        "         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n",
 | ||
|        "         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n",
 | ||
|        "         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n",
 | ||
|        "         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n",
 | ||
|        "         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n",
 | ||
|        "         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n",
 | ||
|        "         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n",
 | ||
|        "         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n",
 | ||
|        "         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n",
 | ||
|        "         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n",
 | ||
|        "         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n",
 | ||
|        "         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n",
 | ||
|        "         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n",
 | ||
|        "         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n",
 | ||
|        "         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n",
 | ||
|        "         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n",
 | ||
|        "         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n",
 | ||
|        "         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n",
 | ||
|        "         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n",
 | ||
|        "         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n",
 | ||
|        "         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n",
 | ||
|        "         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n",
 | ||
|        "         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n",
 | ||
|        "         0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,  0,\n",
 | ||
|        "         0,  0,  0,  0,  0,  0,  0], dtype=int32),\n",
 | ||
|        " 'seq_len': 14,\n",
 | ||
|        " 'transcriptions': array([ 66, 114, 105, 110, 103,  32, 105, 116,  32,  99, 108, 111, 115,\n",
 | ||
|        "        101, 114,  46,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
 | ||
|        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
 | ||
|        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
 | ||
|        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
 | ||
|        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
 | ||
|        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
 | ||
|        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
 | ||
|        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
 | ||
|        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
 | ||
|        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
 | ||
|        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
 | ||
|        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
 | ||
|        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
 | ||
|        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
 | ||
|        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
 | ||
|        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
 | ||
|        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
 | ||
|        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
 | ||
|        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
 | ||
|        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
 | ||
|        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
 | ||
|        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
 | ||
|        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
 | ||
|        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
 | ||
|        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
 | ||
|        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
 | ||
|        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
 | ||
|        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
 | ||
|        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
 | ||
|        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
 | ||
|        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
 | ||
|        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
 | ||
|        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
 | ||
|        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
 | ||
|        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
 | ||
|        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
 | ||
|        "          0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,   0,\n",
 | ||
|        "          0,   0,   0,   0,   0,   0], dtype=int32),\n",
 | ||
|        " 'sentence_label': 'Bring it closer.',\n",
 | ||
|        " 'session': 't15.2023.08.11',\n",
 | ||
|        " 'block_num': 2,\n",
 | ||
|        " 'trial_num': 0,\n",
 | ||
|        " 'corpus': '50-Word'}"
 | ||
|       ]
 | ||
|      },
 | ||
|      "execution_count": 8,
 | ||
|      "metadata": {},
 | ||
|      "output_type": "execute_result"
 | ||
|     }
 | ||
|    ],
 | ||
|    "source": [
 | ||
|     "def data_patch(data, index):\n",
 | ||
|     "    data_patch = {}\n",
 | ||
|     "    data_patch['neural_features'] = data['neural_features'][index]\n",
 | ||
|     "    data_patch['n_time_steps'] = data['n_time_steps'][index]\n",
 | ||
|     "    data_patch['seq_class_ids'] = data['seq_class_ids'][index]\n",
 | ||
|     "    data_patch['seq_len'] = data['seq_len'][index]\n",
 | ||
|     "    data_patch['transcriptions'] = data['transcriptions'][index]\n",
 | ||
|     "    data_patch['sentence_label'] = data['sentence_label'][index]\n",
 | ||
|     "    data_patch['session'] = data['session'][index]\n",
 | ||
|     "    data_patch['block_num'] = data['block_num'][index]\n",
 | ||
|     "    data_patch['trial_num'] = data['trial_num'][index]\n",
 | ||
|     "    data_patch['corpus'] = data['corpus'][index]\n",
 | ||
|     "    return data_patch\n",
 | ||
|     "\n",
 | ||
|     "data_patch(data, 0)"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 9,
 | ||
|    "metadata": {},
 | ||
|    "outputs": [],
 | ||
|    "source": [
 | ||
|     "d1 = data_patch(data, 0)"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 10,
 | ||
|    "metadata": {},
 | ||
|    "outputs": [
 | ||
|     {
 | ||
|      "name": "stdout",
 | ||
|      "output_type": "stream",
 | ||
|      "text": [
 | ||
|       "Transcriptions non-zero length: 16\n",
 | ||
|       "Seq class ids non-zero length: 14\n",
 | ||
|       "Seq len: 14\n"
 | ||
|      ]
 | ||
|     }
 | ||
|    ],
 | ||
|    "source": [
 | ||
|     "trans_len = len([x for x in d1['transcriptions'] if x != 0])\n",
 | ||
|     "seq_len_nonzero = len([x for x in d1['seq_class_ids'] if x != 0])\n",
 | ||
|     "seq_len = d1['seq_len']\n",
 | ||
|     "print(f\"Transcriptions non-zero length: {trans_len}\")\n",
 | ||
|     "print(f\"Seq class ids non-zero length: {seq_len_nonzero}\")\n",
 | ||
|     "print(f\"Seq len: {seq_len}\")"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 11,
 | ||
|    "metadata": {},
 | ||
|    "outputs": [
 | ||
|     {
 | ||
|      "name": "stdout",
 | ||
|      "output_type": "stream",
 | ||
|      "text": [
 | ||
|       "Number of feature sequences: 14\n",
 | ||
|       "Shape of first sequence: (22, 512)\n"
 | ||
|      ]
 | ||
|     }
 | ||
|    ],
 | ||
|    "source": [
 | ||
|     "def create_time_windows(d1):\n",
 | ||
|     "    import numpy as np\n",
 | ||
|     "    n_time_steps = d1['n_time_steps']\n",
 | ||
|     "    seq_len = d1['seq_len']\n",
 | ||
|     "    # Create equal windows\n",
 | ||
|     "    edges = np.linspace(0, n_time_steps, seq_len + 1, dtype=int)\n",
 | ||
|     "    windows = [(edges[i], edges[i+1]) for i in range(seq_len)]\n",
 | ||
|     "    \n",
 | ||
|     "    # Extract feature sequences for each window\n",
 | ||
|     "    feature_sequences = []\n",
 | ||
|     "    for start, end in windows:\n",
 | ||
|     "        seq = d1['neural_features'][start:end, :]\n",
 | ||
|     "        feature_sequences.append(seq)\n",
 | ||
|     "    \n",
 | ||
|     "    return feature_sequences\n",
 | ||
|     "\n",
 | ||
|     "# Example usage\n",
 | ||
|     "feature_sequences = create_time_windows(d1)\n",
 | ||
|     "print(\"Number of feature sequences:\", len(feature_sequences))\n",
 | ||
|     "print(\"Shape of first sequence:\", feature_sequences[0].shape)\n"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 12,
 | ||
|    "metadata": {},
 | ||
|    "outputs": [
 | ||
|     {
 | ||
|      "name": "stdout",
 | ||
|      "output_type": "stream",
 | ||
|      "text": [
 | ||
|       "Train: 45, Val: 41, Test: 41\n",
 | ||
|       "Train files (first 3): ['/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2025.03.14/data_train.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2023.08.11/data_train.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2023.11.19/data_train.hdf5']\n",
 | ||
|       "Val files (first 3): ['/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2025.03.14/data_val.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2023.11.19/data_val.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2024.03.08/data_val.hdf5']\n",
 | ||
|       "Test files (first 3): ['/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2025.03.14/data_test.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2023.11.19/data_test.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2024.03.08/data_test.hdf5']\n"
 | ||
|      ]
 | ||
|     }
 | ||
|    ],
 | ||
|    "source": [
 | ||
|     "import os\n",
 | ||
|     "\n",
 | ||
|     "def scan_hdf5_files(base_path):\n",
 | ||
|     "    train_files = []\n",
 | ||
|     "    val_files = []\n",
 | ||
|     "    test_files = []\n",
 | ||
|     "    for root, dirs, files in os.walk(base_path):\n",
 | ||
|     "        for file in files:\n",
 | ||
|     "            if file.endswith('.hdf5'):\n",
 | ||
|     "                abs_path = os.path.abspath(os.path.join(root, file))\n",
 | ||
|     "                if 'data_train.hdf5' in file:\n",
 | ||
|     "                    train_files.append(abs_path)\n",
 | ||
|     "                elif 'data_val.hdf5' in file:\n",
 | ||
|     "                    val_files.append(abs_path)\n",
 | ||
|     "                elif 'data_test.hdf5' in file:\n",
 | ||
|     "                    test_files.append(abs_path)\n",
 | ||
|     "    return train_files, val_files, test_files\n",
 | ||
|     "\n",
 | ||
|     "# Example usage\n",
 | ||
|     "FILE_PATH = 'data/hdf5_data_final'\n",
 | ||
|     "train_list, val_list, test_list = scan_hdf5_files(FILE_PATH)\n",
 | ||
|     "print(f\"Train: {len(train_list)}, Val: {len(val_list)}, Test: {len(test_list)}\")\n",
 | ||
|     "print(\"Train files (first 3):\", train_list[:3])\n",
 | ||
|     "print(\"Val files (first 3):\", val_list[:3])\n",
 | ||
|     "print(\"Test files (first 3):\", test_list[:3])"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "metadata": {},
 | ||
|    "source": [
 | ||
|     "## 标签处理"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 13,
 | ||
|    "metadata": {},
 | ||
|    "outputs": [],
 | ||
|    "source": [
 | ||
|     "# def classify_windows_by_labels(d1):\n",
 | ||
|     "#     seq_class_ids = d1['seq_class_ids'][:d1['seq_len']]  # Take only the non-zero part\n",
 | ||
|     "#     windows = create_time_windows(d1)\n",
 | ||
|     "    \n",
 | ||
|     "#     classified_windows = {}\n",
 | ||
|     "#     for i, label in enumerate(seq_class_ids):\n",
 | ||
|     "#         char = LOGIT_TO_PHONEME[label]\n",
 | ||
|     "#         if char not in classified_windows:\n",
 | ||
|     "#             classified_windows[char] = []\n",
 | ||
|     "#         classified_windows[char].append(windows[i])\n",
 | ||
|     "    \n",
 | ||
|     "#     return classified_windows\n",
 | ||
|     "\n",
 | ||
|     "# # Example usage\n",
 | ||
|     "# classified = classify_windows_by_labels(d1)\n",
 | ||
|     "# print(\"Classified windows by label:\", classified)"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 14,
 | ||
|    "metadata": {},
 | ||
|    "outputs": [],
 | ||
|    "source": [
 | ||
|     "# import pandas as pd\n",
 | ||
|     "# import tqdm\n",
 | ||
|     "\n",
 | ||
|     "# b2txt_csv_df = pd.read_csv('data/t15_copyTaskData_description.csv')\n",
 | ||
|     "\n",
 | ||
|     "# def workflow(max_files=None):\n",
 | ||
|     "#     group_by_labels = {}\n",
 | ||
|     "#     files_to_process = train_list[:max_files] if max_files is not None else train_list\n",
 | ||
|     "#     for file_path in tqdm.tqdm(files_to_process):\n",
 | ||
|     "#         data = load_h5py_file(file_path, b2txt_csv_df)\n",
 | ||
|     "#         for i in tqdm.tqdm(range(len(data['neural_features'])), leave=False):\n",
 | ||
|     "#             # Process only the first trial for simplicity\n",
 | ||
|     "#             d1 = data_patch(data, i)\n",
 | ||
|     "#             classified = classify_windows_by_labels(d1)\n",
 | ||
|     "#             for key, value in classified.items():\n",
 | ||
|     "#                 if key not in group_by_labels:\n",
 | ||
|     "#                     group_by_labels[key] = []\n",
 | ||
|     "#                 group_by_labels[key].extend(value)\n",
 | ||
|     "#     return group_by_labels\n",
 | ||
|     "\n",
 | ||
|     "# # Example usage\n",
 | ||
|     "# result = workflow()"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "metadata": {},
 | ||
|    "source": [
 | ||
|     "### 核函数扭曲时间\n",
 | ||
|     "控制音素时间长度相同"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 15,
 | ||
|    "metadata": {},
 | ||
|    "outputs": [],
 | ||
|    "source": [
 | ||
|     "# total = 0\n",
 | ||
|     "# count = 0\n",
 | ||
|     "# time_distribution = []\n",
 | ||
|     "# for i in result.values():\n",
 | ||
|     "#     for j in i:\n",
 | ||
|     "#         total += j.shape[0]\n",
 | ||
|     "#         count += 1\n",
 | ||
|     "#         time_distribution.append(j.shape[0])\n",
 | ||
|     "# print(f\"Total time steps: {total}, Total windows: {count}, Average window length: {total/count:.2f}\")\n"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 16,
 | ||
|    "metadata": {},
 | ||
|    "outputs": [],
 | ||
|    "source": [
 | ||
|     "# import numpy as np\n",
 | ||
|     "# plt.figure(figsize=(12, 6))\n",
 | ||
|     "# plt.hist(time_distribution, bins=50, alpha=0.7, edgecolor='black')\n",
 | ||
|     "# plt.xlabel('Window Length (time steps)')\n",
 | ||
|     "# plt.ylabel('Frequency')\n",
 | ||
|     "# plt.title('Distribution of Time Window Lengths')\n",
 | ||
|     "# plt.grid(True, alpha=0.3)\n",
 | ||
|     "# plt.axvline(np.mean(time_distribution), color='red', linestyle='--', label=f'Mean: {np.mean(time_distribution):.2f}')\n",
 | ||
|     "# plt.axvline(np.median(time_distribution), color='green', linestyle='--', label=f'Median: {np.median(time_distribution):.2f}')\n",
 | ||
|     "# plt.legend()\n",
 | ||
|     "# plt.show()\n",
 | ||
|     "\n",
 | ||
|     "# print(f\"Statistics:\")\n",
 | ||
|     "# print(f\"Mean: {np.mean(time_distribution):.2f}\")\n",
 | ||
|     "# print(f\"Median: {np.median(time_distribution):.2f}\")\n",
 | ||
|     "# print(f\"Std: {np.std(time_distribution):.2f}\")\n",
 | ||
|     "# print(f\"Min: {np.min(time_distribution)}\")\n",
 | ||
|     "# print(f\"Max: {np.max(time_distribution)}\")"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 17,
 | ||
|    "metadata": {},
 | ||
|    "outputs": [],
 | ||
|    "source": [
 | ||
|     "# # 把时间序列对齐到这个长度上,然后做聚类去异常值\n",
 | ||
|     "# MEAN_WINDOWS_SIZE = 33\n",
 | ||
|     "\n",
 | ||
|     "# from scipy import interpolate\n",
 | ||
|     "# import numpy as np\n",
 | ||
|     "\n",
 | ||
|     "# def kernel_time_warp(sequence, target_length=33):\n",
 | ||
|     "#     \"\"\"\n",
 | ||
|     "#     使用核函数处理时间序列,将序列长度标准化到target_length\n",
 | ||
|     "#     - 对于长度 < target_length 的序列:使用插值扩展\n",
 | ||
|     "#     - 对于长度 > target_length 的序列:使用压缩采样\n",
 | ||
|     "#     - 对于长度 = target_length 的序列:直接返回\n",
 | ||
|     "    \n",
 | ||
|     "#     Args:\n",
 | ||
|     "#         sequence: 输入时间序列 (time_steps, features)\n",
 | ||
|     "#         target_length: 目标长度\n",
 | ||
|     "    \n",
 | ||
|     "#     Returns:\n",
 | ||
|     "#         warped_sequence: 处理后的时间序列 (target_length, features)\n",
 | ||
|     "#     \"\"\"\n",
 | ||
|     "#     original_length = sequence.shape[0]\n",
 | ||
|     "#     n_features = sequence.shape[1]\n",
 | ||
|     "    \n",
 | ||
|     "#     # 如果序列长度已经等于目标长度,直接返回\n",
 | ||
|     "#     if original_length == target_length:\n",
 | ||
|     "#         return sequence\n",
 | ||
|     "    \n",
 | ||
|     "#     # 处理边界情况\n",
 | ||
|     "#     if original_length == 0:\n",
 | ||
|     "#         return np.zeros((target_length, n_features))\n",
 | ||
|     "    \n",
 | ||
|     "#     if original_length == 1:\n",
 | ||
|     "#         # 如果只有一个时间步,复制到所有目标时间步\n",
 | ||
|     "#         return np.repeat(sequence, target_length, axis=0)\n",
 | ||
|     "    \n",
 | ||
|     "#     warped_sequence = np.zeros((target_length, n_features))\n",
 | ||
|     "    \n",
 | ||
|     "#     if original_length > target_length:\n",
 | ||
|     "#         # 压缩:长序列 -> 短序列\n",
 | ||
|     "#         # 使用均匀采样 + 局部平均的方式压缩\n",
 | ||
|     "#         compression_ratio = original_length / target_length\n",
 | ||
|     "        \n",
 | ||
|     "#         for i in range(target_length):\n",
 | ||
|     "#             # 计算当前目标位置对应的原始序列范围\n",
 | ||
|     "#             start_idx = int(i * compression_ratio)\n",
 | ||
|     "#             end_idx = int((i + 1) * compression_ratio)\n",
 | ||
|     "            \n",
 | ||
|     "#             # 确保不超出边界\n",
 | ||
|     "#             start_idx = max(0, start_idx)\n",
 | ||
|     "#             end_idx = min(original_length, end_idx)\n",
 | ||
|     "            \n",
 | ||
|     "#             if start_idx == end_idx:\n",
 | ||
|     "#                 # 避免空范围\n",
 | ||
|     "#                 end_idx = min(start_idx + 1, original_length)\n",
 | ||
|     "            \n",
 | ||
|     "#             # 对该范围内的数据取平均(压缩)\n",
 | ||
|     "#             warped_sequence[i] = np.mean(sequence[start_idx:end_idx], axis=0)\n",
 | ||
|     "    \n",
 | ||
|     "#     else:\n",
 | ||
|     "#         # 扩展:短序列 -> 长序列\n",
 | ||
|     "#         # 使用插值的方式扩展\n",
 | ||
|     "#         original_indices = np.linspace(0, 1, original_length)\n",
 | ||
|     "#         target_indices = np.linspace(0, 1, target_length)\n",
 | ||
|     "        \n",
 | ||
|     "#         for feature_idx in range(n_features):\n",
 | ||
|     "#             # 根据原始序列长度选择插值方法\n",
 | ||
|     "#             if original_length >= 3:\n",
 | ||
|     "#                 # 对于长度>=3的序列,使用三次样条插值\n",
 | ||
|     "#                 interpolator = interpolate.interp1d(\n",
 | ||
|     "#                     original_indices, \n",
 | ||
|     "#                     sequence[:, feature_idx], \n",
 | ||
|     "#                     kind='cubic', \n",
 | ||
|     "#                     bounds_error=False, \n",
 | ||
|     "#                     fill_value='extrapolate'\n",
 | ||
|     "#                 )\n",
 | ||
|     "#             else:\n",
 | ||
|     "#                 # 对于长度=2的序列,使用线性插值\n",
 | ||
|     "#                 interpolator = interpolate.interp1d(\n",
 | ||
|     "#                     original_indices, \n",
 | ||
|     "#                     sequence[:, feature_idx], \n",
 | ||
|     "#                     kind='linear', \n",
 | ||
|     "#                     bounds_error=False, \n",
 | ||
|     "#                     fill_value='extrapolate'\n",
 | ||
|     "#                 )\n",
 | ||
|     "            \n",
 | ||
|     "#             warped_sequence[:, feature_idx] = interpolator(target_indices)\n",
 | ||
|     "    \n",
 | ||
|     "#     return warped_sequence\n",
 | ||
|     "\n",
 | ||
|     "# def gaussian_kernel_weight(x, sigma=0.1):\n",
 | ||
|     "#     \"\"\"\n",
 | ||
|     "#     高斯核函数权重,用于平滑处理\n",
 | ||
|     "#     \"\"\"\n",
 | ||
|     "#     return np.exp(-0.5 * (x / sigma) ** 2)\n",
 | ||
|     "\n",
 | ||
|     "# def process_result_with_kernel(result_dict, target_length=33):\n",
 | ||
|     "#     \"\"\"\n",
 | ||
|     "#     使用核函数处理result字典中的所有时间序列\n",
 | ||
|     "    \n",
 | ||
|     "#     Args:\n",
 | ||
|     "#         result_dict: 包含时间序列的字典\n",
 | ||
|     "#         target_length: 目标长度\n",
 | ||
|     "    \n",
 | ||
|     "#     Returns:\n",
 | ||
|     "#         processed_result: 处理后的字典\n",
 | ||
|     "#     \"\"\"\n",
 | ||
|     "#     processed_result = {}\n",
 | ||
|     "    \n",
 | ||
|     "#     print(\"Processing time series with kernel warping...\")\n",
 | ||
|     "#     print(f\"Target length: {target_length}\")\n",
 | ||
|     "    \n",
 | ||
|     "#     # 统计不同长度的序列\n",
 | ||
|     "#     length_stats = {}\n",
 | ||
|     "#     total_sequences = 0\n",
 | ||
|     "    \n",
 | ||
|     "#     for label, sequences in tqdm.tqdm(result_dict.items()):\n",
 | ||
|     "#         processed_sequences = []\n",
 | ||
|     "        \n",
 | ||
|     "#         for seq in sequences:\n",
 | ||
|     "#             original_length = seq.shape[0]\n",
 | ||
|     "            \n",
 | ||
|     "#             # 统计长度分布\n",
 | ||
|     "#             if original_length not in length_stats:\n",
 | ||
|     "#                 length_stats[original_length] = 0\n",
 | ||
|     "#             length_stats[original_length] += 1\n",
 | ||
|     "#             total_sequences += 1\n",
 | ||
|     "            \n",
 | ||
|     "#             # 应用核函数时间扭曲(包括压缩和插值)\n",
 | ||
|     "#             warped_seq = kernel_time_warp(seq, target_length)\n",
 | ||
|     "#             processed_sequences.append(warped_seq)\n",
 | ||
|     "        \n",
 | ||
|     "#         processed_result[label] = processed_sequences\n",
 | ||
|     "#         print(f\"Label '{label}': {len(sequences)} sequences -> {len(processed_sequences)} sequences\")\n",
 | ||
|     "    \n",
 | ||
|     "#     # 打印长度统计信息\n",
 | ||
|     "#     print(f\"\\n原始序列长度分布:\")\n",
 | ||
|     "#     sorted_lengths = sorted(length_stats.items())\n",
 | ||
|     "#     short_count = sum(count for length, count in sorted_lengths if length < target_length)\n",
 | ||
|     "#     equal_count = sum(count for length, count in sorted_lengths if length == target_length)\n",
 | ||
|     "#     long_count = sum(count for length, count in sorted_lengths if length > target_length)\n",
 | ||
|     "    \n",
 | ||
|     "#     print(f\"短于目标长度({target_length}),需要插值扩展: {short_count} 个序列\")\n",
 | ||
|     "#     print(f\"等于目标长度({target_length}),无需处理: {equal_count} 个序列\")\n",
 | ||
|     "#     print(f\"长于目标长度({target_length}),需要压缩: {long_count} 个序列\")\n",
 | ||
|     "#     print(f\"总序列数: {total_sequences}\")\n",
 | ||
|     "    \n",
 | ||
|     "#     if len(sorted_lengths) <= 20:  # 如果长度种类不多,显示详细分布\n",
 | ||
|     "#         print(\"\\n详细长度分布:\")\n",
 | ||
|     "#         for length, count in sorted_lengths:\n",
 | ||
|     "#             percentage = (count / total_sequences) * 100\n",
 | ||
|     "#             operation = \"\"\n",
 | ||
|     "#             if length < target_length:\n",
 | ||
|     "#                 operation = \" (插值扩展)\"\n",
 | ||
|     "#             elif length > target_length:\n",
 | ||
|     "#                 operation = \" (压缩)\"\n",
 | ||
|     "#             else:\n",
 | ||
|     "#                 operation = \" (无需处理)\"\n",
 | ||
|     "#             print(f\"  长度 {length}: {count} 个 ({percentage:.1f}%){operation}\")\n",
 | ||
|     "    \n",
 | ||
|     "#     return processed_result\n",
 | ||
|     "\n",
 | ||
|     "# # 处理result字典\n",
 | ||
|     "# processed_result = process_result_with_kernel(result, MEAN_WINDOWS_SIZE)\n",
 | ||
|     "\n",
 | ||
|     "# # 验证处理结果\n",
 | ||
|     "# print(\"\\n处理后的统计信息:\")\n",
 | ||
|     "# total_sequences = 0\n",
 | ||
|     "# for label, sequences in processed_result.items():\n",
 | ||
|     "#     total_sequences += len(sequences)\n",
 | ||
|     "#     if sequences:  # 如果列表不为空\n",
 | ||
|     "#         print(f\"Label '{label}': {len(sequences)} sequences, shape: {sequences[0].shape}\")\n",
 | ||
|     "\n",
 | ||
|     "# print(f\"总共处理了 {total_sequences} 个时间序列,目标长度: {MEAN_WINDOWS_SIZE}\")\n",
 | ||
|     "\n",
 | ||
|     "# # 验证所有序列现在都是目标长度\n",
 | ||
|     "# all_correct_length = True\n",
 | ||
|     "# for label, sequences in processed_result.items():\n",
 | ||
|     "#     for seq in sequences:\n",
 | ||
|     "#         if seq.shape[0] != MEAN_WINDOWS_SIZE:\n",
 | ||
|     "#             print(f\"错误: 发现长度不正确的序列 - Label: {label}, Shape: {seq.shape}\")\n",
 | ||
|     "#             all_correct_length = False\n",
 | ||
|     "#             break\n",
 | ||
|     "#     if not all_correct_length:\n",
 | ||
|     "#         break\n",
 | ||
|     "\n",
 | ||
|     "# if all_correct_length:\n",
 | ||
|     "#     print(f\"✅ 验证通过: 所有序列长度都已正确调整为 {MEAN_WINDOWS_SIZE}\")\n",
 | ||
|     "#     print(\"  - 长序列已通过压缩(局部平均)缩短\")\n",
 | ||
|     "#     print(\"  - 短序列已通过插值扩展\")"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "metadata": {},
 | ||
|    "source": [
 | ||
|     "去掉异常的标签,去除的时候记得保存正常标签的原始索引,我们可能不用去除后的来训练模型。而是训练适应多个时间大小窗口的模型,通过单独扫描的WER大小来确定权重。再把两个一起扫描,按照权重赋值,同样用极大值抑制或者CTC来处理。\n",
 | ||
|     "建议CTC。毕竟我们多个长度窗口的模型已经和RNN差距不大了。"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 18,
 | ||
|    "metadata": {},
 | ||
|    "outputs": [],
 | ||
|    "source": [
 | ||
|     "# import pickle\n",
 | ||
|     "# with open('processed_time_series.pkl', 'wb') as f:\n",
 | ||
|     "#     pickle.dump(processed_result, f)\n",
 | ||
|     "    \n",
 | ||
|     "# with open('time_series_format.pkl', 'wb') as f:\n",
 | ||
|     "#     pickle.dump(result, f)"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 19,
 | ||
|    "metadata": {},
 | ||
|    "outputs": [],
 | ||
|    "source": [
 | ||
|     "# processed_result.keys()"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 20,
 | ||
|    "metadata": {},
 | ||
|    "outputs": [],
 | ||
|    "source": [
 | ||
|     "# from sklearn.cluster import KMeans\n",
 | ||
|     "# from sklearn.metrics import silhouette_score\n",
 | ||
|     "# from sklearn.metrics import silhouette_samples\n",
 | ||
|     "\n",
 | ||
|     "# # for key, value in processed_result.items():\n",
 | ||
|     "# key = 'IH'\n",
 | ||
|     "# value = processed_result[key]\n",
 | ||
|     "# print(f\"Label: {key}, Number of sequences: {len(value)}, Shape of first sequence: {value[0].shape if value else 'N/A'}\")\n",
 | ||
|     "\n",
 | ||
|     "# # Apply KMeans clustering to the sequences for the selected label\n",
 | ||
|     "\n",
 | ||
|     "# # First, we need to reshape the sequences to 2D for clustering\n",
 | ||
|     "# # Since all sequences now have the same length (33), we can flatten them\n",
 | ||
|     "# sequences = value\n",
 | ||
|     "# flattened_sequences = []\n",
 | ||
|     "\n",
 | ||
|     "# for seq in sequences:\n",
 | ||
|     "#     # Flatten each sequence to 1D\n",
 | ||
|     "#     flattened_seq = seq.flatten()\n",
 | ||
|     "#     flattened_sequences.append(flattened_seq)\n",
 | ||
|     "\n",
 | ||
|     "# flattened_sequences = np.array(flattened_sequences)\n",
 | ||
|     "# print(f\"Flattened sequences shape: {flattened_sequences.shape}\")\n",
 | ||
|     "\n",
 | ||
|     "# # Perform KMeans clustering\n",
 | ||
|     "# n_clusters = 5  # You can adjust this number\n",
 | ||
|     "# kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)\n",
 | ||
|     "# cluster_labels = kmeans.fit_predict(flattened_sequences)\n",
 | ||
|     "\n",
 | ||
|     "# # Calculate silhouette score to evaluate clustering quality\n",
 | ||
|     "# silhouette_avg = silhouette_score(flattened_sequences, cluster_labels)\n",
 | ||
|     "\n",
 | ||
|     "# print(f\"Number of clusters: {n_clusters}\")\n",
 | ||
|     "# print(f\"Silhouette score: {silhouette_avg:.3f}\")\n",
 | ||
|     "# print(f\"Cluster distribution: {np.bincount(cluster_labels)}\")\n",
 | ||
|     "\n",
 | ||
|     "# # Visualize clustering results\n",
 | ||
|     "# plt.figure(figsize=(12, 8))\n",
 | ||
|     "\n",
 | ||
|     "# # Plot 1: Cluster distribution\n",
 | ||
|     "# plt.subplot(2, 2, 1)\n",
 | ||
|     "# unique_labels, counts = np.unique(cluster_labels, return_counts=True)\n",
 | ||
|     "# plt.bar(unique_labels, counts)\n",
 | ||
|     "# plt.xlabel('Cluster')\n",
 | ||
|     "# plt.ylabel('Number of Sequences')\n",
 | ||
|     "# plt.title(f'Cluster Distribution for Label \"{key}\"')\n",
 | ||
|     "\n",
 | ||
|     "# # Plot 2: First few dimensions of the data colored by cluster\n",
 | ||
|     "# plt.subplot(2, 2, 2)\n",
 | ||
|     "# for i in range(n_clusters):\n",
 | ||
|     "#     mask = cluster_labels == i\n",
 | ||
|     "#     plt.scatter(flattened_sequences[mask, 0], flattened_sequences[mask, 1], \n",
 | ||
|     "#                 label=f'Cluster {i}', alpha=0.6)\n",
 | ||
|     "# plt.xlabel('Feature 0')\n",
 | ||
|     "# plt.ylabel('Feature 1')\n",
 | ||
|     "# plt.title('Clusters in Feature Space (First 2 Dimensions)')\n",
 | ||
|     "# plt.legend()\n",
 | ||
|     "\n",
 | ||
|     "# # Plot 3: Silhouette analysis\n",
 | ||
|     "# plt.subplot(2, 2, 3)\n",
 | ||
|     "# silhouette_vals = silhouette_samples(flattened_sequences, cluster_labels)\n",
 | ||
|     "# y_lower = 10\n",
 | ||
|     "# for i in range(n_clusters):\n",
 | ||
|     "#     cluster_silhouette_vals = silhouette_vals[cluster_labels == i]\n",
 | ||
|     "#     cluster_silhouette_vals.sort()\n",
 | ||
|     "    \n",
 | ||
|     "#     size_cluster_i = cluster_silhouette_vals.shape[0]\n",
 | ||
|     "#     y_upper = y_lower + size_cluster_i\n",
 | ||
|     "    \n",
 | ||
|     "#     color = plt.cm.nipy_spectral(float(i) / n_clusters)\n",
 | ||
|     "#     plt.fill_betweenx(np.arange(y_lower, y_upper), 0, cluster_silhouette_vals,\n",
 | ||
|     "#                       facecolor=color, edgecolor=color, alpha=0.7)\n",
 | ||
|     "    \n",
 | ||
|     "#     plt.text(-0.05, y_lower + 0.5 * size_cluster_i, str(i))\n",
 | ||
|     "#     y_lower = y_upper + 10\n",
 | ||
|     "\n",
 | ||
|     "# plt.axvline(x=silhouette_avg, color=\"red\", linestyle=\"--\")\n",
 | ||
|     "# plt.xlabel('Silhouette coefficient values')\n",
 | ||
|     "# plt.ylabel('Cluster label')\n",
 | ||
|     "# plt.title('Silhouette Analysis')\n",
 | ||
|     "\n",
 | ||
|     "# # Plot 4: Try different numbers of clusters\n",
 | ||
|     "# plt.subplot(2, 2, 4)\n",
 | ||
|     "# cluster_range = range(2, min(11, len(sequences)//2))\n",
 | ||
|     "# silhouette_scores = []\n",
 | ||
|     "# inertias = []\n",
 | ||
|     "\n",
 | ||
|     "# for n_clust in cluster_range:\n",
 | ||
|     "#     kmeans_temp = KMeans(n_clusters=n_clust, random_state=42, n_init=10)\n",
 | ||
|     "#     cluster_labels_temp = kmeans_temp.fit_predict(flattened_sequences)\n",
 | ||
|     "#     silhouette_scores.append(silhouette_score(flattened_sequences, cluster_labels_temp))\n",
 | ||
|     "#     inertias.append(kmeans_temp.inertia_)\n",
 | ||
|     "\n",
 | ||
|     "# plt.plot(cluster_range, silhouette_scores, 'bo-')\n",
 | ||
|     "# plt.xlabel('Number of Clusters')\n",
 | ||
|     "# plt.ylabel('Average Silhouette Score')\n",
 | ||
|     "# plt.title('Silhouette Score vs Number of Clusters')\n",
 | ||
|     "# plt.grid(True)\n",
 | ||
|     "\n",
 | ||
|     "# plt.tight_layout()\n",
 | ||
|     "# plt.show()\n",
 | ||
|     "\n",
 | ||
|     "# # Print cluster centers information\n",
 | ||
|     "# print(f\"\\nCluster centers shape: {kmeans.cluster_centers_.shape}\")\n",
 | ||
|     "# print(f\"Each cluster center represents the average pattern for sequences in that cluster\")"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 21,
 | ||
|    "metadata": {},
 | ||
|    "outputs": [],
 | ||
|    "source": [
 | ||
|     "# # 余弦距离分析和距离度量对比\n",
 | ||
|     "# print(\"\\n\" + \"=\"*70)\n",
 | ||
|     "# print(\"余弦距离分析\")\n",
 | ||
|     "# print(\"=\"*70)\n",
 | ||
|     "\n",
 | ||
|     "# # 5. 计算余弦距离\n",
 | ||
|     "# print(\"\\n计算余弦距离...\")\n",
 | ||
|     "# cosine_matrix, phonemes_cosine = calculate_inter_class_distances(centroids, 'cosine')\n",
 | ||
|     "\n",
 | ||
|     "# # 6. 可视化余弦距离\n",
 | ||
|     "# print(\"\\n可视化余弦距离矩阵...\")\n",
 | ||
|     "# df_cosine, most_sim_cos, most_diff_cos = visualize_distance_matrix(\n",
 | ||
|     "#     cosine_matrix, phonemes_cosine, 'cosine'\n",
 | ||
|     "# )\n",
 | ||
|     "\n",
 | ||
|     "# # 7. 分析相似音素群组(余弦距离)\n",
 | ||
|     "# similar_pairs_cos, threshold_cos = analyze_phoneme_groups(cosine_matrix, phonemes_cosine)\n",
 | ||
|     "\n",
 | ||
|     "# print(f\"\\n余弦距离分析完成!\")\n",
 | ||
|     "# print(f\"最相似音素对: {most_sim_cos}\")\n",
 | ||
|     "# print(f\"最不相似音素对: {most_diff_cos}\")\n",
 | ||
|     "\n",
 | ||
|     "# # 8. 比较两种距离度量\n",
 | ||
|     "# print(\"\\n\" + \"=\"*70)\n",
 | ||
|     "# print(\"距离度量对比分析\")\n",
 | ||
|     "# print(\"=\"*70)\n",
 | ||
|     "\n",
 | ||
|     "# def compare_distance_metrics(euclidean_matrix, cosine_matrix, phonemes):\n",
 | ||
|     "#     \"\"\"\n",
 | ||
|     "#     比较不同距离度量的结果\n",
 | ||
|     "#     \"\"\"\n",
 | ||
|     "#     # 提取上三角矩阵的距离值\n",
 | ||
|     "#     upper_triangle_indices = np.triu_indices_from(euclidean_matrix, k=1)\n",
 | ||
|     "#     euclidean_distances = euclidean_matrix[upper_triangle_indices]\n",
 | ||
|     "#     cosine_distances = cosine_matrix[upper_triangle_indices]\n",
 | ||
|     "    \n",
 | ||
|     "#     # 计算相关性\n",
 | ||
|     "#     correlation = np.corrcoef(euclidean_distances, cosine_distances)[0, 1]\n",
 | ||
|     "    \n",
 | ||
|     "#     # 创建比较图\n",
 | ||
|     "#     fig, axes = plt.subplots(2, 2, figsize=(16, 12))\n",
 | ||
|     "    \n",
 | ||
|     "#     # 图1: 距离分布对比\n",
 | ||
|     "#     ax1 = axes[0, 0]\n",
 | ||
|     "#     ax1.hist(euclidean_distances, bins=30, alpha=0.6, label='欧氏距离', color='blue')\n",
 | ||
|     "#     ax1.hist(cosine_distances, bins=30, alpha=0.6, label='余弦距离', color='red')\n",
 | ||
|     "#     ax1.set_xlabel('距离值')\n",
 | ||
|     "#     ax1.set_ylabel('频次')\n",
 | ||
|     "#     ax1.set_title('距离分布对比')\n",
 | ||
|     "#     ax1.legend()\n",
 | ||
|     "#     ax1.grid(True, alpha=0.3)\n",
 | ||
|     "    \n",
 | ||
|     "#     # 图2: 距离相关性散点图\n",
 | ||
|     "#     ax2 = axes[0, 1]\n",
 | ||
|     "#     ax2.scatter(euclidean_distances, cosine_distances, alpha=0.6, s=10)\n",
 | ||
|     "#     ax2.set_xlabel('欧氏距离')\n",
 | ||
|     "#     ax2.set_ylabel('余弦距离')\n",
 | ||
|     "#     ax2.set_title(f'距离度量相关性 (r={correlation:.3f})')\n",
 | ||
|     "#     ax2.grid(True, alpha=0.3)\n",
 | ||
|     "    \n",
 | ||
|     "#     # 添加拟合线\n",
 | ||
|     "#     z = np.polyfit(euclidean_distances, cosine_distances, 1)\n",
 | ||
|     "#     p = np.poly1d(z)\n",
 | ||
|     "#     ax2.plot(euclidean_distances, p(euclidean_distances), \"r--\", alpha=0.8)\n",
 | ||
|     "    \n",
 | ||
|     "#     # 图3: 最相似音素对比较\n",
 | ||
|     "#     ax3 = axes[1, 0]\n",
 | ||
|     "#     ax3.axis('off')\n",
 | ||
|     "    \n",
 | ||
|     "#     # 获取每种距离度量下的前10对最相似音素\n",
 | ||
|     "#     eucl_top10 = similar_pairs_eucl[:10]\n",
 | ||
|     "#     cos_top10 = similar_pairs_cos[:10]\n",
 | ||
|     "    \n",
 | ||
|     "#     comparison_text = \"最相似音素对比较 (前10对)\\n\\n\"\n",
 | ||
|     "#     comparison_text += f\"{'欧氏距离':<30} {'余弦距离':<30}\\n\"\n",
 | ||
|     "#     comparison_text += \"-\" * 60 + \"\\n\"\n",
 | ||
|     "    \n",
 | ||
|     "#     for i in range(min(10, len(eucl_top10), len(cos_top10))):\n",
 | ||
|     "#         eucl_pair = f\"{eucl_top10[i][0]}-{eucl_top10[i][1]} ({eucl_top10[i][2]:.3f})\"\n",
 | ||
|     "#         cos_pair = f\"{cos_top10[i][0]}-{cos_top10[i][1]} ({cos_top10[i][2]:.4f})\"\n",
 | ||
|     "#         comparison_text += f\"{eucl_pair:<30} {cos_pair:<30}\\n\"\n",
 | ||
|     "    \n",
 | ||
|     "#     ax3.text(0.05, 0.95, comparison_text, transform=ax3.transAxes, fontsize=9,\n",
 | ||
|     "#             verticalalignment='top', fontfamily='monospace',\n",
 | ||
|     "#             bbox=dict(boxstyle='round', facecolor='lightyellow', alpha=0.8))\n",
 | ||
|     "    \n",
 | ||
|     "#     # 图4: 统计对比\n",
 | ||
|     "#     ax4 = axes[1, 1]\n",
 | ||
|     "#     ax4.axis('off')\n",
 | ||
|     "    \n",
 | ||
|     "#     # 计算统计信息\n",
 | ||
|     "#     eucl_stats = {\n",
 | ||
|     "#         'mean': np.mean(euclidean_distances),\n",
 | ||
|     "#         'median': np.median(euclidean_distances),\n",
 | ||
|     "#         'std': np.std(euclidean_distances),\n",
 | ||
|     "#         'min': np.min(euclidean_distances),\n",
 | ||
|     "#         'max': np.max(euclidean_distances)\n",
 | ||
|     "#     }\n",
 | ||
|     "    \n",
 | ||
|     "#     cos_stats = {\n",
 | ||
|     "#         'mean': np.mean(cosine_distances),\n",
 | ||
|     "#         'median': np.median(cosine_distances),\n",
 | ||
|     "#         'std': np.std(cosine_distances),\n",
 | ||
|     "#         'min': np.min(cosine_distances),\n",
 | ||
|     "#         'max': np.max(cosine_distances)\n",
 | ||
|     "#     }\n",
 | ||
|     "    \n",
 | ||
|     "#     stats_text = f\"\"\"\n",
 | ||
|     "# 距离度量统计对比\n",
 | ||
|     "\n",
 | ||
|     "# 指标            欧氏距离        余弦距离\n",
 | ||
|     "# {'='*45}\n",
 | ||
|     "# 平均值          {eucl_stats['mean']:.4f}        {cos_stats['mean']:.4f}\n",
 | ||
|     "# 中位数          {eucl_stats['median']:.4f}        {cos_stats['median']:.4f}\n",
 | ||
|     "# 标准差          {eucl_stats['std']:.4f}        {cos_stats['std']:.4f}\n",
 | ||
|     "# 最小值          {eucl_stats['min']:.4f}        {cos_stats['min']:.4f}\n",
 | ||
|     "# 最大值          {eucl_stats['max']:.4f}        {cos_stats['max']:.4f}\n",
 | ||
|     "\n",
 | ||
|     "# 相关性系数: {correlation:.4f}\n",
 | ||
|     "\n",
 | ||
|     "# 解释:\n",
 | ||
|     "# - 欧氏距离: 测量特征空间中的直线距离\n",
 | ||
|     "# - 余弦距离: 测量向量间的角度差异\n",
 | ||
|     "# - 高相关性表明两种度量捕获相似的模式\n",
 | ||
|     "# \"\"\"\n",
 | ||
|     "    \n",
 | ||
|     "#     ax4.text(0.05, 0.95, stats_text, transform=ax4.transAxes, fontsize=9,\n",
 | ||
|     "#             verticalalignment='top', fontfamily='monospace',\n",
 | ||
|     "#             bbox=dict(boxstyle='round', facecolor='lightcyan', alpha=0.8))\n",
 | ||
|     "    \n",
 | ||
|     "#     plt.tight_layout()\n",
 | ||
|     "#     plt.show()\n",
 | ||
|     "    \n",
 | ||
|     "#     return correlation, eucl_stats, cos_stats\n",
 | ||
|     "\n",
 | ||
|     "# # 执行距离度量比较\n",
 | ||
|     "# correlation, eucl_stats, cos_stats = compare_distance_metrics(\n",
 | ||
|     "#     euclidean_matrix, cosine_matrix, phonemes\n",
 | ||
|     "# )\n",
 | ||
|     "\n",
 | ||
|     "# # 9. 音素聚类分析\n",
 | ||
|     "# print(\"\\n\" + \"=\"*70)\n",
 | ||
|     "# print(\"基于距离的音素聚类分析\")\n",
 | ||
|     "# print(\"=\"*70)\n",
 | ||
|     "\n",
 | ||
|     "# def analyze_phoneme_clusters(distance_matrix, phonemes, n_clusters_range=[3, 4, 5, 6]):\n",
 | ||
|     "#     \"\"\"\n",
 | ||
|     "#     使用不同数量的聚类分析音素群组\n",
 | ||
|     "#     \"\"\"\n",
 | ||
|     "#     from sklearn.cluster import AgglomerativeClustering\n",
 | ||
|     "    \n",
 | ||
|     "#     results = {}\n",
 | ||
|     "    \n",
 | ||
|     "#     fig, axes = plt.subplots(2, 2, figsize=(16, 12))\n",
 | ||
|     "#     axes = axes.flatten()\n",
 | ||
|     "    \n",
 | ||
|     "#     for i, n_clusters in enumerate(n_clusters_range):\n",
 | ||
|     "#         if i >= 4:  # 最多显示4个图\n",
 | ||
|     "#             break\n",
 | ||
|     "            \n",
 | ||
|     "#         # 执行层次聚类\n",
 | ||
|     "#         clustering = AgglomerativeClustering(\n",
 | ||
|     "#             n_clusters=n_clusters, \n",
 | ||
|     "#             metric='precomputed',\n",
 | ||
|     "#             linkage='average'\n",
 | ||
|     "#         )\n",
 | ||
|     "        \n",
 | ||
|     "#         cluster_labels = clustering.fit_predict(distance_matrix)\n",
 | ||
|     "        \n",
 | ||
|     "#         # 分析聚类结果\n",
 | ||
|     "#         clusters = {}\n",
 | ||
|     "#         for phoneme, label in zip(phonemes, cluster_labels):\n",
 | ||
|     "#             if label not in clusters:\n",
 | ||
|     "#                 clusters[label] = []\n",
 | ||
|     "#             clusters[label].append(phoneme)\n",
 | ||
|     "        \n",
 | ||
|     "#         results[n_clusters] = clusters\n",
 | ||
|     "        \n",
 | ||
|     "#         # 可视化聚类结果\n",
 | ||
|     "#         ax = axes[i]\n",
 | ||
|     "        \n",
 | ||
|     "#         # 创建颜色映射\n",
 | ||
|     "#         colors = plt.cm.Set3(np.linspace(0, 1, n_clusters))\n",
 | ||
|     "        \n",
 | ||
|     "#         # 为每个音素分配颜色\n",
 | ||
|     "#         phoneme_colors = [colors[cluster_labels[j]] for j in range(len(phonemes))]\n",
 | ||
|     "        \n",
 | ||
|     "#         # 使用PCA降维可视化(使用质心数据)\n",
 | ||
|     "#         from sklearn.decomposition import PCA\n",
 | ||
|     "        \n",
 | ||
|     "#         # 重新获取质心矩阵\n",
 | ||
|     "#         centroid_matrix = np.array([centroids[phoneme] for phoneme in phonemes])\n",
 | ||
|     "        \n",
 | ||
|     "#         if centroid_matrix.shape[1] > 2:\n",
 | ||
|     "#             pca = PCA(n_components=2)\n",
 | ||
|     "#             pca_result = pca.fit_transform(centroid_matrix)\n",
 | ||
|     "#         else:\n",
 | ||
|     "#             pca_result = centroid_matrix\n",
 | ||
|     "        \n",
 | ||
|     "#         # 绘制散点图\n",
 | ||
|     "#         for cluster_id in range(n_clusters):\n",
 | ||
|     "#             mask = cluster_labels == cluster_id\n",
 | ||
|     "#             ax.scatter(pca_result[mask, 0], pca_result[mask, 1], \n",
 | ||
|     "#                       c=[colors[cluster_id]], label=f'聚类 {cluster_id}', s=50, alpha=0.7)\n",
 | ||
|     "            \n",
 | ||
|     "#             # 添加音素标签\n",
 | ||
|     "#             for j, phoneme in enumerate(phonemes):\n",
 | ||
|     "#                 if cluster_labels[j] == cluster_id:\n",
 | ||
|     "#                     ax.annotate(phoneme, (pca_result[j, 0], pca_result[j, 1]), \n",
 | ||
|     "#                                xytext=(5, 5), textcoords='offset points', fontsize=8)\n",
 | ||
|     "        \n",
 | ||
|     "#         ax.set_title(f'{n_clusters} 个聚类')\n",
 | ||
|     "#         ax.set_xlabel('PC1')\n",
 | ||
|     "#         ax.set_ylabel('PC2')\n",
 | ||
|     "#         ax.legend()\n",
 | ||
|     "#         ax.grid(True, alpha=0.3)\n",
 | ||
|     "    \n",
 | ||
|     "#     plt.tight_layout()\n",
 | ||
|     "#     plt.show()\n",
 | ||
|     "    \n",
 | ||
|     "#     # 打印聚类结果\n",
 | ||
|     "#     for n_clusters, clusters in results.items():\n",
 | ||
|     "#         print(f\"\\n{n_clusters} 个聚类的结果:\")\n",
 | ||
|     "#         for cluster_id, phonemes_in_cluster in clusters.items():\n",
 | ||
|     "#             print(f\"  聚类 {cluster_id}: {', '.join(phonemes_in_cluster)}\")\n",
 | ||
|     "    \n",
 | ||
|     "#     return results\n",
 | ||
|     "\n",
 | ||
|     "# # 执行音素聚类分析\n",
 | ||
|     "# clustering_results = analyze_phoneme_clusters(euclidean_matrix, phonemes)\n",
 | ||
|     "\n",
 | ||
|     "# print(f\"\\n音素类间距离分析完成!\")\n",
 | ||
|     "# print(f\"发现了丰富的音素相似性模式和聚类结构。\")"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 22,
 | ||
|    "metadata": {},
 | ||
|    "outputs": [],
 | ||
|    "source": [
 | ||
|     "# # 优化的DBSCAN异常值检测分析\n",
 | ||
|     "# import numpy as np\n",
 | ||
|     "# import matplotlib.pyplot as plt\n",
 | ||
|     "# from sklearn.cluster import DBSCAN\n",
 | ||
|     "# from sklearn.preprocessing import StandardScaler\n",
 | ||
|     "# from sklearn.decomposition import PCA\n",
 | ||
|     "# from sklearn.metrics import silhouette_score\n",
 | ||
|     "# import pandas as pd\n",
 | ||
|     "\n",
 | ||
|     "# print(\"=\"*70)\n",
 | ||
|     "# print(\"优化的DBSCAN异常值检测分析\")\n",
 | ||
|     "# print(\"=\"*70)\n",
 | ||
|     "\n",
 | ||
|     "# def determine_optimal_pca_components(data, min_variance_ratio=0.80, max_components=500):\n",
 | ||
|     "#     \"\"\"\n",
 | ||
|     "#     确定最优的PCA组件数,保证解释方差比达到要求\n",
 | ||
|     "#     \"\"\"\n",
 | ||
|     "#     print(f\"原始数据维度: {data.shape}\")\n",
 | ||
|     "    \n",
 | ||
|     "#     # 先用少量组件快速估计\n",
 | ||
|     "#     n_samples, n_features = data.shape\n",
 | ||
|     "    \n",
 | ||
|     "#     # 确保组件数不超过样本数和特征数的最小值\n",
 | ||
|     "#     max_possible_components = min(n_samples, n_features, max_components)\n",
 | ||
|     "    \n",
 | ||
|     "#     # 快速测试不同的组件数\n",
 | ||
|     "#     test_components = [50, 100, 200, 300, 400, 500]\n",
 | ||
|     "#     test_components = [c for c in test_components if c <= max_possible_components]\n",
 | ||
|     "    \n",
 | ||
|     "#     if not test_components:\n",
 | ||
|     "#         test_components = [min(50, max_possible_components)]\n",
 | ||
|     "    \n",
 | ||
|     "#     print(f\"测试的组件数: {test_components}\")\n",
 | ||
|     "    \n",
 | ||
|     "#     best_components = test_components[0]\n",
 | ||
|     "#     best_ratio = 0\n",
 | ||
|     "    \n",
 | ||
|     "#     for n_comp in test_components:\n",
 | ||
|     "#         pca_temp = PCA(n_components=n_comp, random_state=42)\n",
 | ||
|     "#         pca_temp.fit(data)\n",
 | ||
|     "#         variance_ratio = np.sum(pca_temp.explained_variance_ratio_)\n",
 | ||
|     "        \n",
 | ||
|     "#         print(f\"  {n_comp} 组件: 解释方差比 = {variance_ratio:.4f}\")\n",
 | ||
|     "        \n",
 | ||
|     "#         if variance_ratio >= min_variance_ratio:\n",
 | ||
|     "#             best_components = n_comp\n",
 | ||
|     "#             best_ratio = variance_ratio\n",
 | ||
|     "#             break\n",
 | ||
|     "#         elif variance_ratio > best_ratio:\n",
 | ||
|     "#             best_components = n_comp\n",
 | ||
|     "#             best_ratio = variance_ratio\n",
 | ||
|     "    \n",
 | ||
|     "#     print(f\"选择 {best_components} 个组件 (解释方差比: {best_ratio:.4f})\")\n",
 | ||
|     "#     return best_components\n",
 | ||
|     "\n",
 | ||
|     "# def smart_dbscan_outlier_detection(processed_result, target_phonemes=None, min_variance_ratio=0.75):\n",
 | ||
|     "#     \"\"\"\n",
 | ||
|     "#     智能DBSCAN异常值检测,使用合理的PCA降维\n",
 | ||
|     "#     \"\"\"\n",
 | ||
|     "#     outlier_results = {}\n",
 | ||
|     "    \n",
 | ||
|     "#     # 如果没有指定目标音素,选择样本数适中的音素\n",
 | ||
|     "#     if target_phonemes is None:\n",
 | ||
|     "#         phoneme_counts = [(p, len(seqs)) for p, seqs in processed_result.items()]\n",
 | ||
|     "#         # 选择样本数在50-1000之间的音素\n",
 | ||
|     "#         target_phonemes = [p for p, count in phoneme_counts if 50 <= count <= 1000]\n",
 | ||
|     "#         target_phonemes = target_phonemes[:5]  # 最多处理5个音素\n",
 | ||
|     "    \n",
 | ||
|     "#     print(f\"将处理的音素: {target_phonemes}\")\n",
 | ||
|     "    \n",
 | ||
|     "#     for phoneme in target_phonemes:\n",
 | ||
|     "#         if phoneme not in processed_result:\n",
 | ||
|     "#             continue\n",
 | ||
|     "            \n",
 | ||
|     "#         sequences = processed_result[phoneme]\n",
 | ||
|     "#         print(f\"\\n\" + \"=\"*50)\n",
 | ||
|     "#         print(f\"分析音素 '{phoneme}' ({len(sequences)} 个样本)\")\n",
 | ||
|     "#         print(\"=\"*50)\n",
 | ||
|     "        \n",
 | ||
|     "#         # 展平序列数据\n",
 | ||
|     "#         flattened_sequences = []\n",
 | ||
|     "#         for seq in sequences:\n",
 | ||
|     "#             flattened_sequences.append(seq.flatten())\n",
 | ||
|     "        \n",
 | ||
|     "#         flattened_sequences = np.array(flattened_sequences)\n",
 | ||
|     "#         print(f\"原始数据形状: {flattened_sequences.shape}\")\n",
 | ||
|     "        \n",
 | ||
|     "#         # 标准化数据\n",
 | ||
|     "#         scaler = StandardScaler()\n",
 | ||
|     "#         scaled_data = scaler.fit_transform(flattened_sequences)\n",
 | ||
|     "        \n",
 | ||
|     "#         # 智能确定PCA组件数\n",
 | ||
|     "#         optimal_components = determine_optimal_pca_components(\n",
 | ||
|     "#             scaled_data, min_variance_ratio=min_variance_ratio\n",
 | ||
|     "#         )\n",
 | ||
|     "        \n",
 | ||
|     "#         # 执行PCA降维\n",
 | ||
|     "#         pca = PCA(n_components=optimal_components, random_state=42)\n",
 | ||
|     "#         pca_data = pca.fit_transform(scaled_data)\n",
 | ||
|     "#         variance_explained = np.sum(pca.explained_variance_ratio_)\n",
 | ||
|     "        \n",
 | ||
|     "#         print(f\"PCA降维结果:\")\n",
 | ||
|     "#         print(f\"  原始维度: {scaled_data.shape[1]}\")\n",
 | ||
|     "#         print(f\"  降维后: {pca_data.shape[1]}\")\n",
 | ||
|     "#         print(f\"  解释方差比: {variance_explained:.4f}\")\n",
 | ||
|     "#         print(f\"  信息保留率: {variance_explained*100:.2f}%\")\n",
 | ||
|     "        \n",
 | ||
|     "#         # 使用简化的DBSCAN参数搜索\n",
 | ||
|     "#         print(f\"\\n开始DBSCAN参数搜索...\")\n",
 | ||
|     "        \n",
 | ||
|     "#         # 基于数据估计合理的eps范围\n",
 | ||
|     "#         from sklearn.neighbors import NearestNeighbors\n",
 | ||
|     "#         k = min(10, len(sequences)//10)\n",
 | ||
|     "#         nbrs = NearestNeighbors(n_neighbors=k)\n",
 | ||
|     "#         nbrs.fit(pca_data)\n",
 | ||
|     "#         distances, _ = nbrs.kneighbors(pca_data)\n",
 | ||
|     "#         k_distances = np.sort(distances[:, k-1])\n",
 | ||
|     "        \n",
 | ||
|     "#         # 选择eps候选值\n",
 | ||
|     "#         eps_candidates = [\n",
 | ||
|     "#             np.percentile(k_distances, 25),\n",
 | ||
|     "#             np.percentile(k_distances, 50),\n",
 | ||
|     "#             np.percentile(k_distances, 75),\n",
 | ||
|     "#             np.percentile(k_distances, 90)\n",
 | ||
|     "#         ]\n",
 | ||
|     "        \n",
 | ||
|     "#         min_samples_candidates = [5, 10, 15, 20]\n",
 | ||
|     "        \n",
 | ||
|     "#         print(f\"eps候选值: {[f'{e:.3f}' for e in eps_candidates]}\")\n",
 | ||
|     "#         print(f\"min_samples候选值: {min_samples_candidates}\")\n",
 | ||
|     "        \n",
 | ||
|     "#         best_score = -1\n",
 | ||
|     "#         best_result = None\n",
 | ||
|     "        \n",
 | ||
|     "#         for eps in eps_candidates:\n",
 | ||
|     "#             for min_samples in min_samples_candidates:\n",
 | ||
|     "#                 if min_samples >= len(sequences) // 5:  # min_samples不能太大\n",
 | ||
|     "#                     continue\n",
 | ||
|     "                \n",
 | ||
|     "#                 dbscan = DBSCAN(eps=eps, min_samples=min_samples)\n",
 | ||
|     "#                 labels = dbscan.fit_predict(pca_data)\n",
 | ||
|     "                \n",
 | ||
|     "#                 n_outliers = np.sum(labels == -1)\n",
 | ||
|     "#                 n_clusters = len(set(labels)) - (1 if -1 in labels else 0)\n",
 | ||
|     "#                 outlier_ratio = n_outliers / len(labels)\n",
 | ||
|     "                \n",
 | ||
|     "#                 # 计算评分\n",
 | ||
|     "#                 if n_clusters > 0 and 0.05 <= outlier_ratio <= 0.40:  # 异常值比例在5%-40%之间\n",
 | ||
|     "#                     try:\n",
 | ||
|     "#                         if len(set(labels[labels != -1])) > 1:\n",
 | ||
|     "#                             silhouette = silhouette_score(pca_data[labels != -1], labels[labels != -1])\n",
 | ||
|     "#                         else:\n",
 | ||
|     "#                             silhouette = 0.5  # 单聚类给中等分数\n",
 | ||
|     "#                     except:\n",
 | ||
|     "#                         silhouette = 0\n",
 | ||
|     "                    \n",
 | ||
|     "#                     # 综合评分:轮廓系数 + 合理的异常值比例奖励\n",
 | ||
|     "#                     if 0.1 <= outlier_ratio <= 0.25:  # 最理想的异常值比例\n",
 | ||
|     "#                         ratio_bonus = 0.2\n",
 | ||
|     "#                     else:\n",
 | ||
|     "#                         ratio_bonus = 0.1\n",
 | ||
|     "                    \n",
 | ||
|     "#                     score = silhouette + ratio_bonus\n",
 | ||
|     "                    \n",
 | ||
|     "#                     if score > best_score:\n",
 | ||
|     "#                         best_score = score\n",
 | ||
|     "#                         best_result = {\n",
 | ||
|     "#                             'eps': eps,\n",
 | ||
|     "#                             'min_samples': min_samples,\n",
 | ||
|     "#                             'labels': labels.copy(),\n",
 | ||
|     "#                             'outliers': np.where(labels == -1)[0],\n",
 | ||
|     "#                             'n_clusters': n_clusters,\n",
 | ||
|     "#                             'outlier_ratio': outlier_ratio,\n",
 | ||
|     "#                             'silhouette': silhouette\n",
 | ||
|     "#                         }\n",
 | ||
|     "                \n",
 | ||
|     "#                 print(f\"  eps={eps:.3f}, min_samples={min_samples}: \"\n",
 | ||
|     "#                       f\"{n_clusters}聚类, {n_outliers}异常值 ({outlier_ratio*100:.1f}%)\")\n",
 | ||
|     "        \n",
 | ||
|     "#         if best_result is not None:\n",
 | ||
|     "#             outlier_results[phoneme] = {\n",
 | ||
|     "#                 **best_result,\n",
 | ||
|     "#                 'pca_data': pca_data,\n",
 | ||
|     "#                 'pca_model': pca,\n",
 | ||
|     "#                 'variance_explained': variance_explained,\n",
 | ||
|     "#                 'original_data': flattened_sequences,\n",
 | ||
|     "#                 'scaled_data': scaled_data\n",
 | ||
|     "#             }\n",
 | ||
|     "            \n",
 | ||
|     "#             print(f\"\\n✅ 找到最佳参数:\")\n",
 | ||
|     "#             print(f\"   eps: {best_result['eps']:.3f}\")\n",
 | ||
|     "#             print(f\"   min_samples: {best_result['min_samples']}\")\n",
 | ||
|     "#             print(f\"   聚类数: {best_result['n_clusters']}\")\n",
 | ||
|     "#             print(f\"   异常值: {len(best_result['outliers'])} ({best_result['outlier_ratio']*100:.1f}%)\")\n",
 | ||
|     "#             print(f\"   轮廓系数: {best_result['silhouette']:.3f}\")\n",
 | ||
|     "#             print(f\"   综合评分: {best_score:.3f}\")\n",
 | ||
|     "#         else:\n",
 | ||
|     "#             print(f\"\\n❌ 未找到合适的DBSCAN参数\")\n",
 | ||
|     "    \n",
 | ||
|     "#     return outlier_results\n",
 | ||
|     "\n",
 | ||
|     "# def visualize_smart_dbscan_results(outlier_results):\n",
 | ||
|     "#     \"\"\"\n",
 | ||
|     "#     可视化智能DBSCAN结果\n",
 | ||
|     "#     \"\"\"\n",
 | ||
|     "#     if not outlier_results:\n",
 | ||
|     "#         print(\"没有结果可可视化\")\n",
 | ||
|     "#         return\n",
 | ||
|     "    \n",
 | ||
|     "#     n_phonemes = len(outlier_results)\n",
 | ||
|     "#     fig, axes = plt.subplots(2, n_phonemes, figsize=(6*n_phonemes, 10))\n",
 | ||
|     "    \n",
 | ||
|     "#     if n_phonemes == 1:\n",
 | ||
|     "#         axes = axes.reshape(2, 1)\n",
 | ||
|     "    \n",
 | ||
|     "#     for i, (phoneme, result) in enumerate(outlier_results.items()):\n",
 | ||
|     "#         # 上图:PCA散点图\n",
 | ||
|     "#         ax1 = axes[0, i]\n",
 | ||
|     "#         pca_data = result['pca_data']\n",
 | ||
|     "#         labels = result['labels']\n",
 | ||
|     "        \n",
 | ||
|     "#         # 绘制聚类\n",
 | ||
|     "#         unique_labels = set(labels)\n",
 | ||
|     "#         colors = plt.cm.Set3(np.linspace(0, 1, len(unique_labels)))\n",
 | ||
|     "        \n",
 | ||
|     "#         for j, label in enumerate(unique_labels):\n",
 | ||
|     "#             if label == -1:\n",
 | ||
|     "#                 mask = labels == label\n",
 | ||
|     "#                 ax1.scatter(pca_data[mask, 0], pca_data[mask, 1], \n",
 | ||
|     "#                            c='red', marker='x', s=60, label='异常值', alpha=0.8)\n",
 | ||
|     "#             else:\n",
 | ||
|     "#                 mask = labels == label\n",
 | ||
|     "#                 ax1.scatter(pca_data[mask, 0], pca_data[mask, 1], \n",
 | ||
|     "#                            c=[colors[j]], label=f'聚类 {label}', alpha=0.7, s=30)\n",
 | ||
|     "        \n",
 | ||
|     "#         ax1.set_title(f'音素 \"{phoneme}\" DBSCAN结果\\n'\n",
 | ||
|     "#                      f'{result[\"n_clusters\"]}聚类, {len(result[\"outliers\"])}异常值 '\n",
 | ||
|     "#                      f'({result[\"outlier_ratio\"]*100:.1f}%)')\n",
 | ||
|     "#         ax1.set_xlabel('PC1')\n",
 | ||
|     "#         ax1.set_ylabel('PC2')\n",
 | ||
|     "#         ax1.legend()\n",
 | ||
|     "#         ax1.grid(True, alpha=0.3)\n",
 | ||
|     "        \n",
 | ||
|     "#         # 下图:方差解释图\n",
 | ||
|     "#         ax2 = axes[1, i]\n",
 | ||
|     "#         pca_model = result['pca_model']\n",
 | ||
|     "#         n_components = len(pca_model.explained_variance_ratio_)\n",
 | ||
|     "        \n",
 | ||
|     "#         # 绘制累计方差解释比\n",
 | ||
|     "#         cumsum_var = np.cumsum(pca_model.explained_variance_ratio_)\n",
 | ||
|     "#         ax2.plot(range(1, n_components+1), cumsum_var, 'b-', marker='o')\n",
 | ||
|     "#         ax2.axhline(y=0.8, color='r', linestyle='--', label='80%阈值')\n",
 | ||
|     "#         ax2.axhline(y=0.9, color='g', linestyle='--', label='90%阈值')\n",
 | ||
|     "        \n",
 | ||
|     "#         ax2.set_xlabel('主成分数量')\n",
 | ||
|     "#         ax2.set_ylabel('累计解释方差比')\n",
 | ||
|     "#         ax2.set_title(f'PCA方差解释 (总计: {result[\"variance_explained\"]:.3f})')\n",
 | ||
|     "#         ax2.legend()\n",
 | ||
|     "#         ax2.grid(True, alpha=0.3)\n",
 | ||
|     "#         ax2.set_ylim(0, 1)\n",
 | ||
|     "    \n",
 | ||
|     "#     plt.tight_layout()\n",
 | ||
|     "#     plt.show()\n",
 | ||
|     "\n",
 | ||
|     "# def analyze_outliers_detailed(outlier_results):\n",
 | ||
|     "#     \"\"\"\n",
 | ||
|     "#     详细分析异常值\n",
 | ||
|     "#     \"\"\"\n",
 | ||
|     "#     print(\"\\n\" + \"=\"*70)\n",
 | ||
|     "#     print(\"详细异常值分析报告\")\n",
 | ||
|     "#     print(\"=\"*70)\n",
 | ||
|     "    \n",
 | ||
|     "#     for phoneme, result in outlier_results.items():\n",
 | ||
|     "#         print(f\"\\n音素 '{phoneme}' 异常值分析:\")\n",
 | ||
|     "#         print(\"-\" * 40)\n",
 | ||
|     "        \n",
 | ||
|     "#         outlier_indices = result['outliers']\n",
 | ||
|     "#         normal_indices = [i for i in range(len(result['labels'])) if i not in outlier_indices]\n",
 | ||
|     "        \n",
 | ||
|     "#         print(f\"总样本数: {len(result['labels'])}\")\n",
 | ||
|     "#         print(f\"正常样本: {len(normal_indices)} ({len(normal_indices)/len(result['labels'])*100:.1f}%)\")\n",
 | ||
|     "#         print(f\"异常样本: {len(outlier_indices)} ({len(outlier_indices)/len(result['labels'])*100:.1f}%)\")\n",
 | ||
|     "#         print(f\"聚类数量: {result['n_clusters']}\")\n",
 | ||
|     "#         print(f\"PCA维度: {result['pca_data'].shape[1]}\")\n",
 | ||
|     "#         print(f\"信息保留: {result['variance_explained']*100:.2f}%\")\n",
 | ||
|     "        \n",
 | ||
|     "#         # 分析异常值的特征\n",
 | ||
|     "#         if len(outlier_indices) > 0:\n",
 | ||
|     "#             outlier_data = result['pca_data'][outlier_indices]\n",
 | ||
|     "#             normal_data = result['pca_data'][normal_indices] if len(normal_indices) > 0 else None\n",
 | ||
|     "            \n",
 | ||
|     "#             print(f\"\\n异常值特征分析:\")\n",
 | ||
|     "#             print(f\"  PC1均值: {np.mean(outlier_data[:, 0]):.3f} ± {np.std(outlier_data[:, 0]):.3f}\")\n",
 | ||
|     "#             print(f\"  PC2均值: {np.mean(outlier_data[:, 1]):.3f} ± {np.std(outlier_data[:, 1]):.3f}\")\n",
 | ||
|     "            \n",
 | ||
|     "#             if normal_data is not None:\n",
 | ||
|     "#                 print(f\"正常值特征对比:\")\n",
 | ||
|     "#                 print(f\"  PC1均值: {np.mean(normal_data[:, 0]):.3f} ± {np.std(normal_data[:, 0]):.3f}\")\n",
 | ||
|     "#                 print(f\"  PC2均值: {np.mean(normal_data[:, 1]):.3f} ± {np.std(normal_data[:, 1]):.3f}\")\n",
 | ||
|     "\n",
 | ||
|     "# # 执行优化的DBSCAN异常值检测\n",
 | ||
|     "# print(\"开始智能DBSCAN异常值检测...\")\n",
 | ||
|     "\n",
 | ||
|     "# # 选择几个有代表性的音素进行分析\n",
 | ||
|     "# target_phonemes = ['IH', 'T', 'S', 'N', 'AH']  # 手动选择一些常见音素\n",
 | ||
|     "\n",
 | ||
|     "# outlier_results = smart_dbscan_outlier_detection(\n",
 | ||
|     "#     processed_result, \n",
 | ||
|     "#     target_phonemes=target_phonemes,\n",
 | ||
|     "#     min_variance_ratio=0.75  # 保留至少75%的方差\n",
 | ||
|     "# )\n",
 | ||
|     "\n",
 | ||
|     "# if outlier_results:\n",
 | ||
|     "#     print(f\"\\n✅ 成功检测到 {len(outlier_results)} 个音素的异常值!\")\n",
 | ||
|     "    \n",
 | ||
|     "#     # 可视化结果\n",
 | ||
|     "#     visualize_smart_dbscan_results(outlier_results)\n",
 | ||
|     "    \n",
 | ||
|     "#     # 详细分析\n",
 | ||
|     "#     analyze_outliers_detailed(outlier_results)\n",
 | ||
|     "    \n",
 | ||
|     "#     print(f\"\\n📊 异常值检测总结:\")\n",
 | ||
|     "#     for phoneme, result in outlier_results.items():\n",
 | ||
|     "#         print(f\"  {phoneme}: {len(result['outliers'])}/{len(result['labels'])} \"\n",
 | ||
|     "#               f\"({result['outlier_ratio']*100:.1f}%) 异常值\")\n",
 | ||
|     "    \n",
 | ||
|     "# else:\n",
 | ||
|     "#     print(\"❌ 未检测到任何有效的异常值结果\")\n",
 | ||
|     "\n",
 | ||
|     "# print(f\"\\n💡 关键改进:\")\n",
 | ||
|     "# print(f\"1. 使用自适应PCA降维,保留75%以上的方差信息\")\n",
 | ||
|     "# print(f\"2. 基于数据分布智能选择DBSCAN参数\")\n",
 | ||
|     "# print(f\"3. 合理的异常值比例范围(5%-40%)\")\n",
 | ||
|     "# print(f\"4. 综合评分机制平衡聚类质量和异常值检测\")"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 23,
 | ||
|    "metadata": {},
 | ||
|    "outputs": [],
 | ||
|    "source": [
 | ||
|     "# # 类内相似度分析\n",
 | ||
|     "# import numpy as np\n",
 | ||
|     "# import matplotlib.pyplot as plt\n",
 | ||
|     "# from sklearn.metrics.pairwise import cosine_similarity, euclidean_distances\n",
 | ||
|     "# from sklearn.preprocessing import StandardScaler\n",
 | ||
|     "# import pandas as pd\n",
 | ||
|     "# import seaborn as sns\n",
 | ||
|     "\n",
 | ||
|     "# print(\"=\"*70)\n",
 | ||
|     "# print(\"类内相似度分析 - 音素分类有效性评估\")\n",
 | ||
|     "# print(\"=\"*70)\n",
 | ||
|     "\n",
 | ||
|     "# def calculate_intra_class_similarity(sequences, metric='cosine'):\n",
 | ||
|     "#     \"\"\"\n",
 | ||
|     "#     计算类内相似度\n",
 | ||
|     "#     \"\"\"\n",
 | ||
|     "#     if len(sequences) < 2:\n",
 | ||
|     "#         return np.nan, np.nan, np.nan\n",
 | ||
|     "    \n",
 | ||
|     "#     # 展平序列\n",
 | ||
|     "#     flattened_sequences = []\n",
 | ||
|     "#     for seq in sequences:\n",
 | ||
|     "#         flattened_sequences.append(seq.flatten())\n",
 | ||
|     "    \n",
 | ||
|     "#     flattened_sequences = np.array(flattened_sequences)\n",
 | ||
|     "    \n",
 | ||
|     "#     # 标准化\n",
 | ||
|     "#     scaler = StandardScaler()\n",
 | ||
|     "#     scaled_sequences = scaler.fit_transform(flattened_sequences)\n",
 | ||
|     "    \n",
 | ||
|     "#     if metric == 'cosine':\n",
 | ||
|     "#         # 计算余弦相似度\n",
 | ||
|     "#         similarity_matrix = cosine_similarity(scaled_sequences)\n",
 | ||
|     "#         # 提取上三角矩阵(排除对角线)\n",
 | ||
|     "#         upper_tri = np.triu(similarity_matrix, k=1)\n",
 | ||
|     "#         similarities = upper_tri[upper_tri > 0]\n",
 | ||
|     "#     elif metric == 'euclidean':\n",
 | ||
|     "#         # 计算欧氏距离,然后转换为相似度\n",
 | ||
|     "#         distance_matrix = euclidean_distances(scaled_sequences)\n",
 | ||
|     "#         # 转换为相似度(距离越小,相似度越高)\n",
 | ||
|     "#         max_dist = np.max(distance_matrix)\n",
 | ||
|     "#         similarity_matrix = 1 - (distance_matrix / max_dist)\n",
 | ||
|     "#         upper_tri = np.triu(similarity_matrix, k=1)\n",
 | ||
|     "#         similarities = upper_tri[upper_tri > 0]\n",
 | ||
|     "    \n",
 | ||
|     "#     mean_similarity = np.mean(similarities)\n",
 | ||
|     "#     std_similarity = np.std(similarities)\n",
 | ||
|     "#     median_similarity = np.median(similarities)\n",
 | ||
|     "    \n",
 | ||
|     "#     return mean_similarity, std_similarity, median_similarity\n",
 | ||
|     "\n",
 | ||
|     "# def analyze_phoneme_similarity(processed_result, metric='cosine', sample_limit=500):\n",
 | ||
|     "#     \"\"\"\n",
 | ||
|     "#     分析每个音素的类内相似度\n",
 | ||
|     "#     \"\"\"\n",
 | ||
|     "#     print(f\"使用 {metric} 相似度度量\")\n",
 | ||
|     "#     print(f\"每个音素最多分析 {sample_limit} 个样本\")\n",
 | ||
|     "#     print(\"-\" * 50)\n",
 | ||
|     "    \n",
 | ||
|     "#     phoneme_similarities = {}\n",
 | ||
|     "    \n",
 | ||
|     "#     for phoneme, sequences in processed_result.items():\n",
 | ||
|     "#         if len(sequences) < 5:  # 跳过样本数太少的音素\n",
 | ||
|     "#             print(f\"跳过音素 '{phoneme}' (样本数太少: {len(sequences)})\")\n",
 | ||
|     "#             continue\n",
 | ||
|     "        \n",
 | ||
|     "#         # 如果样本太多,随机采样\n",
 | ||
|     "#         if len(sequences) > sample_limit:\n",
 | ||
|     "#             indices = np.random.choice(len(sequences), sample_limit, replace=False)\n",
 | ||
|     "#             sampled_sequences = [sequences[i] for i in indices]\n",
 | ||
|     "#         else:\n",
 | ||
|     "#             sampled_sequences = sequences\n",
 | ||
|     "        \n",
 | ||
|     "#         mean_sim, std_sim, median_sim = calculate_intra_class_similarity(\n",
 | ||
|     "#             sampled_sequences, metric=metric\n",
 | ||
|     "#         )\n",
 | ||
|     "        \n",
 | ||
|     "#         phoneme_similarities[phoneme] = {\n",
 | ||
|     "#             'mean': mean_sim,\n",
 | ||
|     "#             'std': std_sim,\n",
 | ||
|     "#             'median': median_sim,\n",
 | ||
|     "#             'n_samples': len(sampled_sequences),\n",
 | ||
|     "#             'n_pairs': len(sampled_sequences) * (len(sampled_sequences) - 1) // 2\n",
 | ||
|     "#         }\n",
 | ||
|     "        \n",
 | ||
|     "#         print(f\"音素 '{phoneme}': 平均相似度={mean_sim:.4f} ± {std_sim:.4f}, \"\n",
 | ||
|     "#               f\"中位数={median_sim:.4f}, 样本数={len(sampled_sequences)}\")\n",
 | ||
|     "    \n",
 | ||
|     "#     return phoneme_similarities\n",
 | ||
|     "\n",
 | ||
|     "# def calculate_overall_similarity(processed_result, metric='cosine', sample_per_phoneme=50):\n",
 | ||
|     "#     \"\"\"\n",
 | ||
|     "#     计算全部音素作为一类的相似度\n",
 | ||
|     "#     \"\"\"\n",
 | ||
|     "#     print(f\"\\n计算全体音素相似度 (每个音素采样 {sample_per_phoneme} 个)\")\n",
 | ||
|     "#     print(\"-\" * 50)\n",
 | ||
|     "    \n",
 | ||
|     "#     all_sequences = []\n",
 | ||
|     "#     phoneme_labels = []\n",
 | ||
|     "    \n",
 | ||
|     "#     # 从每个音素中采样一定数量的序列\n",
 | ||
|     "#     for phoneme, sequences in processed_result.items():\n",
 | ||
|     "#         if len(sequences) < 5:\n",
 | ||
|     "#             continue\n",
 | ||
|     "        \n",
 | ||
|     "#         n_sample = min(sample_per_phoneme, len(sequences))\n",
 | ||
|     "#         indices = np.random.choice(len(sequences), n_sample, replace=False)\n",
 | ||
|     "        \n",
 | ||
|     "#         for i in indices:\n",
 | ||
|     "#             all_sequences.append(sequences[i])\n",
 | ||
|     "#             phoneme_labels.append(phoneme)\n",
 | ||
|     "    \n",
 | ||
|     "#     print(f\"总共收集了 {len(all_sequences)} 个序列,来自 {len(set(phoneme_labels))} 个音素\")\n",
 | ||
|     "    \n",
 | ||
|     "#     # 计算整体相似度\n",
 | ||
|     "#     mean_sim, std_sim, median_sim = calculate_intra_class_similarity(\n",
 | ||
|     "#         all_sequences, metric=metric\n",
 | ||
|     "#     )\n",
 | ||
|     "    \n",
 | ||
|     "#     overall_result = {\n",
 | ||
|     "#         'mean': mean_sim,\n",
 | ||
|     "#         'std': std_sim,\n",
 | ||
|     "#         'median': median_sim,\n",
 | ||
|     "#         'n_samples': len(all_sequences),\n",
 | ||
|     "#         'n_pairs': len(all_sequences) * (len(all_sequences) - 1) // 2,\n",
 | ||
|     "#         'n_phonemes': len(set(phoneme_labels))\n",
 | ||
|     "#     }\n",
 | ||
|     "    \n",
 | ||
|     "#     print(f\"全体音素: 平均相似度={mean_sim:.4f} ± {std_sim:.4f}, \"\n",
 | ||
|     "#           f\"中位数={median_sim:.4f}, 样本数={len(all_sequences)}\")\n",
 | ||
|     "    \n",
 | ||
|     "#     return overall_result, phoneme_labels\n",
 | ||
|     "\n",
 | ||
|     "# def visualize_similarity_comparison(phoneme_similarities, overall_result, metric='cosine'):\n",
 | ||
|     "#     \"\"\"\n",
 | ||
|     "#     可视化相似度比较\n",
 | ||
|     "#     \"\"\"\n",
 | ||
|     "#     fig, axes = plt.subplots(2, 2, figsize=(16, 12))\n",
 | ||
|     "    \n",
 | ||
|     "#     # 准备数据\n",
 | ||
|     "#     phonemes = list(phoneme_similarities.keys())\n",
 | ||
|     "#     mean_similarities = [phoneme_similarities[p]['mean'] for p in phonemes]\n",
 | ||
|     "#     std_similarities = [phoneme_similarities[p]['std'] for p in phonemes]\n",
 | ||
|     "#     sample_counts = [phoneme_similarities[p]['n_samples'] for p in phonemes]\n",
 | ||
|     "    \n",
 | ||
|     "#     overall_mean = overall_result['mean']\n",
 | ||
|     "#     overall_std = overall_result['std']\n",
 | ||
|     "    \n",
 | ||
|     "#     # 图1: 每个音素的平均相似度 vs 整体相似度\n",
 | ||
|     "#     ax1 = axes[0, 0]\n",
 | ||
|     "#     bars = ax1.bar(range(len(phonemes)), mean_similarities, \n",
 | ||
|     "#                    yerr=std_similarities, capsize=3, alpha=0.7, color='skyblue')\n",
 | ||
|     "#     ax1.axhline(y=overall_mean, color='red', linestyle='--', linewidth=2, \n",
 | ||
|     "#                 label=f'全体音素平均: {overall_mean:.4f}')\n",
 | ||
|     "#     ax1.fill_between(range(len(phonemes)), \n",
 | ||
|     "#                      overall_mean - overall_std, \n",
 | ||
|     "#                      overall_mean + overall_std, \n",
 | ||
|     "#                      alpha=0.2, color='red', label=f'全体音素范围: ±{overall_std:.4f}')\n",
 | ||
|     "    \n",
 | ||
|     "#     ax1.set_xlabel('音素')\n",
 | ||
|     "#     ax1.set_ylabel(f'{metric.title()} 相似度')\n",
 | ||
|     "#     ax1.set_title('各音素类内相似度 vs 全体音素相似度')\n",
 | ||
|     "#     ax1.set_xticks(range(len(phonemes)))\n",
 | ||
|     "#     ax1.set_xticklabels(phonemes, rotation=45)\n",
 | ||
|     "#     ax1.legend()\n",
 | ||
|     "#     ax1.grid(True, alpha=0.3)\n",
 | ||
|     "    \n",
 | ||
|     "#     # 在柱状图上显示数值\n",
 | ||
|     "#     for i, (bar, mean_val) in enumerate(zip(bars, mean_similarities)):\n",
 | ||
|     "#         if mean_val > overall_mean:\n",
 | ||
|     "#             ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, \n",
 | ||
|     "#                     f'{mean_val:.3f}', ha='center', va='bottom', fontsize=8, \n",
 | ||
|     "#                     color='green', weight='bold')\n",
 | ||
|     "#         else:\n",
 | ||
|     "#             ax1.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.01, \n",
 | ||
|     "#                     f'{mean_val:.3f}', ha='center', va='bottom', fontsize=8, \n",
 | ||
|     "#                     color='red', weight='bold')\n",
 | ||
|     "    \n",
 | ||
|     "#     # 图2: 相似度提升程度\n",
 | ||
|     "#     ax2 = axes[0, 1]\n",
 | ||
|     "#     improvements = [(sim - overall_mean) for sim in mean_similarities]\n",
 | ||
|     "#     colors = ['green' if imp > 0 else 'red' for imp in improvements]\n",
 | ||
|     "    \n",
 | ||
|     "#     bars2 = ax2.bar(range(len(phonemes)), improvements, color=colors, alpha=0.7)\n",
 | ||
|     "#     ax2.axhline(y=0, color='black', linestyle='-', linewidth=1)\n",
 | ||
|     "#     ax2.set_xlabel('音素')\n",
 | ||
|     "#     ax2.set_ylabel('相似度提升 (相对于全体)')\n",
 | ||
|     "#     ax2.set_title('音素分类的相似度提升效果')\n",
 | ||
|     "#     ax2.set_xticks(range(len(phonemes)))\n",
 | ||
|     "#     ax2.set_xticklabels(phonemes, rotation=45)\n",
 | ||
|     "#     ax2.grid(True, alpha=0.3)\n",
 | ||
|     "    \n",
 | ||
|     "#     # 显示数值\n",
 | ||
|     "#     for bar, imp in zip(bars2, improvements):\n",
 | ||
|     "#         ax2.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.001, \n",
 | ||
|     "#                 f'{imp:+.3f}', ha='center', va='bottom' if imp > 0 else 'top', \n",
 | ||
|     "#                 fontsize=8, weight='bold')\n",
 | ||
|     "    \n",
 | ||
|     "#     # 图3: 样本数量 vs 相似度\n",
 | ||
|     "#     ax3 = axes[1, 0]\n",
 | ||
|     "#     scatter = ax3.scatter(sample_counts, mean_similarities, \n",
 | ||
|     "#                          c=improvements, cmap='RdYlGn', s=60, alpha=0.7)\n",
 | ||
|     "#     ax3.axhline(y=overall_mean, color='red', linestyle='--', alpha=0.7)\n",
 | ||
|     "#     ax3.set_xlabel('样本数量')\n",
 | ||
|     "#     ax3.set_ylabel(f'{metric.title()} 相似度')\n",
 | ||
|     "#     ax3.set_title('样本数量 vs 相似度')\n",
 | ||
|     "#     ax3.grid(True, alpha=0.3)\n",
 | ||
|     "    \n",
 | ||
|     "#     # 添加颜色条\n",
 | ||
|     "#     cbar = plt.colorbar(scatter, ax=ax3)\n",
 | ||
|     "#     cbar.set_label('相似度提升')\n",
 | ||
|     "    \n",
 | ||
|     "#     # 图4: 相似度分布统计\n",
 | ||
|     "#     ax4 = axes[1, 1]\n",
 | ||
|     "#     ax4.axis('off')\n",
 | ||
|     "    \n",
 | ||
|     "#     # 计算统计信息\n",
 | ||
|     "#     positive_improvements = [imp for imp in improvements if imp > 0]\n",
 | ||
|     "#     negative_improvements = [imp for imp in improvements if imp <= 0]\n",
 | ||
|     "    \n",
 | ||
|     "#     avg_improvement = np.mean(improvements)\n",
 | ||
|     "#     max_improvement = np.max(improvements)\n",
 | ||
|     "#     min_improvement = np.min(improvements)\n",
 | ||
|     "    \n",
 | ||
|     "#     best_phoneme = phonemes[improvements.index(max_improvement)]\n",
 | ||
|     "#     worst_phoneme = phonemes[improvements.index(min_improvement)]\n",
 | ||
|     "    \n",
 | ||
|     "#     stats_text = f\"\"\"\n",
 | ||
|     "# 类内相似度分析统计报告\n",
 | ||
|     "\n",
 | ||
|     "# 度量方法: {metric.title()} 相似度\n",
 | ||
|     "\n",
 | ||
|     "# 全体音素基线:\n",
 | ||
|     "# - 平均相似度: {overall_mean:.4f} ± {overall_std:.4f}\n",
 | ||
|     "# - 样本总数: {overall_result['n_samples']}\n",
 | ||
|     "# - 音素数量: {overall_result['n_phonemes']}\n",
 | ||
|     "\n",
 | ||
|     "# 音素分类效果:\n",
 | ||
|     "# - 分析音素数: {len(phonemes)}\n",
 | ||
|     "# - 平均提升: {avg_improvement:+.4f}\n",
 | ||
|     "# - 最大提升: {max_improvement:+.4f} ({best_phoneme})\n",
 | ||
|     "# - 最小提升: {min_improvement:+.4f} ({worst_phoneme})\n",
 | ||
|     "\n",
 | ||
|     "# 提升统计:\n",
 | ||
|     "# - 相似度提升音素: {len(positive_improvements)}/{len(phonemes)} ({len(positive_improvements)/len(phonemes)*100:.1f}%)\n",
 | ||
|     "# - 相似度下降音素: {len(negative_improvements)}/{len(phonemes)} ({len(negative_improvements)/len(phonemes)*100:.1f}%)\n",
 | ||
|     "\n",
 | ||
|     "# 结论:\n",
 | ||
|     "# \"\"\"\n",
 | ||
|     "    \n",
 | ||
|     "#     if avg_improvement > 0:\n",
 | ||
|     "#         stats_text += f\"✅ 音素分类整体有效 (平均提升 {avg_improvement:.4f})\\n\"\n",
 | ||
|     "#         stats_text += f\"   按音素分类比混合所有音素更能保持相似性\"\n",
 | ||
|     "#     else:\n",
 | ||
|     "#         stats_text += f\"❌ 音素分类效果有限 (平均下降 {avg_improvement:.4f})\\n\"\n",
 | ||
|     "#         stats_text += f\"   可能需要重新考虑分类策略\"\n",
 | ||
|     "    \n",
 | ||
|     "#     ax4.text(0.05, 0.95, stats_text, transform=ax4.transAxes, fontsize=10,\n",
 | ||
|     "#             verticalalignment='top', fontfamily='monospace',\n",
 | ||
|     "#             bbox=dict(boxstyle='round', facecolor='lightblue', alpha=0.8))\n",
 | ||
|     "    \n",
 | ||
|     "#     plt.tight_layout()\n",
 | ||
|     "#     plt.show()\n",
 | ||
|     "    \n",
 | ||
|     "#     return improvements\n",
 | ||
|     "\n",
 | ||
|     "# def create_similarity_dataframe(phoneme_similarities, overall_result):\n",
 | ||
|     "#     \"\"\"\n",
 | ||
|     "#     创建相似度对比的DataFrame\n",
 | ||
|     "#     \"\"\"\n",
 | ||
|     "#     data = []\n",
 | ||
|     "#     overall_mean = overall_result['mean']\n",
 | ||
|     "    \n",
 | ||
|     "#     for phoneme, sim_data in phoneme_similarities.items():\n",
 | ||
|     "#         improvement = sim_data['mean'] - overall_mean\n",
 | ||
|     "#         relative_improvement = (improvement / overall_mean) * 100\n",
 | ||
|     "        \n",
 | ||
|     "#         data.append({\n",
 | ||
|     "#             'phoneme': phoneme,\n",
 | ||
|     "#             'intra_class_similarity': sim_data['mean'],\n",
 | ||
|     "#             'similarity_std': sim_data['std'],\n",
 | ||
|     "#             'n_samples': sim_data['n_samples'],\n",
 | ||
|     "#             'overall_baseline': overall_mean,\n",
 | ||
|     "#             'absolute_improvement': improvement,\n",
 | ||
|     "#             'relative_improvement_pct': relative_improvement,\n",
 | ||
|     "#             'is_better': improvement > 0\n",
 | ||
|     "#         })\n",
 | ||
|     "    \n",
 | ||
|     "#     df = pd.DataFrame(data)\n",
 | ||
|     "#     df = df.sort_values('absolute_improvement', ascending=False)\n",
 | ||
|     "    \n",
 | ||
|     "#     return df\n",
 | ||
|     "\n",
 | ||
|     "# # 执行类内相似度分析\n",
 | ||
|     "# print(\"开始类内相似度分析...\")\n",
 | ||
|     "\n",
 | ||
|     "# # 1. 分析每个音素的类内相似度\n",
 | ||
|     "# print(\"\\n1. 计算各音素类内相似度\")\n",
 | ||
|     "# phoneme_similarities_cosine = analyze_phoneme_similarity(\n",
 | ||
|     "#     processed_result, metric='cosine', sample_limit=300\n",
 | ||
|     "# )\n",
 | ||
|     "\n",
 | ||
|     "# # 2. 计算全体音素的相似度\n",
 | ||
|     "# print(\"\\n2. 计算全体音素相似度作为基线\")\n",
 | ||
|     "# overall_result_cosine, all_phoneme_labels = calculate_overall_similarity(\n",
 | ||
|     "#     processed_result, metric='cosine', sample_per_phoneme=30\n",
 | ||
|     "# )\n",
 | ||
|     "\n",
 | ||
|     "# # 3. 可视化比较\n",
 | ||
|     "# print(\"\\n3. 可视化相似度比较\")\n",
 | ||
|     "# improvements = visualize_similarity_comparison(\n",
 | ||
|     "#     phoneme_similarities_cosine, overall_result_cosine, metric='cosine'\n",
 | ||
|     "# )\n",
 | ||
|     "\n",
 | ||
|     "# # 4. 创建详细对比表\n",
 | ||
|     "# print(\"\\n4. 详细相似度对比表\")\n",
 | ||
|     "# df_similarity = create_similarity_dataframe(phoneme_similarities_cosine, overall_result_cosine)\n",
 | ||
|     "# print(df_similarity.to_string(index=False, float_format='%.4f'))\n",
 | ||
|     "\n",
 | ||
|     "# print(f\"\\n✅ 类内相似度分析完成!\")\n",
 | ||
|     "# print(f\"这个分析揭示了音素分类对神经信号相似性的影响。\")"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": null,
 | ||
|    "metadata": {},
 | ||
|    "outputs": [],
 | ||
|    "source": []
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "metadata": {},
 | ||
|    "source": [
 | ||
|     "# 🔗 数据集批量处理工作流"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 24,
 | ||
|    "metadata": {},
 | ||
|    "outputs": [
 | ||
|     {
 | ||
|      "name": "stdout",
 | ||
|      "output_type": "stream",
 | ||
|      "text": [
 | ||
|       "/kaggle/working/nejm-brain-to-text/model_training\n",
 | ||
|       "======================================================================\n",
 | ||
|       "🚀 RNN数据批量处理工具 - 新版本\n",
 | ||
|       "======================================================================\n",
 | ||
|       "🔧 创建RNN数据处理器...\n",
 | ||
|       "🔧 初始化RNN数据处理器...\n",
 | ||
|       "   模型路径: ../data/t15_pretrained_rnn_baseline\n",
 | ||
|       "   数据目录: ../data/hdf5_data_final\n",
 | ||
|       "   计算设备: cuda:0\n",
 | ||
|       "📋 模型配置:\n",
 | ||
|       "   Sessions数量: 45\n",
 | ||
|       "   神经特征维度: 512\n",
 | ||
|       "   Patch size: 14\n",
 | ||
|       "   Patch stride: 4\n",
 | ||
|       "   输出类别数: 41\n",
 | ||
|       "✅ 模型加载成功\n",
 | ||
|       "📊 CSV数据加载完成: 265 条记录\n",
 | ||
|       "✅ 初始化完成!\n",
 | ||
|       "✅ RNN数据处理器创建成功!\n",
 | ||
|       "✅ 模型加载成功\n",
 | ||
|       "📊 CSV数据加载完成: 265 条记录\n",
 | ||
|       "✅ 初始化完成!\n",
 | ||
|       "✅ RNN数据处理器创建成功!\n"
 | ||
|      ]
 | ||
|     }
 | ||
|    ],
 | ||
|    "source": [
 | ||
|     "%cd model_training\n",
 | ||
|     "# 🚀 RNN数据批量处理工具 - 完整版\n",
 | ||
|     "import os\n",
 | ||
|     "import torch\n",
 | ||
|     "import numpy as np\n",
 | ||
|     "import pandas as pd\n",
 | ||
|     "from omegaconf import OmegaConf\n",
 | ||
|     "import time\n",
 | ||
|     "from tqdm import tqdm\n",
 | ||
|     "import h5py\n",
 | ||
|     "from pathlib import Path\n",
 | ||
|     "\n",
 | ||
|     "# 导入模型相关模块\n",
 | ||
|     "import sys\n",
 | ||
|     "sys.path.append('../model_training')\n",
 | ||
|     "from rnn_model import GRUDecoder\n",
 | ||
|     "from evaluate_model_helpers import *\n",
 | ||
|     "from data_augmentations import gauss_smooth\n",
 | ||
|     "\n",
 | ||
|     "print(\"=\"*70)\n",
 | ||
|     "print(\"🚀 RNN数据批量处理工具 - 新版本\")\n",
 | ||
|     "print(\"=\"*70)\n",
 | ||
|     "\n",
 | ||
|     "class RNNDataProcessor:\n",
 | ||
|     "    \"\"\"\n",
 | ||
|     "    RNN数据批量处理器 - 生成RNN输入输出拼接数据\n",
 | ||
|     "    \n",
 | ||
|     "    核心功能:\n",
 | ||
|     "    1. 加载预训练RNN模型\n",
 | ||
|     "    2. 处理原始神经数据(高斯平滑 + patch操作)\n",
 | ||
|     "    3. 获取RNN输出(40类置信度分数)\n",
 | ||
|     "    4. 拼接处理后的输入和输出\n",
 | ||
|     "    5. 批量保存所有session数据\n",
 | ||
|     "    \"\"\"\n",
 | ||
|     "    \n",
 | ||
|     "    def __init__(self, model_path, data_dir, csv_path, device='auto'):\n",
 | ||
|     "        \"\"\"\n",
 | ||
|     "        初始化处理器\n",
 | ||
|     "        \n",
 | ||
|     "        参数:\n",
 | ||
|     "            model_path: 预训练RNN模型路径\n",
 | ||
|     "            data_dir: 数据目录路径  \n",
 | ||
|     "            csv_path: 数据描述CSV文件路径\n",
 | ||
|     "            device: 计算设备 ('auto', 'cpu', 'cuda:0'等)\n",
 | ||
|     "        \"\"\"\n",
 | ||
|     "        self.model_path = model_path\n",
 | ||
|     "        self.data_dir = data_dir\n",
 | ||
|     "        self.csv_path = csv_path\n",
 | ||
|     "        \n",
 | ||
|     "        # 设备选择\n",
 | ||
|     "        if device == 'auto':\n",
 | ||
|     "            self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
 | ||
|     "        else:\n",
 | ||
|     "            self.device = torch.device(device)\n",
 | ||
|     "        \n",
 | ||
|     "        print(f\"🔧 初始化RNN数据处理器...\")\n",
 | ||
|     "        print(f\"   模型路径: {model_path}\")\n",
 | ||
|     "        print(f\"   数据目录: {data_dir}\")\n",
 | ||
|     "        print(f\"   计算设备: {self.device}\")\n",
 | ||
|     "        \n",
 | ||
|     "        # 加载配置和模型\n",
 | ||
|     "        self._load_config()\n",
 | ||
|     "        self._load_model()\n",
 | ||
|     "        self._load_csv()\n",
 | ||
|     "        \n",
 | ||
|     "        print(f\"✅ 初始化完成!\")\n",
 | ||
|     "    \n",
 | ||
|     "    def _load_config(self):\n",
 | ||
|     "        \"\"\"加载模型配置\"\"\"\n",
 | ||
|     "        config_path = os.path.join(self.model_path, 'checkpoint/args.yaml')\n",
 | ||
|     "        if not os.path.exists(config_path):\n",
 | ||
|     "            raise FileNotFoundError(f\"配置文件不存在: {config_path}\")\n",
 | ||
|     "        \n",
 | ||
|     "        self.model_args = OmegaConf.load(config_path)\n",
 | ||
|     "        \n",
 | ||
|     "        print(f\"📋 模型配置:\")\n",
 | ||
|     "        print(f\"   Sessions数量: {len(self.model_args['dataset']['sessions'])}\")\n",
 | ||
|     "        print(f\"   神经特征维度: {self.model_args['model']['n_input_features']}\")\n",
 | ||
|     "        print(f\"   Patch size: {self.model_args['model']['patch_size']}\")\n",
 | ||
|     "        print(f\"   Patch stride: {self.model_args['model']['patch_stride']}\")\n",
 | ||
|     "        print(f\"   输出类别数: {self.model_args['dataset']['n_classes']}\")\n",
 | ||
|     "    \n",
 | ||
|     "    def _load_model(self):\n",
 | ||
|     "        \"\"\"加载预训练RNN模型\"\"\"\n",
 | ||
|     "        try:\n",
 | ||
|     "            # 创建模型\n",
 | ||
|     "            self.model = GRUDecoder(\n",
 | ||
|     "                neural_dim=self.model_args['model']['n_input_features'],\n",
 | ||
|     "                n_units=self.model_args['model']['n_units'], \n",
 | ||
|     "                n_days=len(self.model_args['dataset']['sessions']),\n",
 | ||
|     "                n_classes=self.model_args['dataset']['n_classes'],\n",
 | ||
|     "                rnn_dropout=self.model_args['model']['rnn_dropout'],\n",
 | ||
|     "                input_dropout=self.model_args['model']['input_network']['input_layer_dropout'],\n",
 | ||
|     "                n_layers=self.model_args['model']['n_layers'],\n",
 | ||
|     "                patch_size=self.model_args['model']['patch_size'],\n",
 | ||
|     "                patch_stride=self.model_args['model']['patch_stride'],\n",
 | ||
|     "            )\n",
 | ||
|     "            \n",
 | ||
|     "            # 加载权重\n",
 | ||
|     "            checkpoint_path = os.path.join(self.model_path, 'checkpoint/best_checkpoint')\n",
 | ||
|     "            try:\n",
 | ||
|     "                checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)\n",
 | ||
|     "            except TypeError:\n",
 | ||
|     "                checkpoint = torch.load(checkpoint_path, map_location=self.device)\n",
 | ||
|     "            \n",
 | ||
|     "            # 清理键名\n",
 | ||
|     "            for key in list(checkpoint['model_state_dict'].keys()):\n",
 | ||
|     "                checkpoint['model_state_dict'][key.replace(\"module.\", \"\")] = checkpoint['model_state_dict'].pop(key)\n",
 | ||
|     "                checkpoint['model_state_dict'][key.replace(\"_orig_mod.\", \"\")] = checkpoint['model_state_dict'].pop(key)\n",
 | ||
|     "            \n",
 | ||
|     "            self.model.load_state_dict(checkpoint['model_state_dict'])\n",
 | ||
|     "            self.model.to(self.device)\n",
 | ||
|     "            self.model.eval()\n",
 | ||
|     "            \n",
 | ||
|     "            print(f\"✅ 模型加载成功\")\n",
 | ||
|     "            \n",
 | ||
|     "        except Exception as e:\n",
 | ||
|     "            print(f\"❌ 模型加载失败: {e}\")\n",
 | ||
|     "            raise\n",
 | ||
|     "    \n",
 | ||
|     "    def _load_csv(self):\n",
 | ||
|     "        \"\"\"加载数据描述文件\"\"\"\n",
 | ||
|     "        if not os.path.exists(self.csv_path):\n",
 | ||
|     "            raise FileNotFoundError(f\"CSV文件不存在: {self.csv_path}\")\n",
 | ||
|     "        \n",
 | ||
|     "        self.csv_df = pd.read_csv(self.csv_path)\n",
 | ||
|     "        print(f\"📊 CSV数据加载完成: {len(self.csv_df)} 条记录\")\n",
 | ||
|     "    \n",
 | ||
|     "    def _process_single_trial(self, neural_data, session_idx):\n",
 | ||
|     "        \"\"\"\n",
 | ||
|     "        处理单个试验数据\n",
 | ||
|     "        \n",
 | ||
|     "        参数:\n",
 | ||
|     "            neural_data: 原始神经数据 [time_steps, features]\n",
 | ||
|     "            session_idx: 会话索引\n",
 | ||
|     "        \n",
 | ||
|     "        返回:\n",
 | ||
|     "            dict: 包含拼接数据和统计信息\n",
 | ||
|     "        \"\"\"\n",
 | ||
|     "        # 添加batch维度\n",
 | ||
|     "        neural_input = np.expand_dims(neural_data, axis=0)\n",
 | ||
|     "        neural_tensor = torch.tensor(neural_input, device=self.device, dtype=torch.bfloat16)\n",
 | ||
|     "        \n",
 | ||
|     "        # 高斯平滑\n",
 | ||
|     "        with torch.autocast(device_type=\"cuda\" if self.device.type == \"cuda\" else \"cpu\", \n",
 | ||
|     "                           enabled=self.model_args.get('use_amp', False), dtype=torch.bfloat16):\n",
 | ||
|     "            \n",
 | ||
|     "            smoothed_data = gauss_smooth(\n",
 | ||
|     "                inputs=neural_tensor,\n",
 | ||
|     "                device=self.device,\n",
 | ||
|     "                smooth_kernel_std=self.model_args['dataset']['data_transforms']['smooth_kernel_std'],\n",
 | ||
|     "                smooth_kernel_size=self.model_args['dataset']['data_transforms']['smooth_kernel_size'],\n",
 | ||
|     "                padding='valid',\n",
 | ||
|     "            )\n",
 | ||
|     "            \n",
 | ||
|     "            # Patch操作(复制模型内部逻辑)\n",
 | ||
|     "            processed_data = smoothed_data\n",
 | ||
|     "            if self.model.patch_size > 0:\n",
 | ||
|     "                processed_data = processed_data.unsqueeze(1)  # [batch, 1, time, features]\n",
 | ||
|     "                processed_data = processed_data.permute(0, 3, 1, 2)  # [batch, features, 1, time]\n",
 | ||
|     "                \n",
 | ||
|     "                # 滑动窗口提取\n",
 | ||
|     "                patches = processed_data.unfold(3, self.model.patch_size, self.model.patch_stride)\n",
 | ||
|     "                patches = patches.squeeze(2)  # [batch, features, patches, patch_size]\n",
 | ||
|     "                patches = patches.permute(0, 2, 3, 1)  # [batch, patches, patch_size, features]\n",
 | ||
|     "                \n",
 | ||
|     "                # 展平最后两个维度\n",
 | ||
|     "                processed_data = patches.reshape(patches.size(0), patches.size(1), -1)\n",
 | ||
|     "            \n",
 | ||
|     "            # RNN推理\n",
 | ||
|     "            with torch.no_grad():\n",
 | ||
|     "                logits, _ = self.model(\n",
 | ||
|     "                    x=smoothed_data,\n",
 | ||
|     "                    day_idx=torch.tensor([session_idx], device=self.device),\n",
 | ||
|     "                    states=None,\n",
 | ||
|     "                    return_state=True,\n",
 | ||
|     "                )\n",
 | ||
|     "        \n",
 | ||
|     "        # 转换为numpy\n",
 | ||
|     "        processed_features = processed_data.float().cpu().numpy()[0]  # [time_steps, processed_features]\n",
 | ||
|     "        confidence_scores = logits.float().cpu().numpy()[0]  # [time_steps, 40]\n",
 | ||
|     "        \n",
 | ||
|     "        # 拼接数据\n",
 | ||
|     "        concatenated = np.concatenate([processed_features, confidence_scores], axis=1)\n",
 | ||
|     "        \n",
 | ||
|     "        return {\n",
 | ||
|     "            'concatenated_data': concatenated,\n",
 | ||
|     "            'processed_features': processed_features,\n",
 | ||
|     "            'confidence_scores': confidence_scores,\n",
 | ||
|     "            'original_time_steps': neural_data.shape[0],\n",
 | ||
|     "            'processed_time_steps': concatenated.shape[0],\n",
 | ||
|     "            'feature_reduction_ratio': concatenated.shape[0] / neural_data.shape[0]\n",
 | ||
|     "        }\n",
 | ||
|     "    \n",
 | ||
|     "    def process_session(self, session_name, data_types=['train', 'val', 'test']):\n",
 | ||
|     "        \"\"\"\n",
 | ||
|     "        处理单个session的数据\n",
 | ||
|     "        \n",
 | ||
|     "        参数:\n",
 | ||
|     "            session_name: 会话名称\n",
 | ||
|     "            data_types: 要处理的数据类型列表\n",
 | ||
|     "        \n",
 | ||
|     "        返回:\n",
 | ||
|     "            dict: 处理结果\n",
 | ||
|     "        \"\"\"\n",
 | ||
|     "        print(f\"\\n🔄 处理会话: {session_name}\")\n",
 | ||
|     "        \n",
 | ||
|     "        session_idx = self.model_args['dataset']['sessions'].index(session_name)\n",
 | ||
|     "        session_results = {}\n",
 | ||
|     "        \n",
 | ||
|     "        for data_type in data_types:\n",
 | ||
|     "            data_file = os.path.join(self.data_dir, session_name, f'data_{data_type}.hdf5')\n",
 | ||
|     "            \n",
 | ||
|     "            if not os.path.exists(data_file):\n",
 | ||
|     "                print(f\"  ⚠️  {data_type} 数据文件不存在,跳过\")\n",
 | ||
|     "                continue\n",
 | ||
|     "            \n",
 | ||
|     "            print(f\"  📁 处理 {data_type} 数据...\")\n",
 | ||
|     "            \n",
 | ||
|     "            try:\n",
 | ||
|     "                # 加载数据\n",
 | ||
|     "                data = load_h5py_file(data_file, self.csv_df)\n",
 | ||
|     "                num_trials = len(data['neural_features'])\n",
 | ||
|     "                \n",
 | ||
|     "                if num_trials == 0:\n",
 | ||
|     "                    print(f\"    ⚠️  {data_type} 数据为空\")\n",
 | ||
|     "                    continue\n",
 | ||
|     "                \n",
 | ||
|     "                # 处理所有试验\n",
 | ||
|     "                results = {\n",
 | ||
|     "                    'concatenated_data': [],\n",
 | ||
|     "                    'processed_features': [],\n",
 | ||
|     "                    'confidence_scores': [],\n",
 | ||
|     "                    'trial_metadata': [],\n",
 | ||
|     "                    'processing_stats': []\n",
 | ||
|     "                }\n",
 | ||
|     "                \n",
 | ||
|     "                for trial_idx in tqdm(range(num_trials), desc=f\"    {data_type}\", leave=False):\n",
 | ||
|     "                    neural_data = data['neural_features'][trial_idx]\n",
 | ||
|     "                    \n",
 | ||
|     "                    # 处理单个试验\n",
 | ||
|     "                    trial_result = self._process_single_trial(neural_data, session_idx)\n",
 | ||
|     "                    \n",
 | ||
|     "                    # 保存结果\n",
 | ||
|     "                    results['concatenated_data'].append(trial_result['concatenated_data'])\n",
 | ||
|     "                    results['processed_features'].append(trial_result['processed_features'])\n",
 | ||
|     "                    results['confidence_scores'].append(trial_result['confidence_scores'])\n",
 | ||
|     "                    \n",
 | ||
|     "                    # 保存元数据\n",
 | ||
|     "                    metadata = {\n",
 | ||
|     "                        'session': session_name,\n",
 | ||
|     "                        'data_type': data_type,\n",
 | ||
|     "                        'trial_idx': trial_idx,\n",
 | ||
|     "                        'block_num': data.get('block_num', [None])[trial_idx],\n",
 | ||
|     "                        'trial_num': data.get('trial_num', [None])[trial_idx],\n",
 | ||
|     "                        **trial_result\n",
 | ||
|     "                    }\n",
 | ||
|     "                    \n",
 | ||
|     "                    # 添加真实标签(如果可用)\n",
 | ||
|     "                    if data_type in ['train', 'val'] and 'sentence_label' in data:\n",
 | ||
|     "                        metadata.update({\n",
 | ||
|     "                            'sentence_label': data['sentence_label'][trial_idx],\n",
 | ||
|     "                            'seq_class_ids': data['seq_class_ids'][trial_idx],\n",
 | ||
|     "                            'seq_len': data['seq_len'][trial_idx]\n",
 | ||
|     "                        })\n",
 | ||
|     "                    \n",
 | ||
|     "                    results['trial_metadata'].append(metadata)\n",
 | ||
|     "                    results['processing_stats'].append(trial_result)\n",
 | ||
|     "                \n",
 | ||
|     "                # 统计信息\n",
 | ||
|     "                if results['concatenated_data']:\n",
 | ||
|     "                    time_steps = [data.shape[0] for data in results['concatenated_data']]\n",
 | ||
|     "                    feature_dims = [data.shape[1] for data in results['concatenated_data']]\n",
 | ||
|     "                    \n",
 | ||
|     "                    print(f\"    ✅ {data_type} 处理完成:\")\n",
 | ||
|     "                    print(f\"       试验数: {len(results['concatenated_data'])}\")\n",
 | ||
|     "                    print(f\"       时间步范围: {min(time_steps)}-{max(time_steps)}\")\n",
 | ||
|     "                    print(f\"       特征维度: {feature_dims[0]} (处理后特征: {feature_dims[0]-40}, 置信度: 40)\")\n",
 | ||
|     "                    \n",
 | ||
|     "                    avg_reduction = np.mean([stat['feature_reduction_ratio'] for stat in results['processing_stats']])\n",
 | ||
|     "                    print(f\"       平均时间压缩比: {avg_reduction:.3f}\")\n",
 | ||
|     "                    \n",
 | ||
|     "                    session_results[data_type] = results\n",
 | ||
|     "                \n",
 | ||
|     "            except Exception as e:\n",
 | ||
|     "                print(f\"    ❌ {data_type} 处理失败: {e}\")\n",
 | ||
|     "                continue\n",
 | ||
|     "        \n",
 | ||
|     "        return session_results\n",
 | ||
|     "    \n",
 | ||
|     "    def process_all_sessions(self, data_types=['train', 'val', 'test'], save_dir='./rnn_processed_data'):\n",
 | ||
|     "        \"\"\"\n",
 | ||
|     "        批量处理所有sessions\n",
 | ||
|     "        \n",
 | ||
|     "        参数:\n",
 | ||
|     "            data_types: 要处理的数据类型\n",
 | ||
|     "            save_dir: 保存目录\n",
 | ||
|     "        \n",
 | ||
|     "        返回:\n",
 | ||
|     "            dict: 所有处理结果\n",
 | ||
|     "        \"\"\"\n",
 | ||
|     "        print(f\"\\n🚀 开始批量处理所有会话...\")\n",
 | ||
|     "        print(f\"   目标数据类型: {data_types}\")\n",
 | ||
|     "        print(f\"   保存目录: {save_dir}\")\n",
 | ||
|     "        \n",
 | ||
|     "        save_path = Path(save_dir)\n",
 | ||
|     "        save_path.mkdir(parents=True, exist_ok=True)\n",
 | ||
|     "        \n",
 | ||
|     "        all_results = {}\n",
 | ||
|     "        sessions = self.model_args['dataset']['sessions']\n",
 | ||
|     "        \n",
 | ||
|     "        start_time = time.time()\n",
 | ||
|     "        \n",
 | ||
|     "        for i, session in enumerate(sessions):\n",
 | ||
|     "            print(f\"\\n📊 进度: {i+1}/{len(sessions)}\")\n",
 | ||
|     "            \n",
 | ||
|     "            try:\n",
 | ||
|     "                session_results = self.process_session(session, data_types)\n",
 | ||
|     "                \n",
 | ||
|     "                if session_results:\n",
 | ||
|     "                    all_results[session] = session_results\n",
 | ||
|     "                    \n",
 | ||
|     "                    # 保存单个session结果\n",
 | ||
|     "                    for data_type, data in session_results.items():\n",
 | ||
|     "                        filename = f\"{session}_{data_type}_rnn_processed.npz\"\n",
 | ||
|     "                        filepath = save_path / filename\n",
 | ||
|     "                        \n",
 | ||
|     "                        save_data = {\n",
 | ||
|     "                            'concatenated_data': np.array(data['concatenated_data'], dtype=object),\n",
 | ||
|     "                            'processed_features': np.array(data['processed_features'], dtype=object),\n",
 | ||
|     "                            'confidence_scores': np.array(data['confidence_scores'], dtype=object),\n",
 | ||
|     "                            'trial_metadata': np.array(data['trial_metadata'], dtype=object),\n",
 | ||
|     "                        }\n",
 | ||
|     "                        \n",
 | ||
|     "                        np.savez_compressed(str(filepath), **save_data)\n",
 | ||
|     "                        print(f\"      💾 保存: {filename}\")\n",
 | ||
|     "                \n",
 | ||
|     "            except Exception as e:\n",
 | ||
|     "                print(f\"❌ 会话 {session} 处理失败: {e}\")\n",
 | ||
|     "                continue\n",
 | ||
|     "        \n",
 | ||
|     "        # 生成总结\n",
 | ||
|     "        end_time = time.time()\n",
 | ||
|     "        processing_time = end_time - start_time\n",
 | ||
|     "        \n",
 | ||
|     "        total_trials = sum(\n",
 | ||
|     "            len(session_data[data_type]['concatenated_data'])\n",
 | ||
|     "            for session_data in all_results.values()\n",
 | ||
|     "            for data_type in session_data.keys()\n",
 | ||
|     "        )\n",
 | ||
|     "        \n",
 | ||
|     "        print(f\"\\n🎉 批量处理完成!\")\n",
 | ||
|     "        print(f\"⏱️  总耗时: {processing_time/60:.2f} 分钟\")\n",
 | ||
|     "        print(f\"📊 处理统计:\")\n",
 | ||
|     "        print(f\"   成功会话: {len(all_results)}/{len(sessions)}\")\n",
 | ||
|     "        print(f\"   总试验数: {total_trials}\")\n",
 | ||
|     "        print(f\"💾 数据保存在: {save_dir}\")\n",
 | ||
|     "        \n",
 | ||
|     "        # 保存总结信息\n",
 | ||
|     "        summary = {\n",
 | ||
|     "            'processing_time': processing_time,\n",
 | ||
|     "            'total_sessions': len(all_results),\n",
 | ||
|     "            'total_trials': total_trials,\n",
 | ||
|     "            'data_types': data_types,\n",
 | ||
|     "            'sessions': list(all_results.keys()),\n",
 | ||
|     "            'model_config': {\n",
 | ||
|     "                'patch_size': self.model_args['model']['patch_size'],\n",
 | ||
|     "                'patch_stride': self.model_args['model']['patch_stride'],\n",
 | ||
|     "                'smooth_kernel_size': self.model_args['dataset']['data_transforms']['smooth_kernel_size'],\n",
 | ||
|     "                'smooth_kernel_std': self.model_args['dataset']['data_transforms']['smooth_kernel_std'],\n",
 | ||
|     "            }\n",
 | ||
|     "        }\n",
 | ||
|     "        \n",
 | ||
|     "        import json\n",
 | ||
|     "        with open(save_path / 'processing_summary.json', 'w') as f:\n",
 | ||
|     "            json.dump(summary, f, indent=2)\n",
 | ||
|     "        \n",
 | ||
|     "        return all_results\n",
 | ||
|     "\n",
 | ||
|     "# 创建处理器实例\n",
 | ||
|     "print(\"🔧 创建RNN数据处理器...\")\n",
 | ||
|     "\n",
 | ||
|     "try:\n",
 | ||
|     "    processor = RNNDataProcessor(\n",
 | ||
|     "        model_path='../data/t15_pretrained_rnn_baseline',\n",
 | ||
|     "        data_dir='../data/hdf5_data_final',\n",
 | ||
|     "        csv_path='../data/t15_copyTaskData_description.csv',\n",
 | ||
|     "        device='auto'\n",
 | ||
|     "    )\n",
 | ||
|     "    \n",
 | ||
|     "    print(f\"✅ RNN数据处理器创建成功!\")\n",
 | ||
|     "    \n",
 | ||
|     "except Exception as e:\n",
 | ||
|     "    print(f\"❌ 处理器创建失败: {e}\")\n",
 | ||
|     "    processor = None"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 25,
 | ||
|    "metadata": {},
 | ||
|    "outputs": [
 | ||
|     {
 | ||
|      "name": "stdout",
 | ||
|      "output_type": "stream",
 | ||
|      "text": [
 | ||
|       "======================================================================\n",
 | ||
|       "🎯 RNN数据批量处理 - 使用示例\n",
 | ||
|       "======================================================================\n",
 | ||
|       "\n",
 | ||
|       "📋 可用的处理方法:\n",
 | ||
|       "1️⃣  单session处理: processor.process_session('session_name')\n",
 | ||
|       "2️⃣  批量处理所有: processor.process_all_sessions()\n",
 | ||
|       "\n",
 | ||
|       "📊 可用会话数量: 45\n",
 | ||
|       "📝 前5个会话: ['t15.2023.08.11', 't15.2023.08.13', 't15.2023.08.18', 't15.2023.08.20', 't15.2023.08.25']\n",
 | ||
|       "\n",
 | ||
|       "🧪 快速测试: 处理会话 't15.2023.08.13' 的训练数据...\n",
 | ||
|       "\n",
 | ||
|       "🔄 处理会话: t15.2023.08.13\n",
 | ||
|       "  📁 处理 train 数据...\n"
 | ||
|      ]
 | ||
|     },
 | ||
|     {
 | ||
|      "name": "stderr",
 | ||
|      "output_type": "stream",
 | ||
|      "text": [
 | ||
|       "                                                            "
 | ||
|      ]
 | ||
|     },
 | ||
|     {
 | ||
|      "name": "stdout",
 | ||
|      "output_type": "stream",
 | ||
|      "text": [
 | ||
|       "    ✅ train 处理完成:\n",
 | ||
|       "       试验数: 348\n",
 | ||
|       "       时间步范围: 55-352\n",
 | ||
|       "       特征维度: 7209 (处理后特征: 7169, 置信度: 40)\n",
 | ||
|       "       平均时间压缩比: 0.243\n",
 | ||
|       "\n",
 | ||
|       "✅ 测试完成!结果概览:\n",
 | ||
|       "   处理的试验数: 348\n",
 | ||
|       "   第一个试验数据形状: (251, 7209)\n",
 | ||
|       "   特征维度详情:\n",
 | ||
|       "     - 处理后的神经特征: 7168 维\n",
 | ||
|       "     - RNN置信度分数: 41 维\n",
 | ||
|       "     - 总拼接特征: 7209 维\n",
 | ||
|       "     - 时间步数: 251\n",
 | ||
|       "   样本元数据:\n",
 | ||
|       "     - 原始时间步: 1023\n",
 | ||
|       "     - 处理后时间步: 251\n",
 | ||
|       "     - 时间压缩比: 0.245\n",
 | ||
|       "     - 句子标签: Which is most unfortunate because we all lose out.\n",
 | ||
|       "\n",
 | ||
|       "💡 要批量处理所有数据,运行:\n",
 | ||
|       "    results = processor.process_all_sessions()\n",
 | ||
|       "    # 这将处理所有45个sessions的train/val/test数据\n"
 | ||
|      ]
 | ||
|     },
 | ||
|     {
 | ||
|      "name": "stderr",
 | ||
|      "output_type": "stream",
 | ||
|      "text": [
 | ||
|       "\r"
 | ||
|      ]
 | ||
|     }
 | ||
|    ],
 | ||
|    "source": [
 | ||
|     "# 🎯 使用示例和批量处理\n",
 | ||
|     "\n",
 | ||
|     "print(\"=\"*70)\n",
 | ||
|     "print(\"🎯 RNN数据批量处理 - 使用示例\")\n",
 | ||
|     "print(\"=\"*70)\n",
 | ||
|     "\n",
 | ||
|     "if processor is not None:\n",
 | ||
|     "    \n",
 | ||
|     "    # 方法1: 处理单个session (推荐用于测试)\n",
 | ||
|     "    print(\"\\n📋 可用的处理方法:\")\n",
 | ||
|     "    print(\"1️⃣  单session处理: processor.process_session('session_name')\")\n",
 | ||
|     "    print(\"2️⃣  批量处理所有: processor.process_all_sessions()\")\n",
 | ||
|     "    \n",
 | ||
|     "    # 显示可用的sessions\n",
 | ||
|     "    sessions = processor.model_args['dataset']['sessions']\n",
 | ||
|     "    print(f\"\\n📊 可用会话数量: {len(sessions)}\")\n",
 | ||
|     "    print(f\"📝 前5个会话: {sessions[:5]}\")\n",
 | ||
|     "    \n",
 | ||
|     "    # 快速测试 - 处理第一个session的部分数据\n",
 | ||
|     "    test_session = sessions[1]  # 't15.2023.08.11'\n",
 | ||
|     "    \n",
 | ||
|     "    print(f\"\\n🧪 快速测试: 处理会话 '{test_session}' 的训练数据...\")\n",
 | ||
|     "    \n",
 | ||
|     "    # 处理单个session(仅train数据进行测试)\n",
 | ||
|     "    single_result = processor.process_session(test_session, ['train'])\n",
 | ||
|     "    \n",
 | ||
|     "    if single_result and 'train' in single_result:\n",
 | ||
|     "        train_data = single_result['train']\n",
 | ||
|     "        \n",
 | ||
|     "        print(f\"\\n✅ 测试完成!结果概览:\")\n",
 | ||
|     "        print(f\"   处理的试验数: {len(train_data['concatenated_data'])}\")\n",
 | ||
|     "        \n",
 | ||
|     "        if len(train_data['concatenated_data']) > 0:\n",
 | ||
|     "            sample_data = train_data['concatenated_data'][0]\n",
 | ||
|     "            print(f\"   第一个试验数据形状: {sample_data.shape}\")\n",
 | ||
|     "            print(f\"   特征维度详情:\")\n",
 | ||
|     "            print(f\"     - 处理后的神经特征: {sample_data.shape[1] - 41} 维\")\n",
 | ||
|     "            print(f\"     - RNN置信度分数: 41 维\")\n",
 | ||
|     "            print(f\"     - 总拼接特征: {sample_data.shape[1]} 维\")\n",
 | ||
|     "            print(f\"     - 时间步数: {sample_data.shape[0]}\")\n",
 | ||
|     "            \n",
 | ||
|     "            # 显示一些样本元数据\n",
 | ||
|     "            sample_metadata = train_data['trial_metadata'][0]\n",
 | ||
|     "            print(f\"   样本元数据:\")\n",
 | ||
|     "            print(f\"     - 原始时间步: {sample_metadata['original_time_steps']}\")\n",
 | ||
|     "            print(f\"     - 处理后时间步: {sample_metadata['processed_time_steps']}\")\n",
 | ||
|     "            print(f\"     - 时间压缩比: {sample_metadata['feature_reduction_ratio']:.3f}\")\n",
 | ||
|     "            \n",
 | ||
|     "            if 'sentence_label' in sample_metadata:\n",
 | ||
|     "                print(f\"     - 句子标签: {sample_metadata['sentence_label']}\")\n",
 | ||
|     "    \n",
 | ||
|     "    print(f\"\\n💡 要批量处理所有数据,运行:\")\n",
 | ||
|     "    print(f\"    results = processor.process_all_sessions()\")\n",
 | ||
|     "    print(f\"    # 这将处理所有45个sessions的train/val/test数据\")\n",
 | ||
|     "    \n",
 | ||
|     "else:\n",
 | ||
|     "    print(\"❌ 处理器未创建成功,请检查上面的错误信息\")"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 26,
 | ||
|    "metadata": {},
 | ||
|    "outputs": [
 | ||
|     {
 | ||
|      "name": "stdout",
 | ||
|      "output_type": "stream",
 | ||
|      "text": [
 | ||
|       "======================================================================\n",
 | ||
|       "🚀 批量处理选项\n",
 | ||
|       "======================================================================\n",
 | ||
|       "📊 批量处理配置:\n",
 | ||
|       "   启用批量处理: False\n",
 | ||
|       "   保存目录: ./rnn_processed_data\n",
 | ||
|       "   数据类型: ['train', 'val', 'test']\n",
 | ||
|       "   总会话数: 45\n",
 | ||
|       "\n",
 | ||
|       "💡 要开始批量处理,请将 ENABLE_FULL_PROCESSING 设为 True\n",
 | ||
|       "   或者手动运行: processor.process_all_sessions()\n",
 | ||
|       "\n",
 | ||
|       "📋 数据使用说明:\n",
 | ||
|       "✅ 处理完成后,每个文件包含:\n",
 | ||
|       "   - concatenated_data: 拼接后的特征 [神经特征(7168) + 置信度(41)]\n",
 | ||
|       "   - processed_features: 仅处理后的神经特征\n",
 | ||
|       "   - confidence_scores: 仅RNN输出的41类置信度分数\n",
 | ||
|       "   - trial_metadata: 试验元数据(标签、时间步等)\n",
 | ||
|       "\n",
 | ||
|       "🔧 加载保存的数据:\n",
 | ||
|       "   data = np.load('session_name_train_rnn_processed.npz', allow_pickle=True)\n",
 | ||
|       "   features = data['concatenated_data']  # 用于训练分类器\n",
 | ||
|       "   metadata = data['trial_metadata']    # 获取标签和其他信息\n"
 | ||
|      ]
 | ||
|     }
 | ||
|    ],
 | ||
|    "source": [
 | ||
|     "# 🚀 批量处理所有数据 (可选择运行)\n",
 | ||
|     "\n",
 | ||
|     "print(\"=\"*70)\n",
 | ||
|     "print(\"🚀 批量处理选项\")\n",
 | ||
|     "print(\"=\"*70)\n",
 | ||
|     "\n",
 | ||
|     "# 设置参数\n",
 | ||
|     "ENABLE_FULL_PROCESSING = False  # 设为True开始批量处理\n",
 | ||
|     "SAVE_DIR = \"./rnn_processed_data\"  # 保存目录\n",
 | ||
|     "DATA_TYPES = ['train', 'val', 'test']  # 要处理的数据类型\n",
 | ||
|     "\n",
 | ||
|     "print(f\"📊 批量处理配置:\")\n",
 | ||
|     "print(f\"   启用批量处理: {ENABLE_FULL_PROCESSING}\")\n",
 | ||
|     "print(f\"   保存目录: {SAVE_DIR}\")\n",
 | ||
|     "print(f\"   数据类型: {DATA_TYPES}\")\n",
 | ||
|     "print(f\"   总会话数: {len(processor.model_args['dataset']['sessions'])}\")\n",
 | ||
|     "\n",
 | ||
|     "if ENABLE_FULL_PROCESSING and processor is not None:\n",
 | ||
|     "    print(f\"\\n🚀 开始批量处理所有数据...\")\n",
 | ||
|     "    print(f\"⚠️  这可能需要较长时间(预计30-60分钟)\")\n",
 | ||
|     "    \n",
 | ||
|     "    # 批量处理\n",
 | ||
|     "    all_results = processor.process_all_sessions(\n",
 | ||
|     "        data_types=DATA_TYPES,\n",
 | ||
|     "        save_dir=SAVE_DIR\n",
 | ||
|     "    )\n",
 | ||
|     "    \n",
 | ||
|     "    print(f\"🎉 批量处理完成!结果保存在: {SAVE_DIR}\")\n",
 | ||
|     "    \n",
 | ||
|     "else:\n",
 | ||
|     "    print(f\"\\n💡 要开始批量处理,请将 ENABLE_FULL_PROCESSING 设为 True\")\n",
 | ||
|     "    print(f\"   或者手动运行: processor.process_all_sessions()\")\n",
 | ||
|     "\n",
 | ||
|     "print(f\"\\n📋 数据使用说明:\")\n",
 | ||
|     "print(f\"✅ 处理完成后,每个文件包含:\")\n",
 | ||
|     "print(f\"   - concatenated_data: 拼接后的特征 [神经特征(7168) + 置信度(41)]\")\n",
 | ||
|     "print(f\"   - processed_features: 仅处理后的神经特征\")\n",
 | ||
|     "print(f\"   - confidence_scores: 仅RNN输出的41类置信度分数\")\n",
 | ||
|     "print(f\"   - trial_metadata: 试验元数据(标签、时间步等)\")\n",
 | ||
|     "print(f\"\")\n",
 | ||
|     "print(f\"🔧 加载保存的数据:\")\n",
 | ||
|     "print(f\"   data = np.load('session_name_train_rnn_processed.npz', allow_pickle=True)\")\n",
 | ||
|     "print(f\"   features = data['concatenated_data']  # 用于训练分类器\")\n",
 | ||
|     "print(f\"   metadata = data['trial_metadata']    # 获取标签和其他信息\")"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 27,
 | ||
|    "metadata": {},
 | ||
|    "outputs": [
 | ||
|     {
 | ||
|      "data": {
 | ||
|       "text/plain": [
 | ||
|        "(212, 7209)"
 | ||
|       ]
 | ||
|      },
 | ||
|      "execution_count": 27,
 | ||
|      "metadata": {},
 | ||
|      "output_type": "execute_result"
 | ||
|     }
 | ||
|    ],
 | ||
|    "source": [
 | ||
|     "single_result['train']['concatenated_data'][2].shape"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "metadata": {},
 | ||
|    "source": [
 | ||
|     "# 模型建立"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": 28,
 | ||
|    "metadata": {},
 | ||
|    "outputs": [
 | ||
|     {
 | ||
|      "name": "stdout",
 | ||
|      "output_type": "stream",
 | ||
|      "text": [
 | ||
|       "🌲 初始化随机森林回归模型\n",
 | ||
|       "🌲 随机森林回归器初始化:\n",
 | ||
|       "   时间窗口大小: 30\n",
 | ||
|       "   树的数量: 100\n",
 | ||
|       "   最大深度: 10\n",
 | ||
|       "   并行任务: -1\n",
 | ||
|       "\n",
 | ||
|       "✅ 随机森林回归器准备完成!\n",
 | ||
|       "🔧 下一步: 准备训练数据和开始训练\n"
 | ||
|      ]
 | ||
|     }
 | ||
|    ],
 | ||
|    "source": [
 | ||
|     "# 🌲 随机森林多输出回归模型实现\n",
 | ||
|     "import numpy as np\n",
 | ||
|     "from sklearn.ensemble import RandomForestRegressor\n",
 | ||
|     "from sklearn.model_selection import train_test_split\n",
 | ||
|     "from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error\n",
 | ||
|     "from sklearn.preprocessing import StandardScaler\n",
 | ||
|     "import matplotlib.pyplot as plt\n",
 | ||
|     "import seaborn as sns\n",
 | ||
|     "from tqdm import tqdm\n",
 | ||
|     "import warnings\n",
 | ||
|     "warnings.filterwarnings('ignore')\n",
 | ||
|     "\n",
 | ||
|     "class TimeWindowRandomForestRegressor:\n",
 | ||
|     "    \"\"\"\n",
 | ||
|     "    基于时间窗口的随机森林多输出回归器\n",
 | ||
|     "    用于预测40个音素的概率分布\n",
 | ||
|     "    \"\"\"\n",
 | ||
|     "    \n",
 | ||
|     "    def __init__(self, window_size=30, n_estimators=100, max_depth=10, n_jobs=-1, random_state=42):\n",
 | ||
|     "        \"\"\"\n",
 | ||
|     "        初始化模型\n",
 | ||
|     "        \n",
 | ||
|     "        参数:\n",
 | ||
|     "            window_size: 时间窗口大小\n",
 | ||
|     "            n_estimators: 随机森林中树的数量\n",
 | ||
|     "            max_depth: 树的最大深度\n",
 | ||
|     "            n_jobs: 并行任务数\n",
 | ||
|     "            random_state: 随机种子\n",
 | ||
|     "        \"\"\"\n",
 | ||
|     "        self.window_size = window_size\n",
 | ||
|     "        self.n_estimators = n_estimators\n",
 | ||
|     "        self.max_depth = max_depth\n",
 | ||
|     "        self.n_jobs = n_jobs\n",
 | ||
|     "        self.random_state = random_state\n",
 | ||
|     "        \n",
 | ||
|     "        # 初始化模型和预处理器\n",
 | ||
|     "        self.regressor = RandomForestRegressor(\n",
 | ||
|     "            n_estimators=n_estimators,\n",
 | ||
|     "            max_depth=max_depth,\n",
 | ||
|     "            n_jobs=n_jobs,\n",
 | ||
|     "            random_state=random_state,\n",
 | ||
|     "            verbose=1\n",
 | ||
|     "        )\n",
 | ||
|     "        \n",
 | ||
|     "        self.scaler = StandardScaler()\n",
 | ||
|     "        self.is_fitted = False\n",
 | ||
|     "        \n",
 | ||
|     "        print(f\"🌲 随机森林回归器初始化:\")\n",
 | ||
|     "        print(f\"   时间窗口大小: {window_size}\")\n",
 | ||
|     "        print(f\"   树的数量: {n_estimators}\")\n",
 | ||
|     "        print(f\"   最大深度: {max_depth}\")\n",
 | ||
|     "        print(f\"   并行任务: {n_jobs}\")\n",
 | ||
|     "    \n",
 | ||
|     "    def create_time_windows(self, neural_features, phoneme_targets=None):\n",
 | ||
|     "        \"\"\"\n",
 | ||
|     "        创建时间窗口特征\n",
 | ||
|     "        \n",
 | ||
|     "        参数:\n",
 | ||
|     "            neural_features: 神经特征 [time_steps, 512]\n",
 | ||
|     "            phoneme_targets: 音素目标 [time_steps, 40] (可选)\n",
 | ||
|     "        \n",
 | ||
|     "        返回:\n",
 | ||
|     "            windowed_features: [samples, window_size * 512]\n",
 | ||
|     "            windowed_targets: [samples, 40] (如果提供targets)\n",
 | ||
|     "        \"\"\"\n",
 | ||
|     "        if len(neural_features) < self.window_size:\n",
 | ||
|     "            print(f\"⚠️ 数据长度 {len(neural_features)} 小于窗口大小 {self.window_size}\")\n",
 | ||
|     "            return None, None\n",
 | ||
|     "        \n",
 | ||
|     "        n_samples = len(neural_features) - self.window_size + 1\n",
 | ||
|     "        n_features = neural_features.shape[1]\n",
 | ||
|     "        \n",
 | ||
|     "        # 创建时间窗口特征\n",
 | ||
|     "        windowed_features = np.zeros((n_samples, self.window_size * n_features))\n",
 | ||
|     "        \n",
 | ||
|     "        for i in range(n_samples):\n",
 | ||
|     "            # 展平时间窗口内的所有特征\n",
 | ||
|     "            window_data = neural_features[i:i+self.window_size].flatten()\n",
 | ||
|     "            windowed_features[i] = window_data\n",
 | ||
|     "        \n",
 | ||
|     "        windowed_targets = None\n",
 | ||
|     "        if phoneme_targets is not None:\n",
 | ||
|     "            # 使用窗口中心点的音素概率作为目标\n",
 | ||
|     "            center_offset = self.window_size // 2\n",
 | ||
|     "            windowed_targets = phoneme_targets[center_offset:center_offset+n_samples]\n",
 | ||
|     "        \n",
 | ||
|     "        return windowed_features, windowed_targets\n",
 | ||
|     "    \n",
 | ||
|     "    def prepare_dataset_for_training(self, datasets_list, dataset_type=\"train\"):\n",
 | ||
|     "        \"\"\"\n",
 | ||
|     "        准备训练数据集\n",
 | ||
|     "        \n",
 | ||
|     "        参数:\n",
 | ||
|     "            datasets_list: DataFrame列表 (train_datasets, val_datasets, etc.)\n",
 | ||
|     "            dataset_type: 数据集类型名称\n",
 | ||
|     "        \n",
 | ||
|     "        返回:\n",
 | ||
|     "            X: 特征矩阵 [总样本数, window_size * 512]\n",
 | ||
|     "            y: 目标矩阵 [总样本数, 40]\n",
 | ||
|     "        \"\"\"\n",
 | ||
|     "        print(f\"\\n📊 准备{dataset_type}数据集:\")\n",
 | ||
|     "        print(f\"   输入数据集数量: {len(datasets_list)}\")\n",
 | ||
|     "        \n",
 | ||
|     "        all_X = []\n",
 | ||
|     "        all_y = []\n",
 | ||
|     "        \n",
 | ||
|     "        for i, df in enumerate(tqdm(datasets_list, desc=f\"处理{dataset_type}数据\")):\n",
 | ||
|     "            # 提取神经特征 (前512列)\n",
 | ||
|     "            neural_cols = [col for col in df.columns if col.startswith('neural_feat_')]\n",
 | ||
|     "            neural_features = df[neural_cols].values\n",
 | ||
|     "            \n",
 | ||
|     "            # 提取音素目标 (40列音素概率)\n",
 | ||
|     "            phoneme_cols = [col for col in df.columns if col.startswith('phoneme_')]\n",
 | ||
|     "            phoneme_targets = df[phoneme_cols].values\n",
 | ||
|     "            \n",
 | ||
|     "            # 按trial分组处理\n",
 | ||
|     "            trials = df['trial_idx'].unique()\n",
 | ||
|     "            \n",
 | ||
|     "            for trial_idx in trials:\n",
 | ||
|     "                trial_mask = df['trial_idx'] == trial_idx\n",
 | ||
|     "                trial_neural = neural_features[trial_mask]\n",
 | ||
|     "                trial_phonemes = phoneme_targets[trial_mask]\n",
 | ||
|     "                \n",
 | ||
|     "                # 创建时间窗口\n",
 | ||
|     "                windowed_X, windowed_y = self.create_time_windows(trial_neural, trial_phonemes)\n",
 | ||
|     "                \n",
 | ||
|     "                if windowed_X is not None and windowed_y is not None:\n",
 | ||
|     "                    all_X.append(windowed_X)\n",
 | ||
|     "                    all_y.append(windowed_y)\n",
 | ||
|     "        \n",
 | ||
|     "        if not all_X:\n",
 | ||
|     "            print(f\"❌ 没有有效的{dataset_type}数据\")\n",
 | ||
|     "            return None, None\n",
 | ||
|     "        \n",
 | ||
|     "        # 合并所有数据\n",
 | ||
|     "        X = np.vstack(all_X)\n",
 | ||
|     "        y = np.vstack(all_y)\n",
 | ||
|     "        \n",
 | ||
|     "        print(f\"   ✅ {dataset_type}数据准备完成:\")\n",
 | ||
|     "        print(f\"      特征矩阵形状: {X.shape}\")\n",
 | ||
|     "        print(f\"      目标矩阵形状: {y.shape}\")\n",
 | ||
|     "        print(f\"      内存使用: {X.nbytes / 1024**2:.1f} MB (X) + {y.nbytes / 1024**2:.1f} MB (y)\")\n",
 | ||
|     "        \n",
 | ||
|     "        return X, y\n",
 | ||
|     "    \n",
 | ||
|     "    def fit(self, X_train, y_train, X_val=None, y_val=None):\n",
 | ||
|     "        \"\"\"\n",
 | ||
|     "        训练模型\n",
 | ||
|     "        \n",
 | ||
|     "        参数:\n",
 | ||
|     "            X_train: 训练特征\n",
 | ||
|     "            y_train: 训练目标\n",
 | ||
|     "            X_val: 验证特征 (可选)\n",
 | ||
|     "            y_val: 验证目标 (可选)\n",
 | ||
|     "        \"\"\"\n",
 | ||
|     "        print(f\"\\n🚀 开始训练随机森林回归模型:\")\n",
 | ||
|     "        print(f\"   训练样本数: {X_train.shape[0]:,}\")\n",
 | ||
|     "        print(f\"   特征维度: {X_train.shape[1]:,}\")\n",
 | ||
|     "        print(f\"   目标维度: {y_train.shape[1]}\")\n",
 | ||
|     "        \n",
 | ||
|     "        # 标准化特征\n",
 | ||
|     "        print(\"   🔄 标准化特征...\")\n",
 | ||
|     "        X_train_scaled = self.scaler.fit_transform(X_train)\n",
 | ||
|     "        \n",
 | ||
|     "        # 训练模型\n",
 | ||
|     "        print(\"   🌲 训练随机森林...\")\n",
 | ||
|     "        self.regressor.fit(X_train_scaled, y_train)\n",
 | ||
|     "        \n",
 | ||
|     "        self.is_fitted = True\n",
 | ||
|     "        \n",
 | ||
|     "        # 计算训练集性能\n",
 | ||
|     "        print(\"   📊 评估训练集性能...\")\n",
 | ||
|     "        train_predictions = self.regressor.predict(X_train_scaled)\n",
 | ||
|     "        train_mse = mean_squared_error(y_train, train_predictions)\n",
 | ||
|     "        train_r2 = r2_score(y_train, train_predictions)\n",
 | ||
|     "        train_mae = mean_absolute_error(y_train, train_predictions)\n",
 | ||
|     "        \n",
 | ||
|     "        print(f\"   ✅ 训练完成!\")\n",
 | ||
|     "        print(f\"      训练集 MSE: {train_mse:.6f}\")\n",
 | ||
|     "        print(f\"      训练集 R²: {train_r2:.4f}\")\n",
 | ||
|     "        print(f\"      训练集 MAE: {train_mae:.6f}\")\n",
 | ||
|     "        \n",
 | ||
|     "        # 如果有验证集,计算验证集性能\n",
 | ||
|     "        if X_val is not None and y_val is not None:\n",
 | ||
|     "            print(\"   📊 评估验证集性能...\")\n",
 | ||
|     "            X_val_scaled = self.scaler.transform(X_val)\n",
 | ||
|     "            val_predictions = self.regressor.predict(X_val_scaled)\n",
 | ||
|     "            val_mse = mean_squared_error(y_val, val_predictions)\n",
 | ||
|     "            val_r2 = r2_score(y_val, val_predictions)\n",
 | ||
|     "            val_mae = mean_absolute_error(y_val, val_predictions)\n",
 | ||
|     "            \n",
 | ||
|     "            print(f\"      验证集 MSE: {val_mse:.6f}\")\n",
 | ||
|     "            print(f\"      验证集 R²: {val_r2:.4f}\")\n",
 | ||
|     "            print(f\"      验证集 MAE: {val_mae:.6f}\")\n",
 | ||
|     "            \n",
 | ||
|     "            return {\n",
 | ||
|     "                'train_mse': train_mse, 'train_r2': train_r2, 'train_mae': train_mae,\n",
 | ||
|     "                'val_mse': val_mse, 'val_r2': val_r2, 'val_mae': val_mae\n",
 | ||
|     "            }\n",
 | ||
|     "        \n",
 | ||
|     "        return {\n",
 | ||
|     "            'train_mse': train_mse, 'train_r2': train_r2, 'train_mae': train_mae\n",
 | ||
|     "        }\n",
 | ||
|     "    \n",
 | ||
|     "    def predict(self, X):\n",
 | ||
|     "        \"\"\"预测\"\"\"\n",
 | ||
|     "        if not self.is_fitted:\n",
 | ||
|     "            raise ValueError(\"模型尚未训练,请先调用fit()方法\")\n",
 | ||
|     "        \n",
 | ||
|     "        X_scaled = self.scaler.transform(X)\n",
 | ||
|     "        return self.regressor.predict(X_scaled)\n",
 | ||
|     "    \n",
 | ||
|     "    def get_feature_importance(self, top_k=20):\n",
 | ||
|     "        \"\"\"获取特征重要性\"\"\"\n",
 | ||
|     "        if not self.is_fitted:\n",
 | ||
|     "            raise ValueError(\"模型尚未训练,请先调用fit()方法\")\n",
 | ||
|     "        \n",
 | ||
|     "        importances = self.regressor.feature_importances_\n",
 | ||
|     "        \n",
 | ||
|     "        # 创建特征名称 (window_timestep_feature)\n",
 | ||
|     "        feature_names = []\n",
 | ||
|     "        for t in range(self.window_size):\n",
 | ||
|     "            for f in range(512):\n",
 | ||
|     "                feature_names.append(f\"t{t}_feat{f}\")\n",
 | ||
|     "        \n",
 | ||
|     "        # 获取top-k重要特征\n",
 | ||
|     "        top_indices = np.argsort(importances)[::-1][:top_k]\n",
 | ||
|     "        top_features = [(feature_names[i], importances[i]) for i in top_indices]\n",
 | ||
|     "        \n",
 | ||
|     "        return top_features, importances\n",
 | ||
|     "\n",
 | ||
|     "# 初始化模型\n",
 | ||
|     "print(\"🌲 初始化随机森林回归模型\")\n",
 | ||
|     "rf_regressor = TimeWindowRandomForestRegressor(\n",
 | ||
|     "    window_size=30,      # 时间窗口大小\n",
 | ||
|     "    n_estimators=100,    # 树的数量\n",
 | ||
|     "    max_depth=10,        # 最大深度 (防止过拟合)\n",
 | ||
|     "    n_jobs=-1,           # 使用所有CPU核心\n",
 | ||
|     "    random_state=42\n",
 | ||
|     ")\n",
 | ||
|     "\n",
 | ||
|     "print(\"\\n✅ 随机森林回归器准备完成!\")\n",
 | ||
|     "print(\"🔧 下一步: 准备训练数据和开始训练\")"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": null,
 | ||
|    "metadata": {},
 | ||
|    "outputs": [
 | ||
|     {
 | ||
|      "name": "stdout",
 | ||
|      "output_type": "stream",
 | ||
|      "text": [
 | ||
|       "🚀 开始数据准备和模型训练流程\n",
 | ||
|       "============================================================\n"
 | ||
|      ]
 | ||
|     },
 | ||
|     {
 | ||
|      "ename": "NameError",
 | ||
|      "evalue": "name 'train_datasets' is not defined",
 | ||
|      "output_type": "error",
 | ||
|      "traceback": [
 | ||
|       "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
 | ||
|       "\u001b[0;31mNameError\u001b[0m                                 Traceback (most recent call last)",
 | ||
|       "\u001b[0;32m/tmp/ipykernel_37/3627267466.py\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[1;32m      4\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      5\u001b[0m \u001b[0;31m# 检查数据可用性\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 6\u001b[0;31m \u001b[0;32mif\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0mtrain_datasets\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m      7\u001b[0m     \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"❌ 没有可用的训练数据,请先运行数据处理工作流\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m      8\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
 | ||
|       "\u001b[0;31mNameError\u001b[0m: name 'train_datasets' is not defined"
 | ||
|      ]
 | ||
|     }
 | ||
|    ],
 | ||
|    "source": [
 | ||
|     "# 🚀 准备数据并训练随机森林回归模型\n",
 | ||
|     "print(\"🚀 开始数据准备和模型训练流程\")\n",
 | ||
|     "print(\"=\"*60)\n",
 | ||
|     "\n",
 | ||
|     "# 检查数据可用性\n",
 | ||
|     "if not train_datasets:\n",
 | ||
|     "    print(\"❌ 没有可用的训练数据,请先运行数据处理工作流\")\n",
 | ||
|     "else:\n",
 | ||
|     "    print(f\"✅ 检测到数据:\")\n",
 | ||
|     "    print(f\"   训练数据集: {len(train_datasets)} 个sessions\")\n",
 | ||
|     "    print(f\"   验证数据集: {len(val_datasets)} 个sessions\")\n",
 | ||
|     "    print(f\"   测试数据集: {len(test_datasets)} 个sessions\")\n",
 | ||
|     "\n",
 | ||
|     "    # 1. 准备训练数据\n",
 | ||
|     "    print(f\"\\n📊 第1步: 准备训练数据\")\n",
 | ||
|     "    X_train, y_train = rf_regressor.prepare_dataset_for_training(train_datasets, \"训练集\")\n",
 | ||
|     "    \n",
 | ||
|     "    # 2. 准备验证数据\n",
 | ||
|     "    print(f\"\\n📊 第2步: 准备验证数据\")\n",
 | ||
|     "    X_val, y_val = rf_regressor.prepare_dataset_for_training(val_datasets, \"验证集\")\n",
 | ||
|     "    \n",
 | ||
|     "    if X_train is not None and y_train is not None:\n",
 | ||
|     "        print(f\"\\n📈 数据准备完成统计:\")\n",
 | ||
|     "        print(f\"   训练集: {X_train.shape[0]:,} 样本\")\n",
 | ||
|     "        print(f\"   验证集: {X_val.shape[0]:,} 样本\" if X_val is not None else \"   验证集: 无\")\n",
 | ||
|     "        print(f\"   特征维度: {X_train.shape[1]:,} (时间窗口30 × 512特征)\")\n",
 | ||
|     "        print(f\"   目标维度: {y_train.shape[1]} (40个音素概率)\")\n",
 | ||
|     "        \n",
 | ||
|     "        # 检查数据质量\n",
 | ||
|     "        print(f\"\\n🔍 数据质量检查:\")\n",
 | ||
|     "        print(f\"   训练特征范围: [{X_train.min():.4f}, {X_train.max():.4f}]\")\n",
 | ||
|     "        print(f\"   训练目标范围: [{y_train.min():.4f}, {y_train.max():.4f}]\")\n",
 | ||
|     "        print(f\"   训练特征均值: {X_train.mean():.4f}\")\n",
 | ||
|     "        print(f\"   训练目标均值: {y_train.mean():.4f}\")\n",
 | ||
|     "        \n",
 | ||
|     "        # 检查是否有NaN或无穷值\n",
 | ||
|     "        nan_count_X = np.isnan(X_train).sum()\n",
 | ||
|     "        nan_count_y = np.isnan(y_train).sum()\n",
 | ||
|     "        inf_count_X = np.isinf(X_train).sum()\n",
 | ||
|     "        inf_count_y = np.isinf(y_train).sum()\n",
 | ||
|     "        \n",
 | ||
|     "        print(f\"   NaN检查: X有{nan_count_X}个, y有{nan_count_y}个\")\n",
 | ||
|     "        print(f\"   Inf检查: X有{inf_count_X}个, y有{inf_count_y}个\")\n",
 | ||
|     "        \n",
 | ||
|     "        if nan_count_X > 0 or nan_count_y > 0 or inf_count_X > 0 or inf_count_y > 0:\n",
 | ||
|     "            print(\"⚠️ 检测到异常值,将进行清理...\")\n",
 | ||
|     "            # 清理异常值\n",
 | ||
|     "            valid_mask = ~(np.isnan(X_train).any(axis=1) | np.isnan(y_train).any(axis=1) | \n",
 | ||
|     "                          np.isinf(X_train).any(axis=1) | np.isinf(y_train).any(axis=1))\n",
 | ||
|     "            X_train = X_train[valid_mask]\n",
 | ||
|     "            y_train = y_train[valid_mask]\n",
 | ||
|     "            \n",
 | ||
|     "            if X_val is not None and y_val is not None:\n",
 | ||
|     "                valid_mask_val = ~(np.isnan(X_val).any(axis=1) | np.isnan(y_val).any(axis=1) | \n",
 | ||
|     "                                  np.isinf(X_val).any(axis=1) | np.isinf(y_val).any(axis=1))\n",
 | ||
|     "                X_val = X_val[valid_mask_val]\n",
 | ||
|     "                y_val = y_val[valid_mask_val]\n",
 | ||
|     "            \n",
 | ||
|     "            print(f\"✅ 数据清理完成,剩余训练样本: {X_train.shape[0]:,}\")\n",
 | ||
|     "        \n",
 | ||
|     "        # 3. 训练模型\n",
 | ||
|     "        print(f\"\\n🌲 第3步: 训练随机森林回归模型\")\n",
 | ||
|     "        training_results = rf_regressor.fit(X_train, y_train, X_val, y_val)\n",
 | ||
|     "        \n",
 | ||
|     "        # 4. 分析训练结果\n",
 | ||
|     "        print(f\"\\n📊 第4步: 训练结果分析\")\n",
 | ||
|     "        print(\"=\"*50)\n",
 | ||
|     "        \n",
 | ||
|     "        for metric, value in training_results.items():\n",
 | ||
|     "            metric_name = metric.replace('_', ' ').title()\n",
 | ||
|     "            print(f\"   {metric_name}: {value:.6f}\")\n",
 | ||
|     "        \n",
 | ||
|     "        # 5. 特征重要性分析\n",
 | ||
|     "        print(f\"\\n🔍 第5步: 特征重要性分析\")\n",
 | ||
|     "        top_features, all_importances = rf_regressor.get_feature_importance(top_k=20)\n",
 | ||
|     "        \n",
 | ||
|     "        print(f\"\\n🏆 Top 20 重要特征:\")\n",
 | ||
|     "        print(f\"{'排名':>4} {'特征名称':>15} {'重要性':>10}\")\n",
 | ||
|     "        print(\"-\" * 35)\n",
 | ||
|     "        for i, (feature_name, importance) in enumerate(top_features):\n",
 | ||
|     "            print(f\"{i+1:>4} {feature_name:>15} {importance:>10.6f}\")\n",
 | ||
|     "        \n",
 | ||
|     "        # 分析时间窗口内的重要性分布\n",
 | ||
|     "        print(f\"\\n📈 时间窗口重要性分布:\")\n",
 | ||
|     "        window_importances = np.zeros(rf_regressor.window_size)\n",
 | ||
|     "        for i in range(rf_regressor.window_size):\n",
 | ||
|     "            start_idx = i * 512\n",
 | ||
|     "            end_idx = (i + 1) * 512\n",
 | ||
|     "            window_importances[i] = all_importances[start_idx:end_idx].sum()\n",
 | ||
|     "        \n",
 | ||
|     "        max_time_step = np.argmax(window_importances)\n",
 | ||
|     "        print(f\"   最重要的时间步: t{max_time_step} (重要性: {window_importances[max_time_step]:.6f})\")\n",
 | ||
|     "        print(f\"   窗口中心位置: t{rf_regressor.window_size//2}\")\n",
 | ||
|     "        print(f\"   重要性分布: 前5个时间步的重要性\")\n",
 | ||
|     "        for i in range(min(5, len(window_importances))):\n",
 | ||
|     "            print(f\"      t{i}: {window_importances[i]:.6f}\")\n",
 | ||
|     "        \n",
 | ||
|     "        print(f\"\\n✅ 随机森林回归模型训练完成!\")\n",
 | ||
|     "        print(f\"🎯 模型可以预测40个音素的概率分布\")\n",
 | ||
|     "        print(f\"📊 基于30时间步的神经特征窗口\")\n",
 | ||
|     "        print(f\"🌲 使用{rf_regressor.n_estimators}棵决策树\")\n",
 | ||
|     "        \n",
 | ||
|     "    else:\n",
 | ||
|     "        print(\"❌ 数据准备失败,无法训练模型\")"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": null,
 | ||
|    "metadata": {},
 | ||
|    "outputs": [],
 | ||
|    "source": [
 | ||
|     "# 📊 模型评估和可视化分析\n",
 | ||
|     "def evaluate_phoneme_predictions(rf_model, X_test, y_test, dataset_name=\"测试集\"):\n",
 | ||
|     "    \"\"\"\n",
 | ||
|     "    评估每个音素的预测性能\n",
 | ||
|     "    \"\"\"\n",
 | ||
|     "    print(f\"\\n📊 {dataset_name}详细评估\")\n",
 | ||
|     "    print(\"=\"*50)\n",
 | ||
|     "    \n",
 | ||
|     "    # 获取预测结果\n",
 | ||
|     "    y_pred = rf_model.predict(X_test)\n",
 | ||
|     "    \n",
 | ||
|     "    # 计算每个音素的性能指标\n",
 | ||
|     "    phoneme_metrics = []\n",
 | ||
|     "    \n",
 | ||
|     "    for i in range(40):  # 40个音素\n",
 | ||
|     "        phoneme_name = LOGIT_TO_PHONEME[i]\n",
 | ||
|     "        \n",
 | ||
|     "        # 计算单个音素的指标\n",
 | ||
|     "        mse = mean_squared_error(y_test[:, i], y_pred[:, i])\n",
 | ||
|     "        r2 = r2_score(y_test[:, i], y_pred[:, i])\n",
 | ||
|     "        mae = mean_absolute_error(y_test[:, i], y_pred[:, i])\n",
 | ||
|     "        \n",
 | ||
|     "        # 计算相关系数\n",
 | ||
|     "        correlation = np.corrcoef(y_test[:, i], y_pred[:, i])[0, 1]\n",
 | ||
|     "        \n",
 | ||
|     "        phoneme_metrics.append({\n",
 | ||
|     "            'phoneme_id': i,\n",
 | ||
|     "            'phoneme_name': phoneme_name,\n",
 | ||
|     "            'mse': mse, \n",
 | ||
|     "            'r2': r2,\n",
 | ||
|     "            'mae': mae,\n",
 | ||
|     "            'correlation': correlation if not np.isnan(correlation) else 0.0\n",
 | ||
|     "        })\n",
 | ||
|     "    \n",
 | ||
|     "    # 转换为DataFrame便于分析\n",
 | ||
|     "    metrics_df = pd.DataFrame(phoneme_metrics)\n",
 | ||
|     "    \n",
 | ||
|     "    # 打印总体统计\n",
 | ||
|     "    print(f\"📈 总体性能指标:\")\n",
 | ||
|     "    print(f\"   平均 MSE: {metrics_df['mse'].mean():.6f}\")\n",
 | ||
|     "    print(f\"   平均 R²: {metrics_df['r2'].mean():.4f}\")\n",
 | ||
|     "    print(f\"   平均 MAE: {metrics_df['mae'].mean():.6f}\")\n",
 | ||
|     "    print(f\"   平均相关系数: {metrics_df['correlation'].mean():.4f}\")\n",
 | ||
|     "    \n",
 | ||
|     "    # 找出最佳和最差预测的音素\n",
 | ||
|     "    best_r2_idx = metrics_df['r2'].idxmax()\n",
 | ||
|     "    worst_r2_idx = metrics_df['r2'].idxmin()\n",
 | ||
|     "    \n",
 | ||
|     "    print(f\"\\n🏆 最佳预测音素:\")\n",
 | ||
|     "    best_phoneme = metrics_df.loc[best_r2_idx]\n",
 | ||
|     "    print(f\"   {best_phoneme['phoneme_name']} (ID: {best_phoneme['phoneme_id']})\")\n",
 | ||
|     "    print(f\"   R²: {best_phoneme['r2']:.4f}, MSE: {best_phoneme['mse']:.6f}\")\n",
 | ||
|     "    \n",
 | ||
|     "    print(f\"\\n📉 最差预测音素:\")\n",
 | ||
|     "    worst_phoneme = metrics_df.loc[worst_r2_idx]\n",
 | ||
|     "    print(f\"   {worst_phoneme['phoneme_name']} (ID: {worst_phoneme['phoneme_id']})\")\n",
 | ||
|     "    print(f\"   R²: {worst_phoneme['r2']:.4f}, MSE: {worst_phoneme['mse']:.6f}\")\n",
 | ||
|     "    \n",
 | ||
|     "    return metrics_df, y_pred\n",
 | ||
|     "\n",
 | ||
|     "def visualize_prediction_results(metrics_df, y_true, y_pred, save_plots=False):\n",
 | ||
|     "    \"\"\"\n",
 | ||
|     "    可视化预测结果\n",
 | ||
|     "    \"\"\"\n",
 | ||
|     "    print(f\"\\n📊 创建可视化图表...\")\n",
 | ||
|     "    \n",
 | ||
|     "    # 设置图表样式\n",
 | ||
|     "    plt.style.use('default')\n",
 | ||
|     "    fig = plt.figure(figsize=(20, 12))\n",
 | ||
|     "    \n",
 | ||
|     "    # 1. R²分数分布\n",
 | ||
|     "    plt.subplot(2, 3, 1)\n",
 | ||
|     "    plt.hist(metrics_df['r2'], bins=20, alpha=0.7, color='skyblue', edgecolor='black')\n",
 | ||
|     "    plt.axvline(metrics_df['r2'].mean(), color='red', linestyle='--', \n",
 | ||
|     "                label=f'平均值: {metrics_df[\"r2\"].mean():.4f}')\n",
 | ||
|     "    plt.xlabel('R² Score')\n",
 | ||
|     "    plt.ylabel('音素数量')\n",
 | ||
|     "    plt.title('R² Score 分布')\n",
 | ||
|     "    plt.legend()\n",
 | ||
|     "    plt.grid(True, alpha=0.3)\n",
 | ||
|     "    \n",
 | ||
|     "    # 2. MSE分布\n",
 | ||
|     "    plt.subplot(2, 3, 2)\n",
 | ||
|     "    plt.hist(metrics_df['mse'], bins=20, alpha=0.7, color='lightcoral', edgecolor='black')\n",
 | ||
|     "    plt.axvline(metrics_df['mse'].mean(), color='red', linestyle='--',\n",
 | ||
|     "                label=f'平均值: {metrics_df[\"mse\"].mean():.6f}')\n",
 | ||
|     "    plt.xlabel('Mean Squared Error')\n",
 | ||
|     "    plt.ylabel('音素数量')\n",
 | ||
|     "    plt.title('MSE 分布')\n",
 | ||
|     "    plt.legend()\n",
 | ||
|     "    plt.grid(True, alpha=0.3)\n",
 | ||
|     "    \n",
 | ||
|     "    # 3. 前10个音素的性能对比\n",
 | ||
|     "    plt.subplot(2, 3, 3)\n",
 | ||
|     "    top_10 = metrics_df.nlargest(10, 'r2')\n",
 | ||
|     "    bars = plt.bar(range(10), top_10['r2'], color='lightgreen', alpha=0.7)\n",
 | ||
|     "    plt.xlabel('音素排名')\n",
 | ||
|     "    plt.ylabel('R² Score')\n",
 | ||
|     "    plt.title('Top 10 音素预测性能')\n",
 | ||
|     "    plt.xticks(range(10), top_10['phoneme_name'], rotation=45)\n",
 | ||
|     "    plt.grid(True, alpha=0.3)\n",
 | ||
|     "    \n",
 | ||
|     "    # 添加数值标签\n",
 | ||
|     "    for i, bar in enumerate(bars):\n",
 | ||
|     "        height = bar.get_height()\n",
 | ||
|     "        plt.text(bar.get_x() + bar.get_width()/2., height + 0.001,\n",
 | ||
|     "                f'{height:.3f}', ha='center', va='bottom', fontsize=8)\n",
 | ||
|     "    \n",
 | ||
|     "    # 4. 真实值 vs 预测值散点图 (选择最佳音素)\n",
 | ||
|     "    plt.subplot(2, 3, 4)\n",
 | ||
|     "    best_phoneme_idx = metrics_df['r2'].idxmax()\n",
 | ||
|     "    phoneme_id = metrics_df.loc[best_phoneme_idx, 'phoneme_id']\n",
 | ||
|     "    phoneme_name = metrics_df.loc[best_phoneme_idx, 'phoneme_name']\n",
 | ||
|     "    \n",
 | ||
|     "    # 随机采样1000个点以避免图表过于密集\n",
 | ||
|     "    sample_size = min(1000, len(y_true))\n",
 | ||
|     "    sample_indices = np.random.choice(len(y_true), sample_size, replace=False)\n",
 | ||
|     "    \n",
 | ||
|     "    plt.scatter(y_true[sample_indices, phoneme_id], y_pred[sample_indices, phoneme_id], \n",
 | ||
|     "               alpha=0.6, s=20, color='blue')\n",
 | ||
|     "    \n",
 | ||
|     "    # 添加对角线 (完美预测线)\n",
 | ||
|     "    min_val = min(y_true[:, phoneme_id].min(), y_pred[:, phoneme_id].min())\n",
 | ||
|     "    max_val = max(y_true[:, phoneme_id].max(), y_pred[:, phoneme_id].max())\n",
 | ||
|     "    plt.plot([min_val, max_val], [min_val, max_val], 'r--', alpha=0.8, label='完美预测')\n",
 | ||
|     "    \n",
 | ||
|     "    plt.xlabel('真实值')\n",
 | ||
|     "    plt.ylabel('预测值')\n",
 | ||
|     "    plt.title(f'最佳音素 {phoneme_name} 的预测结果')\n",
 | ||
|     "    plt.legend()\n",
 | ||
|     "    plt.grid(True, alpha=0.3)\n",
 | ||
|     "    \n",
 | ||
|     "    # 5. 相关系数热力图 (前20个音素)\n",
 | ||
|     "    plt.subplot(2, 3, 5)\n",
 | ||
|     "    top_20_correlations = metrics_df.nlargest(20, 'correlation')\n",
 | ||
|     "    corr_data = top_20_correlations[['phoneme_name', 'correlation']].set_index('phoneme_name')\n",
 | ||
|     "    \n",
 | ||
|     "    # 创建热力图数据\n",
 | ||
|     "    heatmap_data = corr_data.values.reshape(-1, 1)\n",
 | ||
|     "    im = plt.imshow(heatmap_data.T, cmap='RdYlBu_r', aspect='auto', vmin=0, vmax=1)\n",
 | ||
|     "    \n",
 | ||
|     "    plt.colorbar(im, shrink=0.8)\n",
 | ||
|     "    plt.yticks([0], ['相关系数'])\n",
 | ||
|     "    plt.xticks(range(len(top_20_correlations)), top_20_correlations['phoneme_name'], \n",
 | ||
|     "               rotation=45, ha='right')\n",
 | ||
|     "    plt.title('Top 20 音素相关系数')\n",
 | ||
|     "    \n",
 | ||
|     "    # 6. 各音素预测误差箱线图 (前10个音素)\n",
 | ||
|     "    plt.subplot(2, 3, 6)\n",
 | ||
|     "    top_10_ids = metrics_df.nlargest(10, 'r2')['phoneme_id'].values\n",
 | ||
|     "    errors_data = []\n",
 | ||
|     "    labels = []\n",
 | ||
|     "    \n",
 | ||
|     "    for phoneme_id in top_10_ids:\n",
 | ||
|     "        errors = np.abs(y_true[:, phoneme_id] - y_pred[:, phoneme_id])\n",
 | ||
|     "        errors_data.append(errors)\n",
 | ||
|     "        labels.append(LOGIT_TO_PHONEME[phoneme_id])\n",
 | ||
|     "    \n",
 | ||
|     "    plt.boxplot(errors_data, labels=labels)\n",
 | ||
|     "    plt.xlabel('音素')\n",
 | ||
|     "    plt.ylabel('绝对误差')\n",
 | ||
|     "    plt.title('Top 10 音素预测误差分布')\n",
 | ||
|     "    plt.xticks(rotation=45)\n",
 | ||
|     "    plt.grid(True, alpha=0.3)\n",
 | ||
|     "    \n",
 | ||
|     "    plt.tight_layout()\n",
 | ||
|     "    \n",
 | ||
|     "    if save_plots:\n",
 | ||
|     "        plt.savefig('./processed_datasets/rf_regression_results.png', dpi=300, bbox_inches='tight')\n",
 | ||
|     "        print(\"📁 图表已保存至: ./processed_datasets/rf_regression_results.png\")\n",
 | ||
|     "    \n",
 | ||
|     "    plt.show()\n",
 | ||
|     "\n",
 | ||
|     "# 如果模型训练成功,进行评估\n",
 | ||
|     "if 'rf_regressor' in locals() and rf_regressor.is_fitted:\n",
 | ||
|     "    print(f\"\\n🎯 开始模型评估和可视化\")\n",
 | ||
|     "    \n",
 | ||
|     "    # 评估验证集\n",
 | ||
|     "    if X_val is not None and y_val is not None:\n",
 | ||
|     "        val_metrics, val_predictions = evaluate_phoneme_predictions(\n",
 | ||
|     "            rf_regressor, X_val, y_val, \"验证集\"\n",
 | ||
|     "        )\n",
 | ||
|     "        \n",
 | ||
|     "        # 可视化结果\n",
 | ||
|     "        visualize_prediction_results(val_metrics, y_val, val_predictions, save_plots=True)\n",
 | ||
|     "        \n",
 | ||
|     "        # 保存详细结果\n",
 | ||
|     "        val_metrics.to_csv('./processed_datasets/phoneme_prediction_metrics.csv', index=False)\n",
 | ||
|     "        print(f\"\\n📁 详细评估结果已保存至: ./processed_datasets/phoneme_prediction_metrics.csv\")\n",
 | ||
|     "        \n",
 | ||
|     "        # 准备测试集数据 (如果有)\n",
 | ||
|     "        if test_datasets:\n",
 | ||
|     "            print(f\"\\n🔮 准备测试集预测...\")\n",
 | ||
|     "            X_test, y_test = rf_regressor.prepare_dataset_for_training(test_datasets, \"测试集\")\n",
 | ||
|     "            \n",
 | ||
|     "            if X_test is not None:\n",
 | ||
|     "                test_metrics, test_predictions = evaluate_phoneme_predictions(\n",
 | ||
|     "                    rf_regressor, X_test, y_test, \"测试集\"\n",
 | ||
|     "                )\n",
 | ||
|     "                print(f\"\\n✅ 测试集评估完成\")\n",
 | ||
|     "            else:\n",
 | ||
|     "                print(f\"⚠️ 测试集数据准备失败\")\n",
 | ||
|     "    \n",
 | ||
|     "    print(f\"\\n🎉 随机森林回归模型完整评估完成!\")\n",
 | ||
|     "    print(f\"📊 生成了详细的性能分析和可视化图表\")\n",
 | ||
|     "    print(f\"🔧 模型已准备好用于实际预测任务\")\n",
 | ||
|     "    \n",
 | ||
|     "else:\n",
 | ||
|     "    print(\"⚠️ 模型尚未训练完成,请先运行训练代码\")"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": null,
 | ||
|    "metadata": {},
 | ||
|    "outputs": [],
 | ||
|    "source": [
 | ||
|     "# 🎯 回归结果转分类结果分析\n",
 | ||
|     "def regression_to_classification_analysis(y_true_probs, y_pred_probs, show_detailed_metrics=True):\n",
 | ||
|     "    \"\"\"\n",
 | ||
|     "    将回归预测的40个音素概率转换为分类结果并进行分析\n",
 | ||
|     "    \n",
 | ||
|     "    参数:\n",
 | ||
|     "        y_true_probs: 真实的40个音素概率 [n_samples, 40]\n",
 | ||
|     "        y_pred_probs: 预测的40个音素概率 [n_samples, 40]\n",
 | ||
|     "        show_detailed_metrics: 是否显示详细的分类指标\n",
 | ||
|     "     \n",
 | ||
|     "    返回:\n",
 | ||
|     "        classification_results: 包含分类结果的字典\n",
 | ||
|     "    \"\"\"\n",
 | ||
|     "    print(\"🎯 回归结果转分类结果分析\")\n",
 | ||
|     "    print(\"=\"*60)\n",
 | ||
|     "    \n",
 | ||
|     "    # 1. 将概率转换为分类标签\n",
 | ||
|     "    y_true_classes = np.argmax(y_true_probs, axis=1)  # 真实类别\n",
 | ||
|     "    y_pred_classes = np.argmax(y_pred_probs, axis=1)  # 预测类别\n",
 | ||
|     "    \n",
 | ||
|     "    # 2. 计算分类准确率\n",
 | ||
|     "    accuracy = (y_true_classes == y_pred_classes).mean()\n",
 | ||
|     "    \n",
 | ||
|     "    print(f\"📊 分类结果概览:\")\n",
 | ||
|     "    print(f\"   总样本数: {len(y_true_classes):,}\")\n",
 | ||
|     "    print(f\"   分类准确率: {accuracy:.4f} ({accuracy*100:.2f}%)\")\n",
 | ||
|     "    print(f\"   正确预测: {(y_true_classes == y_pred_classes).sum():,}\")\n",
 | ||
|     "    print(f\"   错误预测: {(y_true_classes != y_pred_classes).sum():,}\")\n",
 | ||
|     "    \n",
 | ||
|     "    # 3. 分析预测置信度\n",
 | ||
|     "    pred_confidences = np.max(y_pred_probs, axis=1)  # 预测的最大概率\n",
 | ||
|     "    true_confidences = np.max(y_true_probs, axis=1)  # 真实的最大概率\n",
 | ||
|     "    \n",
 | ||
|     "    print(f\"\\n🔍 预测置信度分析:\")\n",
 | ||
|     "    print(f\"   预测置信度均值: {pred_confidences.mean():.4f}\")\n",
 | ||
|     "    print(f\"   预测置信度标准差: {pred_confidences.std():.4f}\")\n",
 | ||
|     "    print(f\"   预测置信度范围: [{pred_confidences.min():.4f}, {pred_confidences.max():.4f}]\")\n",
 | ||
|     "    print(f\"   真实置信度均值: {true_confidences.mean():.4f}\")\n",
 | ||
|     "    \n",
 | ||
|     "    # 4. 按置信度分组的准确率分析\n",
 | ||
|     "    confidence_bins = [0.0, 0.3, 0.5, 0.7, 0.9, 1.0]\n",
 | ||
|     "    print(f\"\\n📈 按预测置信度分组的准确率:\")\n",
 | ||
|     "    print(f\"{'置信度区间':>12} {'样本数':>8} {'准确率':>8} {'百分比':>8}\")\n",
 | ||
|     "    print(\"-\" * 40)\n",
 | ||
|     "    \n",
 | ||
|     "    for i in range(len(confidence_bins)-1):\n",
 | ||
|     "        low, high = confidence_bins[i], confidence_bins[i+1]\n",
 | ||
|     "        mask = (pred_confidences >= low) & (pred_confidences < high)\n",
 | ||
|     "        if i == len(confidence_bins)-2:  # 最后一个区间包含等号\n",
 | ||
|     "            mask = (pred_confidences >= low) & (pred_confidences <= high)\n",
 | ||
|     "        \n",
 | ||
|     "        if mask.sum() > 0:\n",
 | ||
|     "            bin_accuracy = (y_true_classes[mask] == y_pred_classes[mask]).mean()\n",
 | ||
|     "            count = mask.sum()\n",
 | ||
|     "            percentage = count / len(y_true_classes) * 100\n",
 | ||
|     "            print(f\"[{low:.1f}, {high:.1f}{')'if i<len(confidence_bins)-2 else ']':>1} {count:>8} {bin_accuracy:>8.4f} {percentage:>7.1f}%\")\n",
 | ||
|     "    \n",
 | ||
|     "    # 5. 混淆矩阵分析(Top-K音素)\n",
 | ||
|     "    from collections import Counter\n",
 | ||
|     "    \n",
 | ||
|     "    # 找出最常见的音素\n",
 | ||
|     "    true_counter = Counter(y_true_classes)\n",
 | ||
|     "    pred_counter = Counter(y_pred_classes)\n",
 | ||
|     "    \n",
 | ||
|     "    most_common_true = true_counter.most_common(10)\n",
 | ||
|     "    most_common_pred = pred_counter.most_common(10)\n",
 | ||
|     "    \n",
 | ||
|     "    print(f\"\\n🏆 最常见的音素 (真实 vs 预测):\")\n",
 | ||
|     "    print(f\"{'真实音素':>12} {'次数':>6} {'预测音素':>12} {'次数':>6}\")\n",
 | ||
|     "    print(\"-\" * 42)\n",
 | ||
|     "    \n",
 | ||
|     "    for i in range(min(len(most_common_true), len(most_common_pred))):\n",
 | ||
|     "        true_id, true_count = most_common_true[i]\n",
 | ||
|     "        pred_id, pred_count = most_common_pred[i]\n",
 | ||
|     "        true_name = LOGIT_TO_PHONEME[true_id]\n",
 | ||
|     "        pred_name = LOGIT_TO_PHONEME[pred_id]\n",
 | ||
|     "        print(f\"{true_name:>12} {true_count:>6} {pred_name:>12} {pred_count:>6}\")\n",
 | ||
|     "    \n",
 | ||
|     "    # 6. 每个音素的分类性能\n",
 | ||
|     "    if show_detailed_metrics:\n",
 | ||
|     "        from sklearn.metrics import classification_report, confusion_matrix\n",
 | ||
|     "        \n",
 | ||
|     "        print(f\"\\n📋 详细分类报告 (前20个最常见音素):\")\n",
 | ||
|     "        \n",
 | ||
|     "        # 获取前20个最常见的音素\n",
 | ||
|     "        top_20_phonemes = [phoneme_id for phoneme_id, _ in most_common_true[:20]]\n",
 | ||
|     "        \n",
 | ||
|     "        # 创建掩码,只包含这些音素\n",
 | ||
|     "        mask_top20 = np.isin(y_true_classes, top_20_phonemes)\n",
 | ||
|     "        y_true_top20 = y_true_classes[mask_top20]\n",
 | ||
|     "        y_pred_top20 = y_pred_classes[mask_top20]\n",
 | ||
|     "        \n",
 | ||
|     "        # 生成分类报告\n",
 | ||
|     "        target_names = [LOGIT_TO_PHONEME[i] for i in top_20_phonemes]\n",
 | ||
|     "        \n",
 | ||
|     "        try:\n",
 | ||
|     "            report = classification_report(\n",
 | ||
|     "                y_true_top20, y_pred_top20, \n",
 | ||
|     "                labels=top_20_phonemes,\n",
 | ||
|     "                target_names=target_names,\n",
 | ||
|     "                output_dict=True,\n",
 | ||
|     "                zero_division=0\n",
 | ||
|     "            )\n",
 | ||
|     "            \n",
 | ||
|     "            # 打印格式化的报告\n",
 | ||
|     "            print(f\"{'音素':>8} {'精确率':>8} {'召回率':>8} {'F1分数':>8} {'支持数':>8}\")\n",
 | ||
|     "            print(\"-\" * 48)\n",
 | ||
|     "            \n",
 | ||
|     "            for phoneme_id in top_20_phonemes:\n",
 | ||
|     "                phoneme_name = LOGIT_TO_PHONEME[phoneme_id]\n",
 | ||
|     "                if phoneme_name in report:\n",
 | ||
|     "                    metrics = report[phoneme_name]\n",
 | ||
|     "                    print(f\"{phoneme_name:>8} {metrics['precision']:>8.4f} {metrics['recall']:>8.4f} \"\n",
 | ||
|     "                          f\"{metrics['f1-score']:>8.4f} {int(metrics['support']):>8}\")\n",
 | ||
|     "            \n",
 | ||
|     "            # 总体指标\n",
 | ||
|     "            macro_avg = report['macro avg']\n",
 | ||
|     "            weighted_avg = report['weighted avg']\n",
 | ||
|     "            print(\"-\" * 48)\n",
 | ||
|     "            print(f\"{'宏平均':>8} {macro_avg['precision']:>8.4f} {macro_avg['recall']:>8.4f} \"\n",
 | ||
|     "                  f\"{macro_avg['f1-score']:>8.4f}\")\n",
 | ||
|     "            print(f\"{'加权平均':>8} {weighted_avg['precision']:>8.4f} {weighted_avg['recall']:>8.4f} \"\n",
 | ||
|     "                  f\"{weighted_avg['f1-score']:>8.4f}\")\n",
 | ||
|     "            \n",
 | ||
|     "        except Exception as e:\n",
 | ||
|     "            print(f\"分类报告生成失败: {e}\")\n",
 | ||
|     "    \n",
 | ||
|     "    # 7. Top-K准确率分析\n",
 | ||
|     "    print(f\"\\n🎯 Top-K 准确率分析:\")\n",
 | ||
|     "    for k in [1, 3, 5, 10]:\n",
 | ||
|     "        # 计算Top-K准确率\n",
 | ||
|     "        top_k_pred = np.argsort(y_pred_probs, axis=1)[:, -k:]  # 取概率最高的K个\n",
 | ||
|     "        top_k_accuracy = np.mean([y_true_classes[i] in top_k_pred[i] for i in range(len(y_true_classes))])\n",
 | ||
|     "        print(f\"   Top-{k} 准确率: {top_k_accuracy:.4f} ({top_k_accuracy*100:.2f}%)\")\n",
 | ||
|     "    \n",
 | ||
|     "    # 8. 错误分析 - 最常见的预测错误\n",
 | ||
|     "    print(f\"\\n❌ 最常见的预测错误:\")\n",
 | ||
|     "    error_mask = y_true_classes != y_pred_classes\n",
 | ||
|     "    error_pairs = list(zip(y_true_classes[error_mask], y_pred_classes[error_mask]))\n",
 | ||
|     "    error_counter = Counter(error_pairs)\n",
 | ||
|     "    \n",
 | ||
|     "    print(f\"{'真实音素':>12} {'预测音素':>12} {'错误次数':>8}\")\n",
 | ||
|     "    print(\"-\" * 36)\n",
 | ||
|     "    for (true_id, pred_id), count in error_counter.most_common(10):\n",
 | ||
|     "        true_name = LOGIT_TO_PHONEME[true_id]\n",
 | ||
|     "        pred_name = LOGIT_TO_PHONEME[pred_id]\n",
 | ||
|     "        print(f\"{true_name:>12} {pred_name:>12} {count:>8}\")\n",
 | ||
|     "    \n",
 | ||
|     "    # 返回结果字典\n",
 | ||
|     "    classification_results = {\n",
 | ||
|     "        'accuracy': accuracy,\n",
 | ||
|     "        'y_true_classes': y_true_classes,\n",
 | ||
|     "        'y_pred_classes': y_pred_classes,\n",
 | ||
|     "        'pred_confidences': pred_confidences,\n",
 | ||
|     "        'true_confidences': true_confidences,\n",
 | ||
|     "        'most_common_errors': error_counter.most_common(10)\n",
 | ||
|     "    }\n",
 | ||
|     "    \n",
 | ||
|     "    return classification_results\n",
 | ||
|     "\n",
 | ||
|     "def create_classification_visualizations(y_true_probs, y_pred_probs, classification_results):\n",
 | ||
|     "    \"\"\"\n",
 | ||
|     "    为分类结果创建可视化图表\n",
 | ||
|     "    \"\"\"\n",
 | ||
|     "    print(f\"\\n📊 创建分类结果可视化...\")\n",
 | ||
|     "    \n",
 | ||
|     "    fig, axes = plt.subplots(2, 3, figsize=(18, 12))\n",
 | ||
|     "    fig.suptitle('随机森林回归转分类结果分析', fontsize=16, fontweight='bold')\n",
 | ||
|     "    \n",
 | ||
|     "    y_true_classes = classification_results['y_true_classes']\n",
 | ||
|     "    y_pred_classes = classification_results['y_pred_classes']\n",
 | ||
|     "    pred_confidences = classification_results['pred_confidences']\n",
 | ||
|     "    \n",
 | ||
|     "    # 1. 预测置信度分布\n",
 | ||
|     "    axes[0, 0].hist(pred_confidences, bins=50, alpha=0.7, color='skyblue', edgecolor='black')\n",
 | ||
|     "    axes[0, 0].axvline(pred_confidences.mean(), color='red', linestyle='--', \n",
 | ||
|     "                      label=f'均值: {pred_confidences.mean():.3f}')\n",
 | ||
|     "    axes[0, 0].set_xlabel('预测置信度')\n",
 | ||
|     "    axes[0, 0].set_ylabel('样本数量')\n",
 | ||
|     "    axes[0, 0].set_title('预测置信度分布')\n",
 | ||
|     "    axes[0, 0].legend()\n",
 | ||
|     "    axes[0, 0].grid(True, alpha=0.3)\n",
 | ||
|     "    \n",
 | ||
|     "    # 2. 准确率 vs 置信度\n",
 | ||
|     "    confidence_bins = np.linspace(0, 1, 21)\n",
 | ||
|     "    bin_centers = (confidence_bins[:-1] + confidence_bins[1:]) / 2\n",
 | ||
|     "    bin_accuracies = []\n",
 | ||
|     "    bin_counts = []\n",
 | ||
|     "    \n",
 | ||
|     "    for i in range(len(confidence_bins)-1):\n",
 | ||
|     "        mask = (pred_confidences >= confidence_bins[i]) & (pred_confidences < confidence_bins[i+1])\n",
 | ||
|     "        if mask.sum() > 0:\n",
 | ||
|     "            accuracy = (y_true_classes[mask] == y_pred_classes[mask]).mean()\n",
 | ||
|     "            bin_accuracies.append(accuracy)\n",
 | ||
|     "            bin_counts.append(mask.sum())\n",
 | ||
|     "        else:\n",
 | ||
|     "            bin_accuracies.append(0)\n",
 | ||
|     "            bin_counts.append(0)\n",
 | ||
|     "    \n",
 | ||
|     "    # 只显示有数据的bins\n",
 | ||
|     "    valid_bins = np.array(bin_counts) > 0\n",
 | ||
|     "    axes[0, 1].plot(bin_centers[valid_bins], np.array(bin_accuracies)[valid_bins], \n",
 | ||
|     "                   'bo-', linewidth=2, markersize=6)\n",
 | ||
|     "    axes[0, 1].set_xlabel('预测置信度')\n",
 | ||
|     "    axes[0, 1].set_ylabel('准确率')\n",
 | ||
|     "    axes[0, 1].set_title('准确率 vs 预测置信度')\n",
 | ||
|     "    axes[0, 1].grid(True, alpha=0.3)\n",
 | ||
|     "    axes[0, 1].set_ylim(0, 1)\n",
 | ||
|     "    \n",
 | ||
|     "    # 3. 最常见音素的预测准确率\n",
 | ||
|     "    from collections import Counter\n",
 | ||
|     "    true_counter = Counter(y_true_classes)\n",
 | ||
|     "    most_common_phonemes = [phoneme_id for phoneme_id, _ in true_counter.most_common(15)]\n",
 | ||
|     "    \n",
 | ||
|     "    phoneme_accuracies = []\n",
 | ||
|     "    phoneme_names = []\n",
 | ||
|     "    for phoneme_id in most_common_phonemes:\n",
 | ||
|     "        mask = y_true_classes == phoneme_id\n",
 | ||
|     "        if mask.sum() > 0:\n",
 | ||
|     "            accuracy = (y_pred_classes[mask] == phoneme_id).mean()\n",
 | ||
|     "            phoneme_accuracies.append(accuracy)\n",
 | ||
|     "            phoneme_names.append(LOGIT_TO_PHONEME[phoneme_id])\n",
 | ||
|     "    \n",
 | ||
|     "    bars = axes[0, 2].bar(range(len(phoneme_names)), phoneme_accuracies, \n",
 | ||
|     "                         color='lightgreen', alpha=0.7)\n",
 | ||
|     "    axes[0, 2].set_xlabel('音素')\n",
 | ||
|     "    axes[0, 2].set_ylabel('准确率')\n",
 | ||
|     "    axes[0, 2].set_title('Top 15 音素的分类准确率')\n",
 | ||
|     "    axes[0, 2].set_xticks(range(len(phoneme_names)))\n",
 | ||
|     "    axes[0, 2].set_xticklabels(phoneme_names, rotation=45, ha='right')\n",
 | ||
|     "    axes[0, 2].grid(True, alpha=0.3)\n",
 | ||
|     "    \n",
 | ||
|     "    # 添加数值标签\n",
 | ||
|     "    for bar, acc in zip(bars, phoneme_accuracies):\n",
 | ||
|     "        height = bar.get_height()\n",
 | ||
|     "        axes[0, 2].text(bar.get_x() + bar.get_width()/2., height + 0.01,\n",
 | ||
|     "                       f'{acc:.3f}', ha='center', va='bottom', fontsize=8)\n",
 | ||
|     "    \n",
 | ||
|     "    # 4. 混淆矩阵(前10个最常见音素)\n",
 | ||
|     "    from sklearn.metrics import confusion_matrix\n",
 | ||
|     "    top_10_phonemes = most_common_phonemes[:10]\n",
 | ||
|     "    mask_top10 = np.isin(y_true_classes, top_10_phonemes) & np.isin(y_pred_classes, top_10_phonemes)\n",
 | ||
|     "    \n",
 | ||
|     "    if mask_top10.sum() > 0:\n",
 | ||
|     "        cm = confusion_matrix(y_true_classes[mask_top10], y_pred_classes[mask_top10], \n",
 | ||
|     "                            labels=top_10_phonemes)\n",
 | ||
|     "        \n",
 | ||
|     "        im = axes[1, 0].imshow(cm, interpolation='nearest', cmap='Blues')\n",
 | ||
|     "        axes[1, 0].set_title('混淆矩阵 (Top 10 音素)')\n",
 | ||
|     "        \n",
 | ||
|     "        # 添加颜色条\n",
 | ||
|     "        cbar = plt.colorbar(im, ax=axes[1, 0], shrink=0.8)\n",
 | ||
|     "        cbar.set_label('预测次数')\n",
 | ||
|     "        \n",
 | ||
|     "        # 设置标签\n",
 | ||
|     "        tick_marks = np.arange(len(top_10_phonemes))\n",
 | ||
|     "        top_10_names = [LOGIT_TO_PHONEME[i] for i in top_10_phonemes]\n",
 | ||
|     "        axes[1, 0].set_xticks(tick_marks)\n",
 | ||
|     "        axes[1, 0].set_yticks(tick_marks)\n",
 | ||
|     "        axes[1, 0].set_xticklabels(top_10_names, rotation=45, ha='right')\n",
 | ||
|     "        axes[1, 0].set_yticklabels(top_10_names)\n",
 | ||
|     "        axes[1, 0].set_xlabel('预测音素')\n",
 | ||
|     "        axes[1, 0].set_ylabel('真实音素')\n",
 | ||
|     "    \n",
 | ||
|     "    # 5. Top-K准确率\n",
 | ||
|     "    k_values = [1, 2, 3, 4, 5, 10, 15, 20]\n",
 | ||
|     "    top_k_accuracies = []\n",
 | ||
|     "    \n",
 | ||
|     "    for k in k_values:\n",
 | ||
|     "        top_k_pred = np.argsort(y_pred_probs, axis=1)[:, -k:]\n",
 | ||
|     "        top_k_accuracy = np.mean([y_true_classes[i] in top_k_pred[i] for i in range(len(y_true_classes))])\n",
 | ||
|     "        top_k_accuracies.append(top_k_accuracy)\n",
 | ||
|     "    \n",
 | ||
|     "    axes[1, 1].plot(k_values, top_k_accuracies, 'ro-', linewidth=2, markersize=8)\n",
 | ||
|     "    axes[1, 1].set_xlabel('K 值')\n",
 | ||
|     "    axes[1, 1].set_ylabel('Top-K 准确率')\n",
 | ||
|     "    axes[1, 1].set_title('Top-K 准确率曲线')\n",
 | ||
|     "    axes[1, 1].grid(True, alpha=0.3)\n",
 | ||
|     "    axes[1, 1].set_ylim(0, 1)\n",
 | ||
|     "    \n",
 | ||
|     "    # 添加数值标签\n",
 | ||
|     "    for k, acc in zip(k_values, top_k_accuracies):\n",
 | ||
|     "        axes[1, 1].annotate(f'{acc:.3f}', (k, acc), textcoords=\"offset points\", \n",
 | ||
|     "                           xytext=(0,10), ha='center')\n",
 | ||
|     "    \n",
 | ||
|     "    # 6. 错误分析 - 最常见错误的热力图\n",
 | ||
|     "    error_pairs = classification_results['most_common_errors'][:25]  # 前25个最常见错误\n",
 | ||
|     "    if error_pairs:\n",
 | ||
|     "        # 创建错误矩阵\n",
 | ||
|     "        unique_phonemes = list(set([pair[0][0] for pair in error_pairs] + [pair[0][1] for pair in error_pairs]))\n",
 | ||
|     "        error_matrix = np.zeros((len(unique_phonemes), len(unique_phonemes)))\n",
 | ||
|     "        \n",
 | ||
|     "        phoneme_to_idx = {phoneme: i for i, phoneme in enumerate(unique_phonemes)}\n",
 | ||
|     "        \n",
 | ||
|     "        for (true_id, pred_id), count in error_pairs:\n",
 | ||
|     "            if true_id in phoneme_to_idx and pred_id in phoneme_to_idx:\n",
 | ||
|     "                true_idx = phoneme_to_idx[true_id]\n",
 | ||
|     "                pred_idx = phoneme_to_idx[pred_id]\n",
 | ||
|     "                error_matrix[true_idx, pred_idx] = count\n",
 | ||
|     "        \n",
 | ||
|     "        im = axes[1, 2].imshow(error_matrix, cmap='Reds', interpolation='nearest')\n",
 | ||
|     "        axes[1, 2].set_title('最常见错误分布')\n",
 | ||
|     "        \n",
 | ||
|     "        # 设置标签\n",
 | ||
|     "        phoneme_names = [LOGIT_TO_PHONEME[p] for p in unique_phonemes]\n",
 | ||
|     "        axes[1, 2].set_xticks(range(len(phoneme_names)))\n",
 | ||
|     "        axes[1, 2].set_yticks(range(len(phoneme_names)))\n",
 | ||
|     "        axes[1, 2].set_xticklabels(phoneme_names, rotation=45, ha='right')\n",
 | ||
|     "        axes[1, 2].set_yticklabels(phoneme_names)\n",
 | ||
|     "        axes[1, 2].set_xlabel('预测音素')\n",
 | ||
|     "        axes[1, 2].set_ylabel('真实音素')\n",
 | ||
|     "        \n",
 | ||
|     "        # 添加颜色条\n",
 | ||
|     "        cbar = plt.colorbar(im, ax=axes[1, 2], shrink=0.8)\n",
 | ||
|     "        cbar.set_label('错误次数')\n",
 | ||
|     "    \n",
 | ||
|     "    plt.tight_layout()\n",
 | ||
|     "    plt.savefig('./processed_datasets/classification_analysis.png', dpi=300, bbox_inches='tight')\n",
 | ||
|     "    print(\"📁 分类分析图表已保存至: ./processed_datasets/classification_analysis.png\")\n",
 | ||
|     "    plt.show()\n",
 | ||
|     "\n",
 | ||
|     "print(\"✅ 回归转分类分析功能已创建!\")\n",
 | ||
|     "print(\"🎯 主要功能:\")\n",
 | ||
|     "print(\"• 将40维概率回归结果转换为分类预测\")\n",
 | ||
|     "print(\"• 计算分类准确率和置信度分析\")\n",
 | ||
|     "print(\"• 提供Top-K准确率评估\")\n",
 | ||
|     "print(\"• 生成详细的混淆矩阵和错误分析\")\n",
 | ||
|     "print(\"• 创建全面的可视化图表\")"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": null,
 | ||
|    "metadata": {},
 | ||
|    "outputs": [],
 | ||
|    "source": [
 | ||
|     "# 🎯 完整的回归转分类评估流程\n",
 | ||
|     "def complete_regression_classification_evaluation(rf_model, X_test, y_test, dataset_name=\"测试集\"):\n",
 | ||
|     "    \"\"\"\n",
 | ||
|     "    完整的回归模型转分类结果评估流程\n",
 | ||
|     "    \"\"\"\n",
 | ||
|     "    print(f\"\\n🎯 {dataset_name}完整评估: 回归 → 分类\")\n",
 | ||
|     "    print(\"=\"*70)\n",
 | ||
|     "    \n",
 | ||
|     "    # 1. 获取回归预测结果\n",
 | ||
|     "    print(\"📊 第1步: 获取回归预测...\")\n",
 | ||
|     "    y_pred_probs = rf_model.predict(X_test)\n",
 | ||
|     "    \n",
 | ||
|     "    # 确保概率值在合理范围内\n",
 | ||
|     "    y_pred_probs = np.clip(y_pred_probs, 0, 1)\n",
 | ||
|     "    \n",
 | ||
|     "    # 2. 回归性能评估\n",
 | ||
|     "    print(\"\\n📈 第2步: 回归性能评估...\")\n",
 | ||
|     "    mse = mean_squared_error(y_test, y_pred_probs) \n",
 | ||
|     "    mae = mean_absolute_error(y_test, y_pred_probs)\n",
 | ||
|     "    r2 = r2_score(y_test, y_pred_probs)\n",
 | ||
|     "    \n",
 | ||
|     "    print(f\"   回归 MSE: {mse:.6f}\")\n",
 | ||
|     "    print(f\"   回归 MAE: {mae:.6f}\")\n",
 | ||
|     "    print(f\"   回归 R²: {r2:.4f}\")\n",
 | ||
|     "    \n",
 | ||
|     "    # 3. 概率归一化(softmax)\n",
 | ||
|     "    print(\"\\n🔄 第3步: 概率归一化...\")\n",
 | ||
|     "    def softmax(x, axis=-1):\n",
 | ||
|     "        exp_x = np.exp(x - np.max(x, axis=axis, keepdims=True))\n",
 | ||
|     "        return exp_x / np.sum(exp_x, axis=axis, keepdims=True)\n",
 | ||
|     "    \n",
 | ||
|     "    # 对预测结果应用softmax,使其成为真正的概率分布\n",
 | ||
|     "    y_pred_probs_normalized = softmax(y_pred_probs)\n",
 | ||
|     "    y_test_normalized = softmax(y_test)  # 也对真实标签归一化\n",
 | ||
|     "    \n",
 | ||
|     "    print(f\"   预测概率归一化前: 每行和均值 = {np.mean(np.sum(y_pred_probs, axis=1)):.4f}\")\n",
 | ||
|     "    print(f\"   预测概率归一化后: 每行和均值 = {np.mean(np.sum(y_pred_probs_normalized, axis=1)):.4f}\")\n",
 | ||
|     "    \n",
 | ||
|     "    # 4. 分类结果分析\n",
 | ||
|     "    print(\"\\n🎯 第4步: 分类结果分析...\")\n",
 | ||
|     "    classification_results = regression_to_classification_analysis(\n",
 | ||
|     "        y_test_normalized, y_pred_probs_normalized, show_detailed_metrics=True\n",
 | ||
|     "    )\n",
 | ||
|     "    \n",
 | ||
|     "    # 5. 创建可视化\n",
 | ||
|     "    print(\"\\n📊 第5步: 创建可视化图表...\")\n",
 | ||
|     "    create_classification_visualizations(y_test_normalized, y_pred_probs_normalized, classification_results)\n",
 | ||
|     "    \n",
 | ||
|     "    # 6. 保存结果\n",
 | ||
|     "    print(\"\\n💾 第6步: 保存分析结果...\")\n",
 | ||
|     "    \n",
 | ||
|     "    # 保存分类结果\n",
 | ||
|     "    results_df = pd.DataFrame({\n",
 | ||
|     "        'true_class': classification_results['y_true_classes'],\n",
 | ||
|     "        'pred_class': classification_results['y_pred_classes'],\n",
 | ||
|     "        'true_phoneme': [LOGIT_TO_PHONEME[i] for i in classification_results['y_true_classes']],\n",
 | ||
|     "        'pred_phoneme': [LOGIT_TO_PHONEME[i] for i in classification_results['y_pred_classes']],\n",
 | ||
|     "        'pred_confidence': classification_results['pred_confidences'],\n",
 | ||
|     "        'is_correct': classification_results['y_true_classes'] == classification_results['y_pred_classes']\n",
 | ||
|     "    })\n",
 | ||
|     "    \n",
 | ||
|     "    results_df.to_csv('./processed_datasets/classification_results.csv', index=False)\n",
 | ||
|     "    \n",
 | ||
|     "    # 保存详细的概率预测\n",
 | ||
|     "    prob_results_df = pd.DataFrame(y_pred_probs_normalized, \n",
 | ||
|     "                                  columns=[f'prob_{LOGIT_TO_PHONEME[i]}' for i in range(40)])\n",
 | ||
|     "    prob_results_df['true_class'] = classification_results['y_true_classes']\n",
 | ||
|     "    prob_results_df['pred_class'] = classification_results['y_pred_classes']\n",
 | ||
|     "    \n",
 | ||
|     "    prob_results_df.to_csv('./processed_datasets/probability_predictions.csv', index=False)\n",
 | ||
|     "    \n",
 | ||
|     "    print(\"📁 结果已保存:\")\n",
 | ||
|     "    print(\"   • ./processed_datasets/classification_results.csv (分类结果)\")\n",
 | ||
|     "    print(\"   • ./processed_datasets/probability_predictions.csv (概率预测)\")\n",
 | ||
|     "    print(\"   • ./processed_datasets/classification_analysis.png (可视化图表)\")\n",
 | ||
|     "    \n",
 | ||
|     "    # 7. 总结报告\n",
 | ||
|     "    print(f\"\\n📋 {dataset_name}评估总结:\")\n",
 | ||
|     "    print(\"=\"*50)\n",
 | ||
|     "    print(f\"🔸 回归性能:\")\n",
 | ||
|     "    print(f\"   MSE: {mse:.6f}\")\n",
 | ||
|     "    print(f\"   R²: {r2:.4f}\")\n",
 | ||
|     "    print(f\"🔸 分类性能:\")\n",
 | ||
|     "    print(f\"   准确率: {classification_results['accuracy']:.4f} ({classification_results['accuracy']*100:.2f}%)\")\n",
 | ||
|     "    print(f\"   平均置信度: {classification_results['pred_confidences'].mean():.4f}\")\n",
 | ||
|     "    \n",
 | ||
|     "    # 计算Top-K准确率\n",
 | ||
|     "    for k in [1, 3, 5]:\n",
 | ||
|     "        top_k_pred = np.argsort(y_pred_probs_normalized, axis=1)[:, -k:]\n",
 | ||
|     "        top_k_accuracy = np.mean([classification_results['y_true_classes'][i] in top_k_pred[i] \n",
 | ||
|     "                                for i in range(len(classification_results['y_true_classes']))])\n",
 | ||
|     "        print(f\"   Top-{k} 准确率: {top_k_accuracy:.4f} ({top_k_accuracy*100:.2f}%)\")\n",
 | ||
|     "    \n",
 | ||
|     "    return {\n",
 | ||
|     "        'regression_metrics': {'mse': mse, 'mae': mae, 'r2': r2},\n",
 | ||
|     "        'classification_results': classification_results,\n",
 | ||
|     "        'normalized_predictions': y_pred_probs_normalized,\n",
 | ||
|     "        'normalized_true': y_test_normalized\n",
 | ||
|     "    }\n",
 | ||
|     "\n",
 | ||
|     "# 如果模型已训练且有验证数据,执行完整评估\n",
 | ||
|     "if 'rf_regressor' in locals() and hasattr(rf_regressor, 'is_fitted') and rf_regressor.is_fitted:\n",
 | ||
|     "    if 'X_val' in locals() and X_val is not None and 'y_val' in locals() and y_val is not None:\n",
 | ||
|     "        print(\"🚀 开始完整的回归转分类评估...\")\n",
 | ||
|     "        \n",
 | ||
|     "        # 执行完整评估\n",
 | ||
|     "        evaluation_results = complete_regression_classification_evaluation(\n",
 | ||
|     "            rf_regressor, X_val, y_val, \"验证集\"\n",
 | ||
|     "        )\n",
 | ||
|     "        \n",
 | ||
|     "        print(f\"\\n🎉 评估完成!\")\n",
 | ||
|     "        print(f\"✅ 随机森林回归模型成功转换为分类结果\")\n",
 | ||
|     "        print(f\"📊 生成了详细的性能分析和可视化\")\n",
 | ||
|     "        print(f\"💾 所有结果已保存到文件\")\n",
 | ||
|     "        \n",
 | ||
|     "        # 如果有测试数据,也进行评估\n",
 | ||
|     "        if 'X_test' in locals() and X_test is not None and 'y_test' in locals() and y_test is not None:\n",
 | ||
|     "            print(f\"\\n🔮 开始测试集评估...\")\n",
 | ||
|     "            test_evaluation_results = complete_regression_classification_evaluation(\n",
 | ||
|     "                rf_regressor, X_test, y_test, \"测试集\"\n",
 | ||
|     "            )\n",
 | ||
|     "    else:\n",
 | ||
|     "        print(\"⚠️ 没有可用的验证数据进行评估\")\n",
 | ||
|     "else:\n",
 | ||
|     "    print(\"⚠️ 随机森林模型尚未训练完成\")\n",
 | ||
|     "    print(\"💡 请先运行前面的训练代码\")"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": null,
 | ||
|    "metadata": {},
 | ||
|    "outputs": [
 | ||
|     {
 | ||
|      "data": {
 | ||
|       "text/plain": [
 | ||
|        "array([17.75      , -0.80859375, -2.03125   , -1.8046875 , -0.85546875,\n",
 | ||
|        "       -1.421875  ,  1.765625  , -2.703125  , -1.984375  ,  4.0625    ,\n",
 | ||
|        "        2.        , -3.15625   ,  0.72265625, -0.8671875 , -1.90625   ,\n",
 | ||
|        "       -2.0625    , -1.28125   , -1.03125   ,  0.21289062, -1.890625  ,\n",
 | ||
|        "       -0.4453125 , -0.5546875 ,  0.5625    , -0.421875  , -0.22460938,\n",
 | ||
|        "        0.3515625 , -2.375     , -1.8984375 ,  2.796875  ,  0.3515625 ,\n",
 | ||
|        "       -2.484375  ,  1.453125  ,  0.30078125, -2.390625  ,  0.19335938,\n",
 | ||
|        "        0.35742188, -1.484375  , -2.8125    , -0.84375   , -3.0625    ,\n",
 | ||
|        "        4.96875   ], dtype=float32)"
 | ||
|       ]
 | ||
|      },
 | ||
|      "execution_count": 84,
 | ||
|      "metadata": {},
 | ||
|      "output_type": "execute_result"
 | ||
|     }
 | ||
|    ],
 | ||
|    "source": [
 | ||
|     "single_result['train']['confidence_scores'][0][0]"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": null,
 | ||
|    "metadata": {},
 | ||
|    "outputs": [
 | ||
|     {
 | ||
|      "data": {
 | ||
|       "text/plain": [
 | ||
|        "dict_keys(['concatenated_data', 'processed_features', 'confidence_scores', 'trial_metadata', 'processing_stats'])"
 | ||
|       ]
 | ||
|      },
 | ||
|      "execution_count": 77,
 | ||
|      "metadata": {},
 | ||
|      "output_type": "execute_result"
 | ||
|     }
 | ||
|    ],
 | ||
|    "source": [
 | ||
|     "single_result['train'].keys()"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": null,
 | ||
|    "metadata": {},
 | ||
|    "outputs": [],
 | ||
|    "source": []
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "markdown",
 | ||
|    "metadata": {},
 | ||
|    "source": [
 | ||
|     "## 🌲 随机森林"
 | ||
|    ]
 | ||
|   },
 | ||
|   {
 | ||
|    "cell_type": "code",
 | ||
|    "execution_count": null,
 | ||
|    "metadata": {},
 | ||
|    "outputs": [
 | ||
|     {
 | ||
|      "name": "stdout",
 | ||
|      "output_type": "stream",
 | ||
|      "text": [
 | ||
|       "\n",
 | ||
|       "======================================================================\n",
 | ||
|       "🎯 重新定义任务:音素分类任务\n",
 | ||
|       "======================================================================\n",
 | ||
|       "📋 任务重新定义:\n",
 | ||
|       "   输入: 神经特征 (前7168维)\n",
 | ||
|       "   输出: 音素置信度 (后41维RNN输出)\n",
 | ||
|       "   目标: 预测每个时间步的音素概率分布\n",
 | ||
|       "   分类: 选择最大置信度对应的音素\n",
 | ||
|       "\n",
 | ||
|       "🔧 数据重新处理:\n",
 | ||
|       "   神经特征矩阵: (348, 7168)\n",
 | ||
|       "   音素logits矩阵: (348, 41)\n",
 | ||
|       "   音素类别标签: (348,) (值范围: 0-0)\n",
 | ||
|       "   音素分布: 1 个不同音素\n",
 | ||
|       "   样本最多的前5个音素:\n",
 | ||
|       "     音素 0: 348 次\n",
 | ||
|       "\n",
 | ||
|       "🔄 训练测试集切分:\n",
 | ||
|       "   训练集: 278 样本\n",
 | ||
|       "   测试集: 70 样本\n",
 | ||
|       "   音素类别数: 1\n",
 | ||
|       "   训练集音素分布: 1 个不同音素\n",
 | ||
|       "   测试集音素分布: 1 个不同音素\n",
 | ||
|       "\n",
 | ||
|       "======================================================================\n",
 | ||
|       "🌲 随机森林回归 + 分类\n",
 | ||
|       "======================================================================\n",
 | ||
|       "📊 方案1: 多输出回归 (神经特征 → 音素logits)\n",
 | ||
|       "🚀 训练回归模型...\n"
 | ||
|      ]
 | ||
|     },
 | ||
|     {
 | ||
|      "ename": "KeyboardInterrupt",
 | ||
|      "evalue": "",
 | ||
|      "output_type": "error",
 | ||
|      "traceback": [
 | ||
|       "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
 | ||
|       "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
 | ||
|       "\u001b[0;32m/tmp/ipykernel_186/3764425816.py\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[1;32m     85\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"🚀 训练回归模型...\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     86\u001b[0m \u001b[0mstart_time\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 87\u001b[0;31m \u001b[0mrf_regressor\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX_train_neu\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my_logits_train\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     88\u001b[0m \u001b[0mregression_time\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtime\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m-\u001b[0m \u001b[0mstart_time\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     89\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34mf\"✅ 回归训练完成!耗时: {regression_time:.2f} 秒\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
 | ||
|       "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/sklearn/base.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(estimator, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1387\u001b[0m                 )\n\u001b[1;32m   1388\u001b[0m             ):\n\u001b[0;32m-> 1389\u001b[0;31m                 \u001b[0;32mreturn\u001b[0m \u001b[0mfit_method\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mestimator\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1390\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1391\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
 | ||
|       "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/sklearn/multioutput.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, X, y, sample_weight, **fit_params)\u001b[0m\n\u001b[1;32m    272\u001b[0m                 \u001b[0mrouted_params\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mestimator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"sample_weight\"\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msample_weight\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    273\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 274\u001b[0;31m         self.estimators_ = Parallel(n_jobs=self.n_jobs)(\n\u001b[0m\u001b[1;32m    275\u001b[0m             delayed(_fit_estimator)(\n\u001b[1;32m    276\u001b[0m                 \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mestimator\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mrouted_params\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mestimator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
 | ||
|       "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/sklearn/utils/parallel.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, iterable)\u001b[0m\n\u001b[1;32m     75\u001b[0m             \u001b[0;32mfor\u001b[0m \u001b[0mdelayed_func\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m \u001b[0;32min\u001b[0m \u001b[0miterable\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     76\u001b[0m         )\n\u001b[0;32m---> 77\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__call__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0miterable_with_config\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     78\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     79\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
 | ||
|       "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/joblib/parallel.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, iterable)\u001b[0m\n\u001b[1;32m   1984\u001b[0m             \u001b[0moutput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_get_sequential_output\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0miterable\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1985\u001b[0m             \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1986\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreturn_generator\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1987\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1988\u001b[0m         \u001b[0;31m# Let's create an ID that uniquely identifies the current call. If the\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
 | ||
|       "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/joblib/parallel.py\u001b[0m in \u001b[0;36m_get_sequential_output\u001b[0;34m(self, iterable)\u001b[0m\n\u001b[1;32m   1912\u001b[0m                 \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mn_dispatched_batches\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1913\u001b[0m                 \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mn_dispatched_tasks\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1914\u001b[0;31m                 \u001b[0mres\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfunc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1915\u001b[0m                 \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mn_completed_tasks\u001b[0m \u001b[0;34m+=\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1916\u001b[0m                 \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mprint_progress\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
 | ||
|       "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/sklearn/utils/parallel.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, *args, **kwargs)\u001b[0m\n\u001b[1;32m    137\u001b[0m             \u001b[0mconfig\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    138\u001b[0m         \u001b[0;32mwith\u001b[0m \u001b[0mconfig_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m**\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 139\u001b[0;31m             \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfunction\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    140\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    141\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
 | ||
|       "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/sklearn/multioutput.py\u001b[0m in \u001b[0;36m_fit_estimator\u001b[0;34m(estimator, X, y, sample_weight, **fit_params)\u001b[0m\n\u001b[1;32m     61\u001b[0m         \u001b[0mestimator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msample_weight\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0msample_weight\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mfit_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     62\u001b[0m     \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 63\u001b[0;31m         \u001b[0mestimator\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mfit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mX\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mfit_params\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     64\u001b[0m     \u001b[0;32mreturn\u001b[0m \u001b[0mestimator\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     65\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
 | ||
|       "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/sklearn/base.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(estimator, *args, **kwargs)\u001b[0m\n\u001b[1;32m   1387\u001b[0m                 )\n\u001b[1;32m   1388\u001b[0m             ):\n\u001b[0;32m-> 1389\u001b[0;31m                 \u001b[0;32mreturn\u001b[0m \u001b[0mfit_method\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mestimator\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1390\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1391\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mwrapper\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
 | ||
|       "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/sklearn/ensemble/_forest.py\u001b[0m in \u001b[0;36mfit\u001b[0;34m(self, X, y, sample_weight)\u001b[0m\n\u001b[1;32m    485\u001b[0m             \u001b[0;31m# parallel_backend contexts set at a higher level,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    486\u001b[0m             \u001b[0;31m# since correctness does not rely on using threads.\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 487\u001b[0;31m             trees = Parallel(\n\u001b[0m\u001b[1;32m    488\u001b[0m                 \u001b[0mn_jobs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mn_jobs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    489\u001b[0m                 \u001b[0mverbose\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mverbose\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
 | ||
|       "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/sklearn/utils/parallel.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, iterable)\u001b[0m\n\u001b[1;32m     75\u001b[0m             \u001b[0;32mfor\u001b[0m \u001b[0mdelayed_func\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkwargs\u001b[0m \u001b[0;32min\u001b[0m \u001b[0miterable\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     76\u001b[0m         )\n\u001b[0;32m---> 77\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0msuper\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__call__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0miterable_with_config\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     78\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     79\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
 | ||
|       "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/joblib/parallel.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, iterable)\u001b[0m\n\u001b[1;32m   2070\u001b[0m         \u001b[0mnext\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2071\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 2072\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0moutput\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreturn_generator\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mlist\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   2073\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   2074\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0m__repr__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
 | ||
|       "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/joblib/parallel.py\u001b[0m in \u001b[0;36m_get_outputs\u001b[0;34m(self, iterator, pre_dispatch)\u001b[0m\n\u001b[1;32m   1680\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1681\u001b[0m             \u001b[0;32mwith\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_backend\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mretrieval_context\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1682\u001b[0;31m                 \u001b[0;32myield\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_retrieve\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1683\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1684\u001b[0m         \u001b[0;32mexcept\u001b[0m \u001b[0mGeneratorExit\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
 | ||
|       "\u001b[0;32m/usr/local/lib/python3.11/dist-packages/joblib/parallel.py\u001b[0m in \u001b[0;36m_retrieve\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m   1798\u001b[0m                     \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_jobs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mget_status\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtimeout\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtimeout\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m==\u001b[0m \u001b[0mTASK_PENDING\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1799\u001b[0m                 ):\n\u001b[0;32m-> 1800\u001b[0;31m                     \u001b[0mtime\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msleep\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;36m0.01\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m   1801\u001b[0m                     \u001b[0;32mcontinue\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m   1802\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
 | ||
|       "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
 | ||
|      ]
 | ||
|     }
 | ||
|    ],
 | ||
|    "source": [
 | ||
|     "# 🎯 正确的任务:音素分类 (神经特征 → 音素置信度回归)\n",
 | ||
|     "\n",
 | ||
|     "print(\"\\n\" + \"=\"*70)\n",
 | ||
|     "print(\"🎯 重新定义任务:音素分类任务\")\n",
 | ||
|     "print(\"=\"*70)\n",
 | ||
|     "\n",
 | ||
|     "print(\"📋 任务重新定义:\")\n",
 | ||
|     "print(\"   输入: 神经特征 (前7168维)\")\n",
 | ||
|     "print(\"   输出: 音素置信度 (后41维RNN输出)\")\n",
 | ||
|     "print(\"   目标: 预测每个时间步的音素概率分布\")\n",
 | ||
|     "print(\"   分类: 选择最大置信度对应的音素\") \n",
 | ||
|     "\n",
 | ||
|     "# 重新准备数据\n",
 | ||
|     "print(f\"\\n🔧 数据重新处理:\")\n",
 | ||
|     "\n",
 | ||
|     "# 分离输入特征和目标\n",
 | ||
|     "X_neural = []  # 神经特征 (7168维)\n",
 | ||
|     "y_phoneme_logits = []  # 音素置信度 (41维)\n",
 | ||
|     "\n",
 | ||
|     "for features in valid_features:\n",
 | ||
|     "    # features shape: [time_steps, 7209]\n",
 | ||
|     "    neural_part = features[:, :7168]  # 前7168维是神经特征\n",
 | ||
|     "    rnn_part = features[:, 7168:]     # 后41维是RNN输出(音素logits)\n",
 | ||
|     "    \n",
 | ||
|     "    # 对时间维度做平均\n",
 | ||
|     "    neural_pooled = np.mean(neural_part, axis=0)  # [7168]\n",
 | ||
|     "    rnn_pooled = np.mean(rnn_part, axis=0)        # [41]\n",
 | ||
|     "    \n",
 | ||
|     "    X_neural.append(neural_pooled)\n",
 | ||
|     "    y_phoneme_logits.append(rnn_pooled)\n",
 | ||
|     "\n",
 | ||
|     "X_neural = np.array(X_neural)  # [348, 7168]\n",
 | ||
|     "y_phoneme_logits = np.array(y_phoneme_logits)  # [348, 41]\n",
 | ||
|     "\n",
 | ||
|     "print(f\"   神经特征矩阵: {X_neural.shape}\")\n",
 | ||
|     "print(f\"   音素logits矩阵: {y_phoneme_logits.shape}\")\n",
 | ||
|     "\n",
 | ||
|     "# 从音素logits得到分类标签\n",
 | ||
|     "y_phoneme_class = np.argmax(y_phoneme_logits, axis=1)  # 选择最大值的索引\n",
 | ||
|     "print(f\"   音素类别标签: {y_phoneme_class.shape} (值范围: {y_phoneme_class.min()}-{y_phoneme_class.max()})\")\n",
 | ||
|     "\n",
 | ||
|     "# 显示音素分布\n",
 | ||
|     "from collections import Counter\n",
 | ||
|     "phoneme_dist = Counter(y_phoneme_class)\n",
 | ||
|     "print(f\"   音素分布: {len(phoneme_dist)} 个不同音素\")\n",
 | ||
|     "print(f\"   样本最多的前5个音素:\")\n",
 | ||
|     "for phoneme_id, count in phoneme_dist.most_common(5):\n",
 | ||
|     "    print(f\"     音素 {phoneme_id}: {count} 次\")\n",
 | ||
|     "\n",
 | ||
|     "# 训练测试集切分\n",
 | ||
|     "print(f\"\\n🔄 训练测试集切分:\")\n",
 | ||
|     "X_train_neu, X_test_neu, y_logits_train, y_logits_test, y_class_train, y_class_test = train_test_split(\n",
 | ||
|     "    X_neural, y_phoneme_logits, y_phoneme_class,\n",
 | ||
|     "    test_size=0.2, random_state=42, stratify=y_phoneme_class\n",
 | ||
|     ")\n",
 | ||
|     "\n",
 | ||
|     "print(f\"   训练集: {X_train_neu.shape[0]} 样本\")\n",
 | ||
|     "print(f\"   测试集: {X_test_neu.shape[0]} 样本\")\n",
 | ||
|     "print(f\"   音素类别数: {len(np.unique(y_phoneme_class))}\")\n",
 | ||
|     "\n",
 | ||
|     "# 检查类别分布\n",
 | ||
|     "train_dist = Counter(y_class_train)\n",
 | ||
|     "test_dist = Counter(y_class_test)\n",
 | ||
|     "print(f\"   训练集音素分布: {len(train_dist)} 个不同音素\")\n",
 | ||
|     "print(f\"   测试集音素分布: {len(test_dist)} 个不同音素\")\n",
 | ||
|     "\n",
 | ||
|     "print(\"\\n\" + \"=\"*70)\n",
 | ||
|     "print(\"🌲 随机森林回归 + 分类\")\n",
 | ||
|     "print(\"=\"*70)\n",
 | ||
|     "\n",
 | ||
|     "# 方案1: 多输出回归 (预测41维音素logits)\n",
 | ||
|     "from sklearn.ensemble import RandomForestRegressor\n",
 | ||
|     "from sklearn.multioutput import MultiOutputRegressor\n",
 | ||
|     "\n",
 | ||
|     "print(\"📊 方案1: 多输出回归 (神经特征 → 音素logits)\")\n",
 | ||
|     "rf_regressor = MultiOutputRegressor(\n",
 | ||
|     "    RandomForestRegressor(\n",
 | ||
|     "        n_estimators=100,\n",
 | ||
|     "        max_depth=10,\n",
 | ||
|     "        random_state=42,\n",
 | ||
|     "        n_jobs=-1\n",
 | ||
|     "    )\n",
 | ||
|     ")\n",
 | ||
|     "\n",
 | ||
|     "print(\"🚀 训练回归模型...\")\n",
 | ||
|     "start_time = time.time()\n",
 | ||
|     "rf_regressor.fit(X_train_neu, y_logits_train)\n",
 | ||
|     "regression_time = time.time() - start_time\n",
 | ||
|     "print(f\"✅ 回归训练完成!耗时: {regression_time:.2f} 秒\")\n",
 | ||
|     "\n",
 | ||
|     "# 预测音素logits\n",
 | ||
|     "print(\"🔮 预测音素置信度...\")\n",
 | ||
|     "y_logits_pred_train = rf_regressor.predict(X_train_neu)\n",
 | ||
|     "y_logits_pred_test = rf_regressor.predict(X_test_neu)\n",
 | ||
|     "\n",
 | ||
|     "# 从预测的logits得到分类结果\n",
 | ||
|     "y_class_pred_train = np.argmax(y_logits_pred_train, axis=1)\n",
 | ||
|     "y_class_pred_test = np.argmax(y_logits_pred_test, axis=1)\n",
 | ||
|     "\n",
 | ||
|     "# 评估分类性能\n",
 | ||
|     "from sklearn.metrics import accuracy_score, classification_report\n",
 | ||
|     "\n",
 | ||
|     "train_acc = accuracy_score(y_class_train, y_class_pred_train)\n",
 | ||
|     "test_acc = accuracy_score(y_class_test, y_class_pred_test)\n",
 | ||
|     "\n",
 | ||
|     "print(f\"\\n📊 音素分类性能评估:\")\n",
 | ||
|     "print(f\"   训练集准确率: {train_acc:.4f} ({train_acc*100:.2f}%)\")\n",
 | ||
|     "print(f\"   测试集准确率: {test_acc:.4f} ({test_acc*100:.2f}%)\")\n",
 | ||
|     "print(f\"   过拟合程度: {(train_acc - test_acc)*100:.2f}%\")\n",
 | ||
|     "\n",
 | ||
|     "# 评估回归性能\n",
 | ||
|     "from sklearn.metrics import mean_squared_error, r2_score\n",
 | ||
|     "\n",
 | ||
|     "mse_train = mean_squared_error(y_logits_train, y_logits_pred_train)\n",
 | ||
|     "mse_test = mean_squared_error(y_logits_test, y_logits_pred_test)\n",
 | ||
|     "r2_train = r2_score(y_logits_train, y_logits_pred_train)\n",
 | ||
|     "r2_test = r2_score(y_logits_test, y_logits_pred_test)\n",
 | ||
|     "\n",
 | ||
|     "print(f\"\\n📈 音素logits回归性能:\")\n",
 | ||
|     "print(f\"   训练集 MSE: {mse_train:.6f}\")\n",
 | ||
|     "print(f\"   测试集 MSE: {mse_test:.6f}\")\n",
 | ||
|     "print(f\"   训练集 R²: {r2_train:.4f}\")\n",
 | ||
|     "print(f\"   测试集 R²: {r2_test:.4f}\")\n",
 | ||
|     "\n",
 | ||
|     "print(f\"\\n✨ 任务修正完成!现在是正确的音素分类任务\")"
 | ||
|    ]
 | ||
|   }
 | ||
|  ],
 | ||
|  "metadata": {
 | ||
|   "kaggle": {
 | ||
|    "accelerator": "tpu1vmV38",
 | ||
|    "dataSources": [
 | ||
|     {
 | ||
|      "databundleVersionId": 13056355,
 | ||
|      "sourceId": 106809,
 | ||
|      "sourceType": "competition"
 | ||
|     }
 | ||
|    ],
 | ||
|    "dockerImageVersionId": 31091,
 | ||
|    "isGpuEnabled": false,
 | ||
|    "isInternetEnabled": true,
 | ||
|    "language": "python",
 | ||
|    "sourceType": "notebook"
 | ||
|   },
 | ||
|   "kernelspec": {
 | ||
|    "display_name": "Python 3 (ipykernel)",
 | ||
|    "language": "python",
 | ||
|    "name": "python3"
 | ||
|   },
 | ||
|   "language_info": {
 | ||
|    "codemirror_mode": {
 | ||
|     "name": "ipython",
 | ||
|     "version": 3
 | ||
|    },
 | ||
|    "file_extension": ".py",
 | ||
|    "mimetype": "text/x-python",
 | ||
|    "name": "python",
 | ||
|    "nbconvert_exporter": "python",
 | ||
|    "pygments_lexer": "ipython3",
 | ||
|    "version": "3.11.13"
 | ||
|   }
 | ||
|  },
 | ||
|  "nbformat": 4,
 | ||
|  "nbformat_minor": 4
 | ||
| }
 | 
