Files
b2txt25/brain-to-text-25/brain-to-text-25-datapre.ipynb
2025-10-06 15:17:44 +08:00

2670 lines
156 KiB
Plaintext
Raw Blame History

This file contains invisible Unicode characters

This file contains invisible Unicode characters that are indistinguishable to humans but may be processed differently by a computer. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 环境配置 与 utils"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Looking in indexes: https://download.pytorch.org/whl/cu126\n",
"Requirement already satisfied: torch in /usr/local/lib/python3.11/dist-packages (2.6.0+cu124)\n",
"Requirement already satisfied: torchvision in /usr/local/lib/python3.11/dist-packages (0.21.0+cu124)\n",
"Requirement already satisfied: torchaudio in /usr/local/lib/python3.11/dist-packages (2.6.0+cu124)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from torch) (3.18.0)\n",
"Requirement already satisfied: typing-extensions>=4.10.0 in /usr/local/lib/python3.11/dist-packages (from torch) (4.14.0)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch) (3.5)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch) (3.1.6)\n",
"Requirement already satisfied: fsspec in /usr/local/lib/python3.11/dist-packages (from torch) (2025.5.1)\n",
"Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)\n",
" Downloading https://download.pytorch.org/whl/cu126/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
"Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)\n",
" Downloading https://download.pytorch.org/whl/cu126/nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
"Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)\n",
" Downloading https://download.pytorch.org/whl/cu126/nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n",
"Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)\n",
" Downloading https://download.pytorch.org/whl/cu126/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n",
"Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)\n",
" Downloading https://download.pytorch.org/whl/cu126/nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
"Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)\n",
" Downloading https://download.pytorch.org/whl/cu126/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
"Collecting nvidia-curand-cu12==10.3.5.147 (from torch)\n",
" Downloading https://download.pytorch.org/whl/cu126/nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
"Collecting nvidia-cusolver-cu12==11.6.1.9 (from torch)\n",
" Downloading https://download.pytorch.org/whl/cu126/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n",
"Collecting nvidia-cusparse-cu12==12.3.1.170 (from torch)\n",
" Downloading https://download.pytorch.org/whl/cu126/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)\n",
"Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch) (0.6.2)\n",
"Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch) (2.21.5)\n",
"Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch) (12.4.127)\n",
"Collecting nvidia-nvjitlink-cu12==12.4.127 (from torch)\n",
" Downloading https://download.pytorch.org/whl/cu126/nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)\n",
"Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch) (3.2.0)\n",
"Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch) (1.13.1)\n",
"Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch) (1.3.0)\n",
"Requirement already satisfied: numpy in /usr/local/lib/python3.11/dist-packages (from torchvision) (1.26.4)\n",
"Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /usr/local/lib/python3.11/dist-packages (from torchvision) (11.2.1)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from jinja2->torch) (3.0.2)\n",
"Requirement already satisfied: mkl_fft in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (1.3.8)\n",
"Requirement already satisfied: mkl_random in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (1.2.4)\n",
"Requirement already satisfied: mkl_umath in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (0.1.1)\n",
"Requirement already satisfied: mkl in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (2025.2.0)\n",
"Requirement already satisfied: tbb4py in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (2022.2.0)\n",
"Requirement already satisfied: mkl-service in /usr/local/lib/python3.11/dist-packages (from numpy->torchvision) (2.4.1)\n",
"Requirement already satisfied: intel-openmp<2026,>=2024 in /usr/local/lib/python3.11/dist-packages (from mkl->numpy->torchvision) (2024.2.0)\n",
"Requirement already satisfied: tbb==2022.* in /usr/local/lib/python3.11/dist-packages (from mkl->numpy->torchvision) (2022.2.0)\n",
"Requirement already satisfied: tcmlib==1.* in /usr/local/lib/python3.11/dist-packages (from tbb==2022.*->mkl->numpy->torchvision) (1.4.0)\n",
"Requirement already satisfied: intel-cmplr-lib-rt in /usr/local/lib/python3.11/dist-packages (from mkl_umath->numpy->torchvision) (2024.2.0)\n",
"Requirement already satisfied: intel-cmplr-lib-ur==2024.2.0 in /usr/local/lib/python3.11/dist-packages (from intel-openmp<2026,>=2024->mkl->numpy->torchvision) (2024.2.0)\n",
"Downloading https://download.pytorch.org/whl/cu126/nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl (363.4 MB)\n",
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 363.4/363.4 MB 4.5 MB/s eta 0:00:00\n",
"Downloading https://download.pytorch.org/whl/cu126/nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (13.8 MB)\n",
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 13.8/13.8 MB 84.2 MB/s eta 0:00:00\n",
"Downloading https://download.pytorch.org/whl/cu126/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (24.6 MB)\n",
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 24.6/24.6 MB 78.5 MB/s eta 0:00:00\n",
"Downloading https://download.pytorch.org/whl/cu126/nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (883 kB)\n",
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 883.7/883.7 kB 41.1 MB/s eta 0:00:00\n",
"Downloading https://download.pytorch.org/whl/cu126/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl (664.8 MB)\n",
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 664.8/664.8 MB 2.1 MB/s eta 0:00:00\n",
"Downloading https://download.pytorch.org/whl/cu126/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl (211.5 MB)\n",
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 211.5/211.5 MB 2.1 MB/s eta 0:00:00\n",
"Downloading https://download.pytorch.org/whl/cu126/nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl (56.3 MB)\n",
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 56.3/56.3 MB 30.7 MB/s eta 0:00:00\n",
"Downloading https://download.pytorch.org/whl/cu126/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl (127.9 MB)\n",
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 127.9/127.9 MB 12.8 MB/s eta 0:00:00\n",
"Downloading https://download.pytorch.org/whl/cu126/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl (207.5 MB)\n",
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 207.5/207.5 MB 7.8 MB/s eta 0:00:00\n",
"Downloading https://download.pytorch.org/whl/cu126/nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl (21.1 MB)\n",
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 21.1/21.1 MB 79.7 MB/s eta 0:00:00\n",
"Installing collected packages: nvidia-nvjitlink-cu12, nvidia-curand-cu12, nvidia-cufft-cu12, nvidia-cuda-runtime-cu12, nvidia-cuda-nvrtc-cu12, nvidia-cuda-cupti-cu12, nvidia-cublas-cu12, nvidia-cusparse-cu12, nvidia-cudnn-cu12, nvidia-cusolver-cu12\n",
" Attempting uninstall: nvidia-nvjitlink-cu12\n",
" Found existing installation: nvidia-nvjitlink-cu12 12.5.82\n",
" Uninstalling nvidia-nvjitlink-cu12-12.5.82:\n",
" Successfully uninstalled nvidia-nvjitlink-cu12-12.5.82\n",
" Attempting uninstall: nvidia-curand-cu12\n",
" Found existing installation: nvidia-curand-cu12 10.3.6.82\n",
" Uninstalling nvidia-curand-cu12-10.3.6.82:\n",
" Successfully uninstalled nvidia-curand-cu12-10.3.6.82\n",
" Attempting uninstall: nvidia-cufft-cu12\n",
" Found existing installation: nvidia-cufft-cu12 11.2.3.61\n",
" Uninstalling nvidia-cufft-cu12-11.2.3.61:\n",
" Successfully uninstalled nvidia-cufft-cu12-11.2.3.61\n",
" Attempting uninstall: nvidia-cuda-runtime-cu12\n",
" Found existing installation: nvidia-cuda-runtime-cu12 12.5.82\n",
" Uninstalling nvidia-cuda-runtime-cu12-12.5.82:\n",
" Successfully uninstalled nvidia-cuda-runtime-cu12-12.5.82\n",
" Attempting uninstall: nvidia-cuda-nvrtc-cu12\n",
" Found existing installation: nvidia-cuda-nvrtc-cu12 12.5.82\n",
" Uninstalling nvidia-cuda-nvrtc-cu12-12.5.82:\n",
" Successfully uninstalled nvidia-cuda-nvrtc-cu12-12.5.82\n",
" Attempting uninstall: nvidia-cuda-cupti-cu12\n",
" Found existing installation: nvidia-cuda-cupti-cu12 12.5.82\n",
" Uninstalling nvidia-cuda-cupti-cu12-12.5.82:\n",
" Successfully uninstalled nvidia-cuda-cupti-cu12-12.5.82\n",
" Attempting uninstall: nvidia-cublas-cu12\n",
" Found existing installation: nvidia-cublas-cu12 12.5.3.2\n",
" Uninstalling nvidia-cublas-cu12-12.5.3.2:\n",
" Successfully uninstalled nvidia-cublas-cu12-12.5.3.2\n",
" Attempting uninstall: nvidia-cusparse-cu12\n",
" Found existing installation: nvidia-cusparse-cu12 12.5.1.3\n",
" Uninstalling nvidia-cusparse-cu12-12.5.1.3:\n",
" Successfully uninstalled nvidia-cusparse-cu12-12.5.1.3\n",
" Attempting uninstall: nvidia-cudnn-cu12\n",
" Found existing installation: nvidia-cudnn-cu12 9.3.0.75\n",
" Uninstalling nvidia-cudnn-cu12-9.3.0.75:\n",
" Successfully uninstalled nvidia-cudnn-cu12-9.3.0.75\n",
" Attempting uninstall: nvidia-cusolver-cu12\n",
" Found existing installation: nvidia-cusolver-cu12 11.6.3.83\n",
" Uninstalling nvidia-cusolver-cu12-11.6.3.83:\n",
" Successfully uninstalled nvidia-cusolver-cu12-11.6.3.83\n",
"Successfully installed nvidia-cublas-cu12-12.4.5.8 nvidia-cuda-cupti-cu12-12.4.127 nvidia-cuda-nvrtc-cu12-12.4.127 nvidia-cuda-runtime-cu12-12.4.127 nvidia-cudnn-cu12-9.1.0.70 nvidia-cufft-cu12-11.2.1.3 nvidia-curand-cu12-10.3.5.147 nvidia-cusolver-cu12-11.6.1.9 nvidia-cusparse-cu12-12.3.1.170 nvidia-nvjitlink-cu12-12.4.127\n",
"Collecting jupyter==1.1.1\n",
" Downloading jupyter-1.1.1-py2.py3-none-any.whl.metadata (2.0 kB)\n",
"Requirement already satisfied: numpy<2.1.0,>=1.26.0 in /usr/local/lib/python3.11/dist-packages (1.26.4)\n",
"Collecting pandas==2.3.0\n",
" Downloading pandas-2.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (91 kB)\n",
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 91.2/91.2 kB 4.6 MB/s eta 0:00:00\n",
"Collecting matplotlib==3.10.1\n",
" Downloading matplotlib-3.10.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (11 kB)\n",
"Collecting scipy==1.15.2\n",
" Downloading scipy-1.15.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (61 kB)\n",
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 62.0/62.0 kB 3.0 MB/s eta 0:00:00\n",
"Collecting scikit-learn==1.6.1\n",
" Downloading scikit_learn-1.6.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (18 kB)\n",
"Requirement already satisfied: tqdm==4.67.1 in /usr/local/lib/python3.11/dist-packages (4.67.1)\n",
"Collecting g2p_en==2.1.0\n",
" Downloading g2p_en-2.1.0-py3-none-any.whl.metadata (4.5 kB)\n",
"Collecting h5py==3.13.0\n",
" Downloading h5py-3.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (2.5 kB)\n",
"Requirement already satisfied: omegaconf==2.3.0 in /usr/local/lib/python3.11/dist-packages (2.3.0)\n",
"Requirement already satisfied: editdistance==0.8.1 in /usr/local/lib/python3.11/dist-packages (0.8.1)\n",
"Requirement already satisfied: huggingface-hub==0.33.1 in /usr/local/lib/python3.11/dist-packages (0.33.1)\n",
"Collecting transformers==4.53.0\n",
" Downloading transformers-4.53.0-py3-none-any.whl.metadata (39 kB)\n",
"Requirement already satisfied: tokenizers==0.21.2 in /usr/local/lib/python3.11/dist-packages (0.21.2)\n",
"Requirement already satisfied: accelerate==1.8.1 in /usr/local/lib/python3.11/dist-packages (1.8.1)\n",
"Collecting bitsandbytes==0.46.0\n",
" Downloading bitsandbytes-0.46.0-py3-none-manylinux_2_24_x86_64.whl.metadata (10 kB)\n",
"Requirement already satisfied: notebook in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.5.4)\n",
"Requirement already satisfied: jupyter-console in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.1.0)\n",
"Requirement already satisfied: nbconvert in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.4.5)\n",
"Requirement already satisfied: ipykernel in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (6.17.1)\n",
"Requirement already satisfied: ipywidgets in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (8.1.5)\n",
"Requirement already satisfied: jupyterlab in /usr/local/lib/python3.11/dist-packages (from jupyter==1.1.1) (3.6.8)\n",
"Requirement already satisfied: python-dateutil>=2.8.2 in /usr/local/lib/python3.11/dist-packages (from pandas==2.3.0) (2.9.0.post0)\n",
"Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.11/dist-packages (from pandas==2.3.0) (2025.2)\n",
"Requirement already satisfied: tzdata>=2022.7 in /usr/local/lib/python3.11/dist-packages (from pandas==2.3.0) (2025.2)\n",
"Requirement already satisfied: contourpy>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (1.3.2)\n",
"Requirement already satisfied: cycler>=0.10 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (0.12.1)\n",
"Requirement already satisfied: fonttools>=4.22.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (4.58.4)\n",
"Requirement already satisfied: kiwisolver>=1.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (1.4.8)\n",
"Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (25.0)\n",
"Requirement already satisfied: pillow>=8 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (11.2.1)\n",
"Requirement already satisfied: pyparsing>=2.3.1 in /usr/local/lib/python3.11/dist-packages (from matplotlib==3.10.1) (3.0.9)\n",
"Requirement already satisfied: joblib>=1.2.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn==1.6.1) (1.5.1)\n",
"Requirement already satisfied: threadpoolctl>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from scikit-learn==1.6.1) (3.6.0)\n",
"Requirement already satisfied: nltk>=3.2.4 in /usr/local/lib/python3.11/dist-packages (from g2p_en==2.1.0) (3.9.1)\n",
"Requirement already satisfied: inflect>=0.3.1 in /usr/local/lib/python3.11/dist-packages (from g2p_en==2.1.0) (7.5.0)\n",
"Collecting distance>=0.1.3 (from g2p_en==2.1.0)\n",
" Downloading Distance-0.1.3.tar.gz (180 kB)\n",
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 180.3/180.3 kB 9.7 MB/s eta 0:00:00\n",
" Preparing metadata (setup.py): started\n",
" Preparing metadata (setup.py): finished with status 'done'\n",
"Requirement already satisfied: antlr4-python3-runtime==4.9.* in /usr/local/lib/python3.11/dist-packages (from omegaconf==2.3.0) (4.9.3)\n",
"Requirement already satisfied: PyYAML>=5.1.0 in /usr/local/lib/python3.11/dist-packages (from omegaconf==2.3.0) (6.0.2)\n",
"Requirement already satisfied: filelock in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (3.18.0)\n",
"Requirement already satisfied: fsspec>=2023.5.0 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (2025.5.1)\n",
"Requirement already satisfied: requests in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (2.32.4)\n",
"Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (4.14.0)\n",
"Requirement already satisfied: hf-xet<2.0.0,>=1.1.2 in /usr/local/lib/python3.11/dist-packages (from huggingface-hub==0.33.1) (1.1.5)\n",
"Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.11/dist-packages (from transformers==4.53.0) (2024.11.6)\n",
"Requirement already satisfied: safetensors>=0.4.3 in /usr/local/lib/python3.11/dist-packages (from transformers==4.53.0) (0.5.3)\n",
"Requirement already satisfied: psutil in /usr/local/lib/python3.11/dist-packages (from accelerate==1.8.1) (7.0.0)\n",
"Requirement already satisfied: torch>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from accelerate==1.8.1) (2.6.0+cu124)\n",
"Requirement already satisfied: mkl_fft in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (1.3.8)\n",
"Requirement already satisfied: mkl_random in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (1.2.4)\n",
"Requirement already satisfied: mkl_umath in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (0.1.1)\n",
"Requirement already satisfied: mkl in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (2025.2.0)\n",
"Requirement already satisfied: tbb4py in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (2022.2.0)\n",
"Requirement already satisfied: mkl-service in /usr/local/lib/python3.11/dist-packages (from numpy<2.1.0,>=1.26.0) (2.4.1)\n",
"Requirement already satisfied: more_itertools>=8.5.0 in /usr/local/lib/python3.11/dist-packages (from inflect>=0.3.1->g2p_en==2.1.0) (10.7.0)\n",
"Requirement already satisfied: typeguard>=4.0.1 in /usr/local/lib/python3.11/dist-packages (from inflect>=0.3.1->g2p_en==2.1.0) (4.4.4)\n",
"Requirement already satisfied: click in /usr/local/lib/python3.11/dist-packages (from nltk>=3.2.4->g2p_en==2.1.0) (8.2.1)\n",
"Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.11/dist-packages (from python-dateutil>=2.8.2->pandas==2.3.0) (1.17.0)\n",
"Requirement already satisfied: networkx in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (3.5)\n",
"Requirement already satisfied: jinja2 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (3.1.6)\n",
"Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n",
"Requirement already satisfied: nvidia-cuda-runtime-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n",
"Requirement already satisfied: nvidia-cuda-cupti-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n",
"Requirement already satisfied: nvidia-cudnn-cu12==9.1.0.70 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (9.1.0.70)\n",
"Requirement already satisfied: nvidia-cublas-cu12==12.4.5.8 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.5.8)\n",
"Requirement already satisfied: nvidia-cufft-cu12==11.2.1.3 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (11.2.1.3)\n",
"Requirement already satisfied: nvidia-curand-cu12==10.3.5.147 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (10.3.5.147)\n",
"Requirement already satisfied: nvidia-cusolver-cu12==11.6.1.9 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (11.6.1.9)\n",
"Requirement already satisfied: nvidia-cusparse-cu12==12.3.1.170 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.3.1.170)\n",
"Requirement already satisfied: nvidia-cusparselt-cu12==0.6.2 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (0.6.2)\n",
"Requirement already satisfied: nvidia-nccl-cu12==2.21.5 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (2.21.5)\n",
"Requirement already satisfied: nvidia-nvtx-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n",
"Requirement already satisfied: nvidia-nvjitlink-cu12==12.4.127 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (12.4.127)\n",
"Requirement already satisfied: triton==3.2.0 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (3.2.0)\n",
"Requirement already satisfied: sympy==1.13.1 in /usr/local/lib/python3.11/dist-packages (from torch>=2.0.0->accelerate==1.8.1) (1.13.1)\n",
"Requirement already satisfied: mpmath<1.4,>=1.1.0 in /usr/local/lib/python3.11/dist-packages (from sympy==1.13.1->torch>=2.0.0->accelerate==1.8.1) (1.3.0)\n",
"Requirement already satisfied: debugpy>=1.0 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (1.8.0)\n",
"Requirement already satisfied: ipython>=7.23.1 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (7.34.0)\n",
"Requirement already satisfied: jupyter-client>=6.1.12 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (8.6.3)\n",
"Requirement already satisfied: matplotlib-inline>=0.1 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (0.1.7)\n",
"Requirement already satisfied: nest-asyncio in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (1.6.0)\n",
"Requirement already satisfied: pyzmq>=17 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (24.0.1)\n",
"Requirement already satisfied: tornado>=6.1 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (6.5.1)\n",
"Requirement already satisfied: traitlets>=5.1.0 in /usr/local/lib/python3.11/dist-packages (from ipykernel->jupyter==1.1.1) (5.7.1)\n",
"Requirement already satisfied: comm>=0.1.3 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->jupyter==1.1.1) (0.2.2)\n",
"Requirement already satisfied: widgetsnbextension~=4.0.12 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->jupyter==1.1.1) (4.0.14)\n",
"Requirement already satisfied: jupyterlab-widgets~=3.0.12 in /usr/local/lib/python3.11/dist-packages (from ipywidgets->jupyter==1.1.1) (3.0.15)\n",
"Requirement already satisfied: prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-console->jupyter==1.1.1) (3.0.51)\n",
"Requirement already satisfied: pygments in /usr/local/lib/python3.11/dist-packages (from jupyter-console->jupyter==1.1.1) (2.19.2)\n",
"Requirement already satisfied: jupyter-core in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (5.8.1)\n",
"Requirement already satisfied: jupyterlab-server~=2.19 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (2.27.3)\n",
"Requirement already satisfied: jupyter-server<3,>=1.16.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (2.12.5)\n",
"Requirement already satisfied: jupyter-ydoc~=0.2.4 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (0.2.5)\n",
"Requirement already satisfied: jupyter-server-ydoc~=0.8.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (0.8.0)\n",
"Requirement already satisfied: nbclassic in /usr/local/lib/python3.11/dist-packages (from jupyterlab->jupyter==1.1.1) (1.3.1)\n",
"Requirement already satisfied: argon2-cffi in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (25.1.0)\n",
"Requirement already satisfied: ipython-genutils in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (0.2.0)\n",
"Requirement already satisfied: nbformat in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (5.10.4)\n",
"Requirement already satisfied: Send2Trash>=1.8.0 in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (1.8.3)\n",
"Requirement already satisfied: terminado>=0.8.3 in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (0.18.1)\n",
"Requirement already satisfied: prometheus-client in /usr/local/lib/python3.11/dist-packages (from notebook->jupyter==1.1.1) (0.22.1)\n",
"Requirement already satisfied: mistune<2,>=0.8.1 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.8.4)\n",
"Requirement already satisfied: jupyterlab-pygments in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.3.0)\n",
"Requirement already satisfied: entrypoints>=0.2.2 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.4)\n",
"Requirement already satisfied: bleach in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (6.2.0)\n",
"Requirement already satisfied: pandocfilters>=1.4.1 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (1.5.1)\n",
"Requirement already satisfied: testpath in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.6.0)\n",
"Requirement already satisfied: defusedxml in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.7.1)\n",
"Requirement already satisfied: beautifulsoup4 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (4.13.4)\n",
"Requirement already satisfied: nbclient<0.6.0,>=0.5.0 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (0.5.13)\n",
"Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.11/dist-packages (from nbconvert->jupyter==1.1.1) (3.0.2)\n",
"Requirement already satisfied: intel-openmp<2026,>=2024 in /usr/local/lib/python3.11/dist-packages (from mkl->numpy<2.1.0,>=1.26.0) (2024.2.0)\n",
"Requirement already satisfied: tbb==2022.* in /usr/local/lib/python3.11/dist-packages (from mkl->numpy<2.1.0,>=1.26.0) (2022.2.0)\n",
"Requirement already satisfied: tcmlib==1.* in /usr/local/lib/python3.11/dist-packages (from tbb==2022.*->mkl->numpy<2.1.0,>=1.26.0) (1.4.0)\n",
"Requirement already satisfied: intel-cmplr-lib-rt in /usr/local/lib/python3.11/dist-packages (from mkl_umath->numpy<2.1.0,>=1.26.0) (2024.2.0)\n",
"Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (3.4.2)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (3.10)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (2.5.0)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests->huggingface-hub==0.33.1) (2025.6.15)\n",
"Requirement already satisfied: intel-cmplr-lib-ur==2024.2.0 in /usr/local/lib/python3.11/dist-packages (from intel-openmp<2026,>=2024->mkl->numpy<2.1.0,>=1.26.0) (2024.2.0)\n",
"Requirement already satisfied: setuptools>=18.5 in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (75.2.0)\n",
"Requirement already satisfied: jedi>=0.16 in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.19.2)\n",
"Requirement already satisfied: decorator in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (4.4.2)\n",
"Requirement already satisfied: pickleshare in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.7.5)\n",
"Requirement already satisfied: backcall in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.2.0)\n",
"Requirement already satisfied: pexpect>4.3 in /usr/local/lib/python3.11/dist-packages (from ipython>=7.23.1->ipykernel->jupyter==1.1.1) (4.9.0)\n",
"Requirement already satisfied: platformdirs>=2.5 in /usr/local/lib/python3.11/dist-packages (from jupyter-core->jupyterlab->jupyter==1.1.1) (4.3.8)\n",
"Requirement already satisfied: anyio>=3.1.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (4.9.0)\n",
"Requirement already satisfied: jupyter-events>=0.9.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.12.0)\n",
"Requirement already satisfied: jupyter-server-terminals in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.5.3)\n",
"Requirement already satisfied: overrides in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (7.7.0)\n",
"Requirement already satisfied: websocket-client in /usr/local/lib/python3.11/dist-packages (from jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.8.0)\n",
"Requirement already satisfied: jupyter-server-fileid<1,>=0.6.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (0.9.3)\n",
"Requirement already satisfied: ypy-websocket<0.9.0,>=0.8.2 in /usr/local/lib/python3.11/dist-packages (from jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (0.8.4)\n",
"Requirement already satisfied: y-py<0.7.0,>=0.6.0 in /usr/local/lib/python3.11/dist-packages (from jupyter-ydoc~=0.2.4->jupyterlab->jupyter==1.1.1) (0.6.2)\n",
"Requirement already satisfied: babel>=2.10 in /usr/local/lib/python3.11/dist-packages (from jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (2.17.0)\n",
"Requirement already satisfied: json5>=0.9.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (0.12.0)\n",
"Requirement already satisfied: jsonschema>=4.18.0 in /usr/local/lib/python3.11/dist-packages (from jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (4.24.0)\n",
"Requirement already satisfied: notebook-shim>=0.2.3 in /usr/local/lib/python3.11/dist-packages (from nbclassic->jupyterlab->jupyter==1.1.1) (0.2.4)\n",
"Requirement already satisfied: fastjsonschema>=2.15 in /usr/local/lib/python3.11/dist-packages (from nbformat->notebook->jupyter==1.1.1) (2.21.1)\n",
"Requirement already satisfied: wcwidth in /usr/local/lib/python3.11/dist-packages (from prompt-toolkit!=3.0.0,!=3.0.1,<3.1.0,>=2.0.0->jupyter-console->jupyter==1.1.1) (0.2.13)\n",
"Requirement already satisfied: ptyprocess in /usr/local/lib/python3.11/dist-packages (from terminado>=0.8.3->notebook->jupyter==1.1.1) (0.7.0)\n",
"Requirement already satisfied: argon2-cffi-bindings in /usr/local/lib/python3.11/dist-packages (from argon2-cffi->notebook->jupyter==1.1.1) (21.2.0)\n",
"Requirement already satisfied: soupsieve>1.2 in /usr/local/lib/python3.11/dist-packages (from beautifulsoup4->nbconvert->jupyter==1.1.1) (2.7)\n",
"Requirement already satisfied: webencodings in /usr/local/lib/python3.11/dist-packages (from bleach->nbconvert->jupyter==1.1.1) (0.5.1)\n",
"Requirement already satisfied: sniffio>=1.1 in /usr/local/lib/python3.11/dist-packages (from anyio>=3.1.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.3.1)\n",
"Requirement already satisfied: parso<0.9.0,>=0.8.4 in /usr/local/lib/python3.11/dist-packages (from jedi>=0.16->ipython>=7.23.1->ipykernel->jupyter==1.1.1) (0.8.4)\n",
"Requirement already satisfied: attrs>=22.2.0 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (25.3.0)\n",
"Requirement already satisfied: jsonschema-specifications>=2023.03.6 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (2025.4.1)\n",
"Requirement already satisfied: referencing>=0.28.4 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (0.36.2)\n",
"Requirement already satisfied: rpds-py>=0.7.1 in /usr/local/lib/python3.11/dist-packages (from jsonschema>=4.18.0->jupyterlab-server~=2.19->jupyterlab->jupyter==1.1.1) (0.25.1)\n",
"Requirement already satisfied: python-json-logger>=2.0.4 in /usr/local/lib/python3.11/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (3.3.0)\n",
"Requirement already satisfied: rfc3339-validator in /usr/local/lib/python3.11/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.1.4)\n",
"Requirement already satisfied: rfc3986-validator>=0.1.1 in /usr/local/lib/python3.11/dist-packages (from jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (0.1.1)\n",
"Requirement already satisfied: aiofiles<23,>=22.1.0 in /usr/local/lib/python3.11/dist-packages (from ypy-websocket<0.9.0,>=0.8.2->jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (22.1.0)\n",
"Requirement already satisfied: aiosqlite<1,>=0.17.0 in /usr/local/lib/python3.11/dist-packages (from ypy-websocket<0.9.0,>=0.8.2->jupyter-server-ydoc~=0.8.0->jupyterlab->jupyter==1.1.1) (0.21.0)\n",
"Requirement already satisfied: cffi>=1.0.1 in /usr/local/lib/python3.11/dist-packages (from argon2-cffi-bindings->argon2-cffi->notebook->jupyter==1.1.1) (1.17.1)\n",
"Requirement already satisfied: pycparser in /usr/local/lib/python3.11/dist-packages (from cffi>=1.0.1->argon2-cffi-bindings->argon2-cffi->notebook->jupyter==1.1.1) (2.22)\n",
"Requirement already satisfied: fqdn in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.5.1)\n",
"Requirement already satisfied: isoduration in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (20.11.0)\n",
"Requirement already satisfied: jsonpointer>1.13 in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (3.0.0)\n",
"Requirement already satisfied: uri-template in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.3.0)\n",
"Requirement already satisfied: webcolors>=24.6.0 in /usr/local/lib/python3.11/dist-packages (from jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (24.11.1)\n",
"Requirement already satisfied: arrow>=0.15.0 in /usr/local/lib/python3.11/dist-packages (from isoduration->jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (1.3.0)\n",
"Requirement already satisfied: types-python-dateutil>=2.8.10 in /usr/local/lib/python3.11/dist-packages (from arrow>=0.15.0->isoduration->jsonschema[format-nongpl]>=4.18.0->jupyter-events>=0.9.0->jupyter-server<3,>=1.16.0->jupyterlab->jupyter==1.1.1) (2.9.0.20250516)\n",
"Downloading jupyter-1.1.1-py2.py3-none-any.whl (2.7 kB)\n",
"Downloading pandas-2.3.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (12.4 MB)\n",
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 12.4/12.4 MB 84.0 MB/s eta 0:00:00\n",
"Downloading matplotlib-3.10.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (8.6 MB)\n",
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 8.6/8.6 MB 103.4 MB/s eta 0:00:00\n",
"Downloading scipy-1.15.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (37.6 MB)\n",
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 37.6/37.6 MB 44.8 MB/s eta 0:00:00\n",
"Downloading scikit_learn-1.6.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (13.5 MB)\n",
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 13.5/13.5 MB 75.2 MB/s eta 0:00:00\n",
"Downloading g2p_en-2.1.0-py3-none-any.whl (3.1 MB)\n",
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 3.1/3.1 MB 80.6 MB/s eta 0:00:00\n",
"Downloading h5py-3.13.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (4.5 MB)\n",
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.5/4.5 MB 89.1 MB/s eta 0:00:00\n",
"Downloading transformers-4.53.0-py3-none-any.whl (10.8 MB)\n",
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 10.8/10.8 MB 105.1 MB/s eta 0:00:00\n",
"Downloading bitsandbytes-0.46.0-py3-none-manylinux_2_24_x86_64.whl (67.0 MB)\n",
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 67.0/67.0 MB 17.9 MB/s eta 0:00:00\n",
"Building wheels for collected packages: distance\n",
" Building wheel for distance (setup.py): started\n",
" Building wheel for distance (setup.py): finished with status 'done'\n",
" Created wheel for distance: filename=Distance-0.1.3-py3-none-any.whl size=16256 sha256=3bf3b151b4d15c4f5e442085e48cff050eec4c6d192a307e2b09a77daa84a5dc\n",
" Stored in directory: /root/.cache/pip/wheels/fb/cd/9c/3ab5d666e3bcacc58900b10959edd3816cc9557c7337986322\n",
"Successfully built distance\n",
"Installing collected packages: distance, jupyter, scipy, transformers, scikit-learn, pandas, matplotlib, h5py, g2p_en, bitsandbytes\n",
" Attempting uninstall: scipy\n",
" Found existing installation: scipy 1.15.3\n",
" Uninstalling scipy-1.15.3:\n",
" Successfully uninstalled scipy-1.15.3\n",
" Attempting uninstall: transformers\n",
" Found existing installation: transformers 4.52.4\n",
" Uninstalling transformers-4.52.4:\n",
" Successfully uninstalled transformers-4.52.4\n",
" Attempting uninstall: scikit-learn\n",
" Found existing installation: scikit-learn 1.2.2\n",
" Uninstalling scikit-learn-1.2.2:\n",
" Successfully uninstalled scikit-learn-1.2.2\n",
" Attempting uninstall: pandas\n",
" Found existing installation: pandas 2.2.3\n",
" Uninstalling pandas-2.2.3:\n",
" Successfully uninstalled pandas-2.2.3\n",
" Attempting uninstall: matplotlib\n",
" Found existing installation: matplotlib 3.7.2\n",
" Uninstalling matplotlib-3.7.2:\n",
" Successfully uninstalled matplotlib-3.7.2\n",
" Attempting uninstall: h5py\n",
" Found existing installation: h5py 3.14.0\n",
" Uninstalling h5py-3.14.0:\n",
" Successfully uninstalled h5py-3.14.0\n",
"Successfully installed bitsandbytes-0.46.0 distance-0.1.3 g2p_en-2.1.0 h5py-3.13.0 jupyter-1.1.1 matplotlib-3.10.1 pandas-2.3.0 scikit-learn-1.6.1 scipy-1.15.2 transformers-4.53.0\n",
"Requirement already satisfied: PyDrive2 in /usr/local/lib/python3.11/dist-packages (1.21.3)\n",
"Requirement already satisfied: google-api-python-client>=1.12.5 in /usr/local/lib/python3.11/dist-packages (from PyDrive2) (2.173.0)\n",
"Requirement already satisfied: oauth2client>=4.0.0 in /usr/local/lib/python3.11/dist-packages (from PyDrive2) (4.1.3)\n",
"Requirement already satisfied: PyYAML>=3.0 in /usr/local/lib/python3.11/dist-packages (from PyDrive2) (6.0.2)\n",
"Collecting cryptography<44 (from PyDrive2)\n",
" Downloading cryptography-43.0.3-cp39-abi3-manylinux_2_28_x86_64.whl.metadata (5.4 kB)\n",
"Collecting pyOpenSSL<=24.2.1,>=19.1.0 (from PyDrive2)\n",
" Downloading pyOpenSSL-24.2.1-py3-none-any.whl.metadata (13 kB)\n",
"Requirement already satisfied: cffi>=1.12 in /usr/local/lib/python3.11/dist-packages (from cryptography<44->PyDrive2) (1.17.1)\n",
"Requirement already satisfied: httplib2<1.0.0,>=0.19.0 in /usr/local/lib/python3.11/dist-packages (from google-api-python-client>=1.12.5->PyDrive2) (0.22.0)\n",
"Requirement already satisfied: google-auth!=2.24.0,!=2.25.0,<3.0.0,>=1.32.0 in /usr/local/lib/python3.11/dist-packages (from google-api-python-client>=1.12.5->PyDrive2) (2.40.3)\n",
"Requirement already satisfied: google-auth-httplib2<1.0.0,>=0.2.0 in /usr/local/lib/python3.11/dist-packages (from google-api-python-client>=1.12.5->PyDrive2) (0.2.0)\n",
"Requirement already satisfied: google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5 in /usr/local/lib/python3.11/dist-packages (from google-api-python-client>=1.12.5->PyDrive2) (1.34.1)\n",
"Requirement already satisfied: uritemplate<5,>=3.0.1 in /usr/local/lib/python3.11/dist-packages (from google-api-python-client>=1.12.5->PyDrive2) (4.2.0)\n",
"Requirement already satisfied: pyasn1>=0.1.7 in /usr/local/lib/python3.11/dist-packages (from oauth2client>=4.0.0->PyDrive2) (0.6.1)\n",
"Requirement already satisfied: pyasn1-modules>=0.0.5 in /usr/local/lib/python3.11/dist-packages (from oauth2client>=4.0.0->PyDrive2) (0.4.2)\n",
"Requirement already satisfied: rsa>=3.1.4 in /usr/local/lib/python3.11/dist-packages (from oauth2client>=4.0.0->PyDrive2) (4.9.1)\n",
"Requirement already satisfied: six>=1.6.1 in /usr/local/lib/python3.11/dist-packages (from oauth2client>=4.0.0->PyDrive2) (1.17.0)\n",
"Requirement already satisfied: pycparser in /usr/local/lib/python3.11/dist-packages (from cffi>=1.12->cryptography<44->PyDrive2) (2.22)\n",
"Requirement already satisfied: googleapis-common-protos<2.0dev,>=1.56.2 in /usr/local/lib/python3.11/dist-packages (from google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5->google-api-python-client>=1.12.5->PyDrive2) (1.70.0)\n",
"Requirement already satisfied: protobuf!=3.20.0,!=3.20.1,!=4.21.0,!=4.21.1,!=4.21.2,!=4.21.3,!=4.21.4,!=4.21.5,<4.0.0dev,>=3.19.5 in /usr/local/lib/python3.11/dist-packages (from google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5->google-api-python-client>=1.12.5->PyDrive2) (3.20.3)\n",
"Requirement already satisfied: requests<3.0.0dev,>=2.18.0 in /usr/local/lib/python3.11/dist-packages (from google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5->google-api-python-client>=1.12.5->PyDrive2) (2.32.4)\n",
"Requirement already satisfied: cachetools<6.0,>=2.0.0 in /usr/local/lib/python3.11/dist-packages (from google-auth!=2.24.0,!=2.25.0,<3.0.0,>=1.32.0->google-api-python-client>=1.12.5->PyDrive2) (5.5.2)\n",
"Requirement already satisfied: pyparsing!=3.0.0,!=3.0.1,!=3.0.2,!=3.0.3,<4,>=2.4.2 in /usr/local/lib/python3.11/dist-packages (from httplib2<1.0.0,>=0.19.0->google-api-python-client>=1.12.5->PyDrive2) (3.0.9)\n",
"Requirement already satisfied: charset_normalizer<4,>=2 in /usr/local/lib/python3.11/dist-packages (from requests<3.0.0dev,>=2.18.0->google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5->google-api-python-client>=1.12.5->PyDrive2) (3.4.2)\n",
"Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.11/dist-packages (from requests<3.0.0dev,>=2.18.0->google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5->google-api-python-client>=1.12.5->PyDrive2) (3.10)\n",
"Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.11/dist-packages (from requests<3.0.0dev,>=2.18.0->google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5->google-api-python-client>=1.12.5->PyDrive2) (2.5.0)\n",
"Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.11/dist-packages (from requests<3.0.0dev,>=2.18.0->google-api-core!=2.0.*,!=2.1.*,!=2.2.*,!=2.3.0,<3.0.0,>=1.31.5->google-api-python-client>=1.12.5->PyDrive2) (2025.6.15)\n",
"Downloading cryptography-43.0.3-cp39-abi3-manylinux_2_28_x86_64.whl (4.0 MB)\n",
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.0/4.0 MB 49.5 MB/s eta 0:00:00\n",
"Downloading pyOpenSSL-24.2.1-py3-none-any.whl (58 kB)\n",
" ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 58.4/58.4 kB 3.4 MB/s eta 0:00:00\n",
"Installing collected packages: cryptography, pyOpenSSL\n",
" Attempting uninstall: cryptography\n",
" Found existing installation: cryptography 44.0.3\n",
" Uninstalling cryptography-44.0.3:\n",
" Successfully uninstalled cryptography-44.0.3\n",
" Attempting uninstall: pyOpenSSL\n",
" Found existing installation: pyOpenSSL 25.1.0\n",
" Uninstalling pyOpenSSL-25.1.0:\n",
" Successfully uninstalled pyOpenSSL-25.1.0\n",
"Successfully installed cryptography-43.0.3 pyOpenSSL-24.2.1\n",
"Obtaining file:///kaggle/working/nejm-brain-to-text\n",
" Preparing metadata (setup.py): started\n",
" Preparing metadata (setup.py): finished with status 'done'\n",
"Installing collected packages: nejm_b2txt_utils\n",
" Running setup.py develop for nejm_b2txt_utils\n",
"Successfully installed nejm_b2txt_utils-0.0.0\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"Cloning into 'nejm-brain-to-text'...\n",
"ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.\n",
"bigframes 2.8.0 requires google-cloud-bigquery-storage<3.0.0,>=2.30.0, which is not installed.\n",
"gensim 4.3.3 requires scipy<1.14.0,>=1.7.0, but you have scipy 1.15.2 which is incompatible.\n",
"dask-cudf-cu12 25.2.2 requires pandas<2.2.4dev0,>=2.0, but you have pandas 2.3.0 which is incompatible.\n",
"cudf-cu12 25.2.2 requires pandas<2.2.4dev0,>=2.0, but you have pandas 2.3.0 which is incompatible.\n",
"datasets 3.6.0 requires fsspec[http]<=2025.3.0,>=2023.1.0, but you have fsspec 2025.5.1 which is incompatible.\n",
"ydata-profiling 4.16.1 requires matplotlib<=3.10,>=3.5, but you have matplotlib 3.10.1 which is incompatible.\n",
"category-encoders 2.7.0 requires scikit-learn<1.6.0,>=1.0.0, but you have scikit-learn 1.6.1 which is incompatible.\n",
"cesium 0.12.4 requires numpy<3.0,>=2.0, but you have numpy 1.26.4 which is incompatible.\n",
"google-colab 1.0.0 requires google-auth==2.38.0, but you have google-auth 2.40.3 which is incompatible.\n",
"google-colab 1.0.0 requires notebook==6.5.7, but you have notebook 6.5.4 which is incompatible.\n",
"google-colab 1.0.0 requires pandas==2.2.2, but you have pandas 2.3.0 which is incompatible.\n",
"google-colab 1.0.0 requires requests==2.32.3, but you have requests 2.32.4 which is incompatible.\n",
"google-colab 1.0.0 requires tornado==6.4.2, but you have tornado 6.5.1 which is incompatible.\n",
"dopamine-rl 4.1.2 requires gymnasium>=1.0.0, but you have gymnasium 0.29.0 which is incompatible.\n",
"pandas-gbq 0.29.1 requires google-api-core<3.0.0,>=2.10.2, but you have google-api-core 1.34.1 which is incompatible.\n",
"bigframes 2.8.0 requires google-cloud-bigquery[bqstorage,pandas]>=3.31.0, but you have google-cloud-bigquery 3.25.0 which is incompatible.\n",
"bigframes 2.8.0 requires rich<14,>=12.4.4, but you have rich 14.0.0 which is incompatible.\n"
]
}
],
"source": [
"%%bash\n",
"rm -rf /kaggle/working/nejm-brain-to-text/\n",
"git clone https://github.com/ZH-CEN/nejm-brain-to-text.git\n",
"cp /kaggle/input/brain-to-text-baseline-model/t15_copyTask.pkl /kaggle/working/nejm-brain-to-text/data/t15_copyTask.pkl\n",
"\n",
"ln -s /kaggle/input/brain-to-text-25/t15_pretrained_rnn_baseline/t15_pretrained_rnn_baseline /kaggle/working/nejm-brain-to-text/data\n",
"ln -s /kaggle/input/brain-to-text-25/t15_copyTask_neuralData/hdf5_data_final /kaggle/working/nejm-brain-to-text/data\n",
"ln -s /kaggle/input/rnn-pretagged-data /kaggle/working/nejm-brain-to-text/data\n",
"\n",
"pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu126\n",
"\n",
"pip install \\\n",
" jupyter==1.1.1 \\\n",
" \"numpy>=1.26.0,<2.1.0\" \\\n",
" pandas==2.3.0 \\\n",
" matplotlib==3.10.1 \\\n",
" scipy==1.15.2 \\\n",
" scikit-learn==1.6.1 \\\n",
" tqdm==4.67.1 \\\n",
" g2p_en==2.1.0 \\\n",
" h5py==3.13.0 \\\n",
" omegaconf==2.3.0 \\\n",
" editdistance==0.8.1 \\\n",
" huggingface-hub==0.33.1 \\\n",
" transformers==4.53.0 \\\n",
" tokenizers==0.21.2 \\\n",
" accelerate==1.8.1 \\\n",
" bitsandbytes==0.46.0\n",
" \n",
"pip install PyDrive2\n",
"\n",
"cd /kaggle/working/nejm-brain-to-text/\n",
"pip install -e .\n"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/kaggle/working/nejm-brain-to-text\n"
]
}
],
"source": [
"%cd /kaggle/working/nejm-brain-to-text\n",
"import numpy as np\n",
"import os\n",
"import pickle\n",
"import matplotlib.pyplot as plt\n",
"import matplotlib\n",
"from g2p_en import G2p\n",
"import pandas as pd\n",
"import numpy as np\n",
"from nejm_b2txt_utils.general_utils import *\n",
"\n",
"matplotlib.rcParams['pdf.fonttype'] = 42\n",
"matplotlib.rcParams['ps.fonttype'] = 42\n",
"matplotlib.rcParams['font.family'] = 'sans-serif'\n"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/kaggle/working/nejm-brain-to-text/model_training\n"
]
}
],
"source": [
"%cd model_training/\n",
"from data_augmentations import gauss_smooth\n",
"# single decoding step function that also returns smoothed input\n",
"# smooths data and puts it through the model, returning both logits and smoothed input.\n",
"def runSingleDecodingStepWithSmoothedInput(x, input_layer, model, model_args, device):\n",
"\n",
" # Use autocast for efficiency\n",
" with torch.autocast(device_type = \"cuda\", enabled = model_args['use_amp'], dtype = torch.bfloat16):\n",
"\n",
" smoothed_x = gauss_smooth(\n",
" inputs = x, \n",
" device = device,\n",
" smooth_kernel_std = model_args['dataset']['data_transforms']['smooth_kernel_std'],\n",
" smooth_kernel_size = model_args['dataset']['data_transforms']['smooth_kernel_size'],\n",
" padding = 'valid',\n",
" )\n",
"\n",
" with torch.no_grad():\n",
" logits, _ = model(\n",
" x = smoothed_x,\n",
" day_idx = torch.tensor([input_layer], device=device),\n",
" states = None, # no initial states\n",
" return_state = True,\n",
" )\n",
"\n",
" # convert both logits and smoothed input from bfloat16 to float32\n",
" logits = logits.float().cpu().numpy()\n",
" smoothed_input = smoothed_x.float().cpu().numpy()\n",
"\n",
" # # original order is [BLANK, phonemes..., SIL]\n",
" # # rearrange so the order is [BLANK, SIL, phonemes...]\n",
" # logits = rearrange_speech_logits_pt(logits)\n",
"\n",
" return logits, smoothed_input\n",
"\n",
"\n",
"import h5py\n",
"def load_h5py_file(file_path, b2txt_csv_df):\n",
" data = {\n",
" 'neural_features': [],\n",
" 'n_time_steps': [],\n",
" 'seq_class_ids': [],\n",
" 'seq_len': [],\n",
" 'transcriptions': [],\n",
" 'sentence_label': [],\n",
" 'session': [],\n",
" 'block_num': [],\n",
" 'trial_num': [],\n",
" 'corpus': [],\n",
" }\n",
" # Open the hdf5 file for that day\n",
" with h5py.File(file_path, 'r') as f:\n",
"\n",
" keys = list(f.keys())\n",
"\n",
" # For each trial in the selected trials in that day\n",
" for key in keys:\n",
" g = f[key]\n",
"\n",
" neural_features = g['input_features'][:] # pyright: ignore[reportIndexIssue]\n",
" n_time_steps = g.attrs['n_time_steps']\n",
" seq_class_ids = g['seq_class_ids'][:] if 'seq_class_ids' in g else None # type: ignore\n",
" seq_len = g.attrs['seq_len'] if 'seq_len' in g.attrs else None\n",
" transcription = g['transcription'][:] if 'transcription' in g else None # type: ignore\n",
" sentence_label = g.attrs['sentence_label'][:] if 'sentence_label' in g.attrs else None # pyright: ignore[reportIndexIssue]\n",
" session = g.attrs['session']\n",
" block_num = g.attrs['block_num']\n",
" trial_num = g.attrs['trial_num']\n",
"\n",
" # match this trial up with the csv to get the corpus name\n",
" year, month, day = session.split('.')[1:] # pyright: ignore[reportAttributeAccessIssue]\n",
" date = f'{year}-{month}-{day}'\n",
" row = b2txt_csv_df[(b2txt_csv_df['Date'] == date) & (b2txt_csv_df['Block number'] == block_num)]\n",
" corpus_name = row['Corpus'].values[0]\n",
"\n",
" data['neural_features'].append(neural_features)\n",
" data['n_time_steps'].append(n_time_steps)\n",
" data['seq_class_ids'].append(seq_class_ids)\n",
" data['seq_len'].append(seq_len)\n",
" data['transcriptions'].append(transcription)\n",
" data['sentence_label'].append(sentence_label)\n",
" data['session'].append(session)\n",
" data['block_num'].append(block_num)\n",
" data['trial_num'].append(trial_num)\n",
" data['corpus'].append(corpus_name)\n",
" return data"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"LOGIT_TO_PHONEME = [\n",
" 'BLANK',\n",
" 'AA', 'AE', 'AH', 'AO', 'AW',\n",
" 'AY', 'B', 'CH', 'D', 'DH',\n",
" 'EH', 'ER', 'EY', 'F', 'G',\n",
" 'HH', 'IH', 'IY', 'JH', 'K',\n",
" 'L', 'M', 'N', 'NG', 'OW',\n",
" 'OY', 'P', 'R', 'S', 'SH',\n",
" 'T', 'TH', 'UH', 'UW', 'V',\n",
" 'W', 'Y', 'Z', 'ZH',\n",
" ' | ',\n",
"]"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 数据分析与预处理"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 数据准备"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/kaggle/working/nejm-brain-to-text\n"
]
}
],
"source": [
"%cd /kaggle/working/nejm-brain-to-text/"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"data = load_h5py_file(file_path='/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2023.08.11/data_train.hdf5',\n",
" b2txt_csv_df=pd.read_csv('/kaggle/working/nejm-brain-to-text/data/t15_copyTaskData_description.csv'))"
]
},
{
"cell_type": "code",
"execution_count": 57,
"metadata": {},
"outputs": [],
"source": [
"data2 = load_h5py_file(file_path='/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2023.08.13/data_train.hdf5',\n",
" b2txt_csv_df=pd.read_csv('/kaggle/working/nejm-brain-to-text/data/t15_copyTaskData_description.csv'))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- **任务介绍** :机器学习解决高维信号的模式识别问题"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"我们的数据集标签缺少时间戳,现在要进行的是半监督学习"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"- 音素时间均等分割或者按照调研数据设定初始长度。然后筛掉异常值。提取出可用的训练集,再控制时间长短,查看样本类的长度"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'neural_features': array([[ 2.3076649 , -0.78699756, -0.64687246, ..., 0.57367045,\n",
" -0.7091646 , -0.11018186],\n",
" [-0.5859305 , -0.78699756, -0.64687246, ..., 0.3122117 ,\n",
" 1.7943763 , -0.76884896],\n",
" [-0.5859305 , -0.78699756, -0.64687246, ..., -0.21193463,\n",
" -0.8481289 , -0.7648201 ],\n",
" ...,\n",
" [-0.5859305 , 0.22756557, 0.9262037 , ..., -0.34710956,\n",
" 0.9710176 , 2.5397465 ],\n",
" [-0.5859305 , 0.22756557, -0.64687246, ..., -0.83613133,\n",
" -0.68723625, 0.10479005],\n",
" [ 0.8608672 , -0.78699756, -0.64687246, ..., -0.7171131 ,\n",
" 0.7417906 , -0.7008622 ]], dtype=float32),\n",
" 'n_time_steps': 321,\n",
" 'seq_class_ids': array([ 7, 28, 17, 24, 40, 17, 31, 40, 20, 21, 25, 29, 12, 40, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0], dtype=int32),\n",
" 'seq_len': 14,\n",
" 'transcriptions': array([ 66, 114, 105, 110, 103, 32, 105, 116, 32, 99, 108, 111, 115,\n",
" 101, 114, 46, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,\n",
" 0, 0, 0, 0, 0, 0], dtype=int32),\n",
" 'sentence_label': 'Bring it closer.',\n",
" 'session': 't15.2023.08.11',\n",
" 'block_num': 2,\n",
" 'trial_num': 0,\n",
" 'corpus': '50-Word'}"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"def data_patch(data, index):\n",
" data_patch = {}\n",
" data_patch['neural_features'] = data['neural_features'][index]\n",
" data_patch['n_time_steps'] = data['n_time_steps'][index]\n",
" data_patch['seq_class_ids'] = data['seq_class_ids'][index]\n",
" data_patch['seq_len'] = data['seq_len'][index]\n",
" data_patch['transcriptions'] = data['transcriptions'][index]\n",
" data_patch['sentence_label'] = data['sentence_label'][index]\n",
" data_patch['session'] = data['session'][index]\n",
" data_patch['block_num'] = data['block_num'][index]\n",
" data_patch['trial_num'] = data['trial_num'][index]\n",
" data_patch['corpus'] = data['corpus'][index]\n",
" return data_patch"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"d1 = data_patch(data, 0)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Transcriptions non-zero length: 16\n",
"Seq class ids non-zero length: 14\n",
"Seq len: 14\n"
]
}
],
"source": [
"trans_len = len([x for x in d1['transcriptions'] if x != 0])\n",
"seq_len_nonzero = len([x for x in d1['seq_class_ids'] if x != 0])\n",
"seq_len = d1['seq_len']\n",
"print(f\"Transcriptions non-zero length: {trans_len}\")\n",
"print(f\"Seq class ids non-zero length: {seq_len_nonzero}\")\n",
"print(f\"Seq len: {seq_len}\")"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Number of feature sequences: 14\n",
"Shape of first sequence: (22, 512)\n"
]
}
],
"source": [
"def create_time_windows(d1):\n",
" import numpy as np\n",
" n_time_steps = d1['n_time_steps']\n",
" seq_len = d1['seq_len']\n",
" # Create equal windows\n",
" edges = np.linspace(0, n_time_steps, seq_len + 1, dtype=int)\n",
" windows = [(edges[i], edges[i+1]) for i in range(seq_len)]\n",
" \n",
" # Extract feature sequences for each window\n",
" feature_sequences = []\n",
" for start, end in windows:\n",
" seq = d1['neural_features'][start:end, :]\n",
" feature_sequences.append(seq)\n",
" \n",
" return feature_sequences\n",
"\n",
"# Example usage\n",
"feature_sequences = create_time_windows(d1)\n",
"print(\"Number of feature sequences:\", len(feature_sequences))\n",
"print(\"Shape of first sequence:\", feature_sequences[0].shape)\n"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Train: 45, Val: 41, Test: 41\n",
"Train files (first 3): ['/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2025.03.14/data_train.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2023.08.11/data_train.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2023.11.19/data_train.hdf5']\n",
"Val files (first 3): ['/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2025.03.14/data_val.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2023.11.19/data_val.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2024.03.08/data_val.hdf5']\n",
"Test files (first 3): ['/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2025.03.14/data_test.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2023.11.19/data_test.hdf5', '/kaggle/working/nejm-brain-to-text/data/hdf5_data_final/t15.2024.03.08/data_test.hdf5']\n"
]
}
],
"source": [
"import os\n",
"\n",
"def scan_hdf5_files(base_path):\n",
" train_files = []\n",
" val_files = []\n",
" test_files = []\n",
" for root, dirs, files in os.walk(base_path):\n",
" for file in files:\n",
" if file.endswith('.hdf5'):\n",
" abs_path = os.path.abspath(os.path.join(root, file))\n",
" if 'data_train.hdf5' in file:\n",
" train_files.append(abs_path)\n",
" elif 'data_val.hdf5' in file:\n",
" val_files.append(abs_path)\n",
" elif 'data_test.hdf5' in file:\n",
" test_files.append(abs_path)\n",
" return train_files, val_files, test_files\n",
"\n",
"# Example usage\n",
"FILE_PATH = 'data/hdf5_data_final'\n",
"train_list, val_list, test_list = scan_hdf5_files(FILE_PATH)\n",
"print(f\"Train: {len(train_list)}, Val: {len(val_list)}, Test: {len(test_list)}\")\n",
"print(\"Train files (first 3):\", train_list[:3])\n",
"print(\"Val files (first 3):\", val_list[:3])\n",
"print(\"Test files (first 3):\", test_list[:3])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 🔗 数据集批量处理工作流"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 25,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"/kaggle/working/nejm-brain-to-text/model_training\n"
]
}
],
"source": [
"%cd model_training/"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"======================================================================\n",
"🚀 RNN数据批量处理工具 - 增强内存管理版本\n",
"======================================================================\n",
"🔧 创建RNN数据处理器增强内存管理版...\n",
"🔧 初始化RNN数据处理器增强内存管理版...\n",
" 模型路径: ../data/t15_pretrained_rnn_baseline\n",
" 数据目录: ../data/hdf5_data_final\n",
" 计算设备: cpu\n",
" 批量保存间隔: 25 个试验\n",
" 内存警告阈值: 80%\n",
" 初始内存状态: RAM: 0.52GB (3.8%)\n",
"📋 模型配置:\n",
" Sessions数量: 45\n",
" 神经特征维度: 512\n",
" Patch size: 14\n",
" Patch stride: 4\n",
" 输出类别数: 41\n",
"🔄 加载模型... 当前内存: RAM: 0.52GB (3.8%)\n",
"✅ 模型加载成功,内存清理后: RAM: 1.18GB (8.6%)\n",
"📊 CSV数据加载完成: 265 条记录\n",
"✅ 初始化完成!当前内存: RAM: 1.18GB (8.6%)\n",
"✅ RNN数据处理器创建成功\n",
"🧠 内存管理功能已启用:\n",
" - 批量保存间隔: 25个试验\n",
" - 自动内存监控和清理\n",
" - GPU内存即时释放\n",
" - 垃圾回收优化\n",
"✅ 模型加载成功,内存清理后: RAM: 1.18GB (8.6%)\n",
"📊 CSV数据加载完成: 265 条记录\n",
"✅ 初始化完成!当前内存: RAM: 1.18GB (8.6%)\n",
"✅ RNN数据处理器创建成功\n",
"🧠 内存管理功能已启用:\n",
" - 批量保存间隔: 25个试验\n",
" - 自动内存监控和清理\n",
" - GPU内存即时释放\n",
" - 垃圾回收优化\n"
]
}
],
"source": [
"# 🚀 RNN数据批量处理工具 - 完整版(增强内存管理 + 自动上传)\n",
"import os\n",
"import torch\n",
"import numpy as np\n",
"import pandas as pd\n",
"from omegaconf import OmegaConf\n",
"import time\n",
"from tqdm import tqdm\n",
"import h5py\n",
"from pathlib import Path\n",
"import gc # 垃圾回收\n",
"import psutil # 内存监控\n",
"\n",
"# 导入模型相关模块\n",
"import sys\n",
"sys.path.append('../model_training')\n",
"from rnn_model import GRUDecoder\n",
"from evaluate_model_helpers import *\n",
"from data_augmentations import gauss_smooth\n",
"\n",
"print(\"=\"*70)\n",
"print(\"🚀 RNN数据批量处理工具 - 增强内存管理 + 自动上传版本\")\n",
"print(\"=\"*70)\n",
"\n",
"class MemoryManager:\n",
" \"\"\"内存管理器 - 监控和清理内存\"\"\"\n",
" \n",
" @staticmethod\n",
" def get_memory_info():\n",
" \"\"\"获取内存使用情况\"\"\"\n",
" process = psutil.Process()\n",
" memory_info = process.memory_info()\n",
" memory_percent = process.memory_percent()\n",
" \n",
" # GPU内存如果可用\n",
" gpu_memory = \"\"\n",
" if torch.cuda.is_available():\n",
" gpu_allocated = torch.cuda.memory_allocated() / 1024**3\n",
" gpu_reserved = torch.cuda.memory_reserved() / 1024**3\n",
" gpu_memory = f\" | GPU: {gpu_allocated:.2f}GB allocated, {gpu_reserved:.2f}GB reserved\"\n",
" \n",
" return f\"RAM: {memory_info.rss / 1024**3:.2f}GB ({memory_percent:.1f}%){gpu_memory}\"\n",
" \n",
" @staticmethod\n",
" def clear_memory():\n",
" \"\"\"清理内存\"\"\"\n",
" # 清理Python垃圾回收\n",
" collected = gc.collect()\n",
" \n",
" # 清理GPU内存\n",
" if torch.cuda.is_available():\n",
" torch.cuda.empty_cache()\n",
" torch.cuda.synchronize()\n",
" \n",
" return collected\n",
" \n",
" @staticmethod\n",
" def memory_warning_check():\n",
" \"\"\"检查内存使用情况并发出警告\"\"\"\n",
" memory_percent = psutil.virtual_memory().percent\n",
" if memory_percent > 85:\n",
" print(f\" 内存使用率过高: {memory_percent:.1f}%\")\n",
" return True\n",
" return False\n",
"\n",
"class RNNDataProcessor:\n",
" \"\"\"\n",
" RNN数据批量处理器 - 生成RNN输入输出拼接数据\n",
" 增强版本优化内存管理支持大数据集处理自动上传到WebDAV\n",
" \n",
" 核心功能:\n",
" 1. 加载预训练RNN模型\n",
" 2. 处理原始神经数据(高斯平滑 + patch操作\n",
" 3. 获取RNN输出40类置信度分数\n",
" 4. 拼接处理后的输入和输出\n",
" 5. 批量保存所有session数据\n",
" 6. 自动上传到WebDAV并删除本地文件\n",
" 7. 内存管理和监控\n",
" \"\"\"\n",
" \n",
" def __init__(self, model_path, data_dir, csv_path, device='auto', \n",
" batch_save_interval=50, memory_threshold=80, \n",
" enable_auto_upload=True, webdav_uploader=None):\n",
" \"\"\"\n",
" 初始化处理器\n",
" \n",
" 参数:\n",
" model_path: 预训练RNN模型路径\n",
" data_dir: 数据目录路径 \n",
" csv_path: 数据描述CSV文件路径\n",
" device: 计算设备 ('auto', 'cpu', 'cuda:0'等)\n",
" batch_save_interval: 批量保存间隔每N个试验保存一次\n",
" memory_threshold: 内存警告阈值(百分比)\n",
" enable_auto_upload: 是否启用自动上传\n",
" webdav_uploader: WebDAV上传器实例\n",
" \"\"\"\n",
" self.model_path = model_path\n",
" self.data_dir = data_dir\n",
" self.csv_path = csv_path\n",
" self.batch_save_interval = batch_save_interval\n",
" self.memory_threshold = memory_threshold\n",
" self.enable_auto_upload = enable_auto_upload\n",
" self.webdav_uploader = webdav_uploader\n",
" \n",
" # 设备选择\n",
" if device == 'auto':\n",
" self.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')\n",
" else:\n",
" self.device = torch.device(device)\n",
" \n",
" print(f\"🔧 初始化RNN数据处理器增强内存管理 + 自动上传版)...\")\n",
" print(f\" 模型路径: {model_path}\")\n",
" print(f\" 数据目录: {data_dir}\")\n",
" print(f\" 计算设备: {self.device}\")\n",
" print(f\" 批量保存间隔: {batch_save_interval} 个试验\")\n",
" print(f\" 内存警告阈值: {memory_threshold}%\")\n",
" print(f\" 自动上传: {'✅ 启用' if enable_auto_upload else '❌ 禁用'}\")\n",
" \n",
" # 初始内存状态\n",
" print(f\" 初始内存状态: {MemoryManager.get_memory_info()}\")\n",
" \n",
" # 加载配置和模型\n",
" self._load_config()\n",
" self._load_model()\n",
" self._load_csv()\n",
" \n",
" print(f\" 初始化完成!当前内存: {MemoryManager.get_memory_info()}\")\n",
" \n",
" def _load_config(self):\n",
" \"\"\"加载模型配置\"\"\"\n",
" config_path = os.path.join(self.model_path, 'checkpoint/args.yaml')\n",
" if not os.path.exists(config_path):\n",
" raise FileNotFoundError(f\"配置文件不存在: {config_path}\")\n",
" \n",
" self.model_args = OmegaConf.load(config_path)\n",
" \n",
" print(f\" 模型配置:\")\n",
" print(f\" Sessions数量: {len(self.model_args['dataset']['sessions'])}\")\n",
" print(f\" 神经特征维度: {self.model_args['model']['n_input_features']}\")\n",
" print(f\" Patch size: {self.model_args['model']['patch_size']}\")\n",
" print(f\" Patch stride: {self.model_args['model']['patch_stride']}\")\n",
" print(f\" 输出类别数: {self.model_args['dataset']['n_classes']}\")\n",
" \n",
" def _load_model(self):\n",
" \"\"\"加载预训练RNN模型\"\"\"\n",
" try:\n",
" print(f\" 加载模型... 当前内存: {MemoryManager.get_memory_info()}\")\n",
" \n",
" # 创建模型\n",
" self.model = GRUDecoder(\n",
" neural_dim=self.model_args['model']['n_input_features'],\n",
" n_units=self.model_args['model']['n_units'], \n",
" n_days=len(self.model_args['dataset']['sessions']),\n",
" n_classes=self.model_args['dataset']['n_classes'],\n",
" rnn_dropout=self.model_args['model']['rnn_dropout'],\n",
" input_dropout=self.model_args['model']['input_network']['input_layer_dropout'],\n",
" n_layers=self.model_args['model']['n_layers'],\n",
" patch_size=self.model_args['model']['patch_size'],\n",
" patch_stride=self.model_args['model']['patch_stride'],\n",
" )\n",
" \n",
" # 加载权重\n",
" checkpoint_path = os.path.join(self.model_path, 'checkpoint/best_checkpoint')\n",
" try:\n",
" checkpoint = torch.load(checkpoint_path, map_location=self.device, weights_only=False)\n",
" except TypeError:\n",
" checkpoint = torch.load(checkpoint_path, map_location=self.device)\n",
" \n",
" # 清理键名\n",
" for key in list(checkpoint['model_state_dict'].keys()):\n",
" checkpoint['model_state_dict'][key.replace(\"module.\", \"\")] = checkpoint['model_state_dict'].pop(key)\n",
" checkpoint['model_state_dict'][key.replace(\"_orig_mod.\", \"\")] = checkpoint['model_state_dict'].pop(key)\n",
" \n",
" self.model.load_state_dict(checkpoint['model_state_dict'])\n",
" self.model.to(self.device)\n",
" self.model.eval()\n",
" \n",
" # 立即清理checkpoint内存\n",
" del checkpoint\n",
" MemoryManager.clear_memory()\n",
" \n",
" print(f\" 模型加载成功,内存清理后: {MemoryManager.get_memory_info()}\")\n",
" \n",
" except Exception as e:\n",
" print(f\" 模型加载失败: {e}\")\n",
" raise\n",
" \n",
" def _load_csv(self):\n",
" \"\"\"加载数据描述文件\"\"\"\n",
" if not os.path.exists(self.csv_path):\n",
" raise FileNotFoundError(f\"CSV文件不存在: {self.csv_path}\")\n",
" \n",
" self.csv_df = pd.read_csv(self.csv_path)\n",
" print(f\"📊 CSV数据加载完成: {len(self.csv_df)} 条记录\")\n",
" \n",
" def _process_single_trial(self, neural_data, session_idx):\n",
" \"\"\"\n",
" 处理单个试验数据(优化内存使用)\n",
" \n",
" 参数:\n",
" neural_data: 原始神经数据 [time_steps, features]\n",
" session_idx: 会话索引\n",
" \n",
" 返回:\n",
" dict: 包含拼接数据和统计信息\n",
" \"\"\"\n",
" try:\n",
" # 添加batch维度\n",
" neural_input = np.expand_dims(neural_data, axis=0)\n",
" neural_tensor = torch.tensor(neural_input, device=self.device, dtype=torch.bfloat16)\n",
" \n",
" # 高斯平滑\n",
" with torch.autocast(device_type=\"cuda\" if self.device.type == \"cuda\" else \"cpu\", \n",
" enabled=self.model_args.get('use_amp', False), dtype=torch.bfloat16):\n",
" \n",
" smoothed_data = gauss_smooth(\n",
" inputs=neural_tensor,\n",
" device=self.device,\n",
" smooth_kernel_std=self.model_args['dataset']['data_transforms']['smooth_kernel_std'],\n",
" smooth_kernel_size=self.model_args['dataset']['data_transforms']['smooth_kernel_size'],\n",
" padding='valid',\n",
" )\n",
" \n",
" # Patch操作复制模型内部逻辑\n",
" processed_data = smoothed_data\n",
" if self.model.patch_size > 0:\n",
" processed_data = processed_data.unsqueeze(1) # [batch, 1, time, features]\n",
" processed_data = processed_data.permute(0, 3, 1, 2) # [batch, features, 1, time]\n",
" \n",
" # 滑动窗口提取\n",
" patches = processed_data.unfold(3, self.model.patch_size, self.model.patch_stride)\n",
" patches = patches.squeeze(2) # [batch, features, patches, patch_size]\n",
" patches = patches.permute(0, 2, 3, 1) # [batch, patches, patch_size, features]\n",
" \n",
" # 展平最后两个维度\n",
" processed_data = patches.reshape(patches.size(0), patches.size(1), -1)\n",
" \n",
" # RNN推理\n",
" with torch.no_grad():\n",
" logits, _ = self.model(\n",
" x=smoothed_data,\n",
" day_idx=torch.tensor([session_idx], device=self.device),\n",
" states=None,\n",
" return_state=True,\n",
" )\n",
" \n",
" # 转换为numpy并立即释放GPU内存\n",
" processed_features = processed_data.float().cpu().numpy()[0] # [time_steps, processed_features]\n",
" confidence_scores = logits.float().cpu().numpy()[0] # [time_steps, 40]\n",
" \n",
" # 立即清理GPU张量\n",
" del neural_tensor, smoothed_data, processed_data, logits\n",
" if torch.cuda.is_available():\n",
" torch.cuda.empty_cache()\n",
" \n",
" # 拼接数据\n",
" concatenated = np.concatenate([processed_features, confidence_scores], axis=1)\n",
" \n",
" return {\n",
" 'concatenated_data': concatenated,\n",
" 'processed_features': processed_features,\n",
" 'confidence_scores': confidence_scores,\n",
" 'original_time_steps': neural_data.shape[0],\n",
" 'processed_time_steps': concatenated.shape[0],\n",
" 'feature_reduction_ratio': concatenated.shape[0] / neural_data.shape[0]\n",
" }\n",
" \n",
" except Exception as e:\n",
" # 确保GPU内存清理\n",
" if torch.cuda.is_available():\n",
" torch.cuda.empty_cache()\n",
" raise e\n",
" \n",
" def _upload_and_cleanup_file(self, filepath):\n",
" \"\"\"上传文件到WebDAV并删除本地文件\"\"\"\n",
" if not self.enable_auto_upload or not self.webdav_uploader:\n",
" return False\n",
" \n",
" try:\n",
" # 上传文件\n",
" filename = os.path.basename(filepath)\n",
" success = self.webdav_uploader.upload_file(\n",
" str(filepath), \n",
" '/移动云盘/DATA/rnn_processed_data/'\n",
" )\n",
" \n",
" if success:\n",
" # 删除本地文件\n",
" os.remove(filepath)\n",
" print(f\" 📤 上传并删除: {filename}\")\n",
" return True\n",
" else:\n",
" print(f\" ❌ 上传失败,保留本地文件: {filename}\")\n",
" return False\n",
" \n",
" except Exception as e:\n",
" print(f\" ⚠️ 上传过程出错: {e}\")\n",
" return False\n",
" \n",
" def _save_batch_data(self, results, session_name, data_type, save_path, batch_idx=None):\n",
" \"\"\"\n",
" 保存批次数据(减少内存占用 + 自动上传)\n",
" \n",
" 参数:\n",
" results: 结果数据\n",
" session_name: 会话名称\n",
" data_type: 数据类型\n",
" save_path: 保存路径\n",
" batch_idx: 批次索引(可选)\n",
" \"\"\"\n",
" if not results['concatenated_data']:\n",
" return\n",
" \n",
" # 生成文件名\n",
" if batch_idx is not None:\n",
" filename = f\"{session_name}_{data_type}_rnn_processed_batch{batch_idx}.npz\"\n",
" else:\n",
" filename = f\"{session_name}_{data_type}_rnn_processed.npz\"\n",
" \n",
" filepath = save_path / filename\n",
" \n",
" save_data = {\n",
" 'concatenated_data': np.array(results['concatenated_data'], dtype=object),\n",
" 'processed_features': np.array(results['processed_features'], dtype=object),\n",
" 'confidence_scores': np.array(results['confidence_scores'], dtype=object),\n",
" 'trial_metadata': np.array(results['trial_metadata'], dtype=object),\n",
" }\n",
" \n",
" # 保存文件\n",
" np.savez_compressed(str(filepath), **save_data)\n",
" print(f\" 💾 保存批次: {filename} ({len(results['concatenated_data'])} 个试验)\")\n",
" \n",
" # 自动上传并删除本地文件\n",
" if self.enable_auto_upload:\n",
" self._upload_and_cleanup_file(filepath)\n",
" \n",
" # 清理结果数据释放内存\n",
" for key in results:\n",
" if isinstance(results[key], list):\n",
" results[key].clear()\n",
" \n",
" # 强制垃圾回收\n",
" MemoryManager.clear_memory()\n",
" \n",
" def process_session(self, session_name, data_types=['train', 'val', 'test'], save_dir='./rnn_processed_data'):\n",
" \"\"\"\n",
" 处理单个session的数据优化内存管理 + 自动上传)\n",
" \n",
" 参数:\n",
" session_name: 会话名称\n",
" data_types: 要处理的数据类型列表\n",
" save_dir: 保存目录\n",
" \n",
" 返回:\n",
" dict: 处理结果摘要\n",
" \"\"\"\n",
" print(f\"\\n 处理会话: {session_name}\")\n",
" print(f\" 开始时内存: {MemoryManager.get_memory_info()}\")\n",
" \n",
" session_idx = self.model_args['dataset']['sessions'].index(session_name)\n",
" session_results = {}\n",
" \n",
" # 确保保存目录存在\n",
" save_path = Path(save_dir)\n",
" save_path.mkdir(parents=True, exist_ok=True)\n",
" \n",
" for data_type in data_types:\n",
" data_file = os.path.join(self.data_dir, session_name, f'data_{data_type}.hdf5')\n",
" \n",
" if not os.path.exists(data_file):\n",
" print(f\" {data_type} 数据文件不存在,跳过\")\n",
" continue\n",
" \n",
" print(f\" 处理 {data_type} 数据...\")\n",
" \n",
" try:\n",
" # 加载数据\n",
" data = load_h5py_file(data_file, self.csv_df)\n",
" num_trials = len(data['neural_features'])\n",
" \n",
" if num_trials == 0:\n",
" print(f\" {data_type} 数据为空\")\n",
" continue\n",
" \n",
" # 处理所有试验(批量保存策略)\n",
" results = {\n",
" 'concatenated_data': [],\n",
" 'processed_features': [],\n",
" 'confidence_scores': [],\n",
" 'trial_metadata': [],\n",
" 'processing_stats': []\n",
" }\n",
" \n",
" batch_count = 0\n",
" total_processed = 0\n",
" uploaded_files = 0\n",
" \n",
" for trial_idx in tqdm(range(num_trials), desc=f\" {data_type}\", leave=False):\n",
" # 检查内存使用情况\n",
" if trial_idx % 10 == 0: # 每10个trial检查一次\n",
" if MemoryManager.memory_warning_check():\n",
" MemoryManager.clear_memory()\n",
" print(f\" 🧹 内存清理: {MemoryManager.get_memory_info()}\")\n",
" \n",
" neural_data = data['neural_features'][trial_idx]\n",
" \n",
" # 处理单个试验\n",
" trial_result = self._process_single_trial(neural_data, session_idx)\n",
" \n",
" # 保存结果\n",
" results['concatenated_data'].append(trial_result['concatenated_data'])\n",
" results['processed_features'].append(trial_result['processed_features'])\n",
" results['confidence_scores'].append(trial_result['confidence_scores'])\n",
" \n",
" # 保存元数据\n",
" metadata = {\n",
" 'session': session_name,\n",
" 'data_type': data_type,\n",
" 'trial_idx': trial_idx,\n",
" 'block_num': data.get('block_num', [None])[trial_idx],\n",
" 'trial_num': data.get('trial_num', [None])[trial_idx],\n",
" **{k: v for k, v in trial_result.items() if k != 'concatenated_data'}\n",
" }\n",
" \n",
" # 添加真实标签(如果可用)\n",
" if data_type in ['train', 'val'] and 'sentence_label' in data:\n",
" metadata.update({\n",
" 'sentence_label': data['sentence_label'][trial_idx],\n",
" 'seq_class_ids': data['seq_class_ids'][trial_idx],\n",
" 'seq_len': data['seq_len'][trial_idx]\n",
" })\n",
" \n",
" results['trial_metadata'].append(metadata)\n",
" results['processing_stats'].append(trial_result)\n",
" total_processed += 1\n",
" \n",
" # 批量保存策略\n",
" if (trial_idx + 1) % self.batch_save_interval == 0 or trial_idx == num_trials - 1:\n",
" self._save_batch_data(results, session_name, data_type, save_path, batch_count)\n",
" if self.enable_auto_upload:\n",
" uploaded_files += 1\n",
" batch_count += 1\n",
" \n",
" # 强制内存清理\n",
" MemoryManager.clear_memory()\n",
" \n",
" # 统计信息\n",
" if total_processed > 0:\n",
" print(f\" {data_type} 处理完成:\")\n",
" print(f\" 试验数: {total_processed}\")\n",
" print(f\" 保存批次数: {batch_count}\")\n",
" if self.enable_auto_upload:\n",
" print(f\" 上传文件数: {uploaded_files}\")\n",
" print(f\" 最终内存: {MemoryManager.get_memory_info()}\")\n",
" \n",
" session_results[data_type] = {\n",
" 'total_trials': total_processed,\n",
" 'batch_count': batch_count,\n",
" 'uploaded_files': uploaded_files if self.enable_auto_upload else 0,\n",
" 'files': [f\"{session_name}_{data_type}_rnn_processed_batch{i}.npz\" for i in range(batch_count)]\n",
" }\n",
" \n",
" # 清理大型数据对象\n",
" del data\n",
" MemoryManager.clear_memory()\n",
" \n",
" except Exception as e:\n",
" print(f\" {data_type} 处理失败: {e}\")\n",
" # 确保内存清理\n",
" MemoryManager.clear_memory()\n",
" continue\n",
" \n",
" print(f\" 会话完成时内存: {MemoryManager.get_memory_info()}\")\n",
" return session_results\n",
" \n",
" def process_all_sessions(self, data_types=['train', 'val', 'test'], save_dir='./rnn_processed_data'):\n",
" \"\"\"\n",
" 批量处理所有sessions优化内存管理 + 自动上传)\n",
" \n",
" 参数:\n",
" data_types: 要处理的数据类型\n",
" save_dir: 保存目录\n",
" \n",
" 返回:\n",
" dict: 所有处理结果摘要\n",
" \"\"\"\n",
" print(f\"\\n 开始批量处理所有会话(增强内存管理 + 自动上传)...\")\n",
" print(f\" 目标数据类型: {data_types}\")\n",
" print(f\" 保存目录: {save_dir}\")\n",
" print(f\" 批量保存间隔: {self.batch_save_interval}\")\n",
" print(f\" 自动上传: {'✅ 启用' if self.enable_auto_upload else '❌ 禁用'}\")\n",
" print(f\" 初始内存状态: {MemoryManager.get_memory_info()}\")\n",
" \n",
" save_path = Path(save_dir)\n",
" save_path.mkdir(parents=True, exist_ok=True)\n",
" \n",
" all_results = {}\n",
" sessions = self.model_args['dataset']['sessions']\n",
" \n",
" start_time = time.time()\n",
" total_uploaded_files = 0\n",
" \n",
" for i, session in enumerate(sessions):\n",
" print(f\"\\n 进度: {i+1}/{len(sessions)} - {session}\")\n",
" \n",
" try:\n",
" session_results = self.process_session(session, data_types, save_dir)\n",
" \n",
" if session_results:\n",
" all_results[session] = session_results\n",
" \n",
" # 统计上传文件数\n",
" session_uploaded = sum(\n",
" type_data.get('uploaded_files', 0) \n",
" for type_data in session_results.values()\n",
" )\n",
" total_uploaded_files += session_uploaded\n",
" \n",
" print(f\" 会话 {session} 完成\")\n",
" if self.enable_auto_upload:\n",
" print(f\" 本会话上传文件: {session_uploaded}\")\n",
" else:\n",
" print(f\" 会话 {session} 无有效数据\")\n",
" \n",
" # 每处理几个session进行一次深度内存清理\n",
" if (i + 1) % 5 == 0:\n",
" print(f\" 深度内存清理...\")\n",
" collected = MemoryManager.clear_memory()\n",
" print(f\" 回收对象数: {collected}, 当前内存: {MemoryManager.get_memory_info()}\")\n",
" \n",
" except Exception as e:\n",
" print(f\" 会话 {session} 处理失败: {e}\")\n",
" # 确保内存清理\n",
" MemoryManager.clear_memory()\n",
" continue\n",
" \n",
" # 生成总结\n",
" end_time = time.time()\n",
" processing_time = end_time - start_time\n",
" \n",
" total_trials = sum(\n",
" session_data[data_type]['total_trials']\n",
" for session_data in all_results.values()\n",
" for data_type in session_data.keys()\n",
" )\n",
" \n",
" total_files = sum(\n",
" session_data[data_type]['batch_count']\n",
" for session_data in all_results.values()\n",
" for data_type in session_data.keys()\n",
" )\n",
" \n",
" print(f\"\\n 批量处理完成!\")\n",
" print(f\"⏱ 总耗时: {processing_time/60:.2f} 分钟\")\n",
" print(f\" 处理统计:\")\n",
" print(f\" 成功会话: {len(all_results)}/{len(sessions)}\")\n",
" print(f\" 总试验数: {total_trials}\")\n",
" print(f\" 生成文件总数: {total_files}\")\n",
" if self.enable_auto_upload:\n",
" print(f\" 📤 上传文件总数: {total_uploaded_files}\")\n",
" print(f\" 💾 本地保留文件: {total_files - total_uploaded_files}\")\n",
" print(f\" 最终内存状态: {MemoryManager.get_memory_info()}\")\n",
" print(f\" 数据保存在: {save_dir}\")\n",
" \n",
" # 保存总结信息\n",
" summary = {\n",
" 'processing_time': processing_time,\n",
" 'total_sessions': len(all_results),\n",
" 'total_trials': total_trials,\n",
" 'total_files': total_files,\n",
" 'uploaded_files': total_uploaded_files if self.enable_auto_upload else 0,\n",
" 'auto_upload_enabled': self.enable_auto_upload,\n",
" 'data_types': data_types,\n",
" 'sessions': list(all_results.keys()),\n",
" 'batch_save_interval': self.batch_save_interval,\n",
" 'memory_management': True,\n",
" 'model_config': {\n",
" 'patch_size': self.model_args['model']['patch_size'],\n",
" 'patch_stride': self.model_args['model']['patch_stride'],\n",
" 'smooth_kernel_size': self.model_args['dataset']['data_transforms']['smooth_kernel_size'],\n",
" 'smooth_kernel_std': self.model_args['dataset']['data_transforms']['smooth_kernel_std'],\n",
" }\n",
" }\n",
" \n",
" import json\n",
" with open(save_path / 'processing_summary.json', 'w') as f:\n",
" json.dump(summary, f, indent=2)\n",
" \n",
" return all_results\n",
"\n",
"# 检查WebDAV上传器是否可用\n",
"try:\n",
" # 如果之前的上传器可用,直接使用\n",
" if 'uploader' in globals():\n",
" webdav_uploader_instance = uploader\n",
" print(\"🔗 使用现有的WebDAV上传器\")\n",
" else:\n",
" # 创建简单的WebDAV上传器\n",
" from webdav3.client import Client\n",
" \n",
" class SimpleWebDAVUploader:\n",
" def __init__(self):\n",
" self.client = Client({\n",
" 'webdav_hostname': 'http://zchens.cn:5244/dav/',\n",
" 'webdav_login': 'admin',\n",
" 'webdav_password': 'Zccns20050420',\n",
" 'webdav_timeout': 30\n",
" })\n",
" \n",
" def upload_file(self, local_file, remote_dir):\n",
" try:\n",
" filename = os.path.basename(local_file)\n",
" remote_path = remote_dir.rstrip('/') + '/' + filename\n",
" \n",
" if not self.client.check(remote_dir):\n",
" self.client.mkdir(remote_dir)\n",
" \n",
" self.client.upload_sync(remote_path=remote_path, local_path=local_file)\n",
" return True\n",
" except Exception as e:\n",
" print(f\"上传失败: {e}\")\n",
" return False\n",
" \n",
" webdav_uploader_instance = SimpleWebDAVUploader()\n",
" print(\"🔗 创建新的WebDAV上传器\")\n",
" \n",
"except Exception as e:\n",
" webdav_uploader_instance = None\n",
" print(f\"⚠️ WebDAV上传器不可用: {e}\")\n",
"\n",
"# 创建处理器实例(增强内存管理 + 自动上传)\n",
"print(\"🔧 创建RNN数据处理器增强内存管理 + 自动上传版)...\")\n",
"\n",
"try:\n",
" processor = RNNDataProcessor(\n",
" model_path='../data/t15_pretrained_rnn_baseline',\n",
" data_dir='../data/hdf5_data_final',\n",
" csv_path='../data/t15_copyTaskData_description.csv',\n",
" device='auto',\n",
" batch_save_interval=25, # 每25个试验保存一次减少内存积累\n",
" memory_threshold=80, # 80%内存使用率时警告\n",
" enable_auto_upload=True, # 启用自动上传\n",
" webdav_uploader=webdav_uploader_instance # 传入WebDAV上传器\n",
" )\n",
" \n",
" print(f\" RNN数据处理器创建成功\")\n",
" print(f\" 功能特性:\")\n",
" print(f\" - 批量保存间隔: 25个试验\")\n",
" print(f\" - 自动内存监控和清理\")\n",
" print(f\" - GPU内存即时释放\")\n",
" print(f\" - 垃圾回收优化\")\n",
" print(f\" - 📤 自动上传到WebDAV\")\n",
" print(f\" - 🗑️ 自动删除本地文件\")\n",
" \n",
"except Exception as e:\n",
" print(f\" 处理器创建失败: {e}\")\n",
" processor = None"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"======================================================================\n",
"🎯 RNN数据批量处理 - 使用示例(内存优化版)\n",
"======================================================================\n",
"\n",
"📋 可用的处理方法:\n",
"1⃣ 单session处理: processor.process_session('session_name')\n",
"2⃣ 批量处理所有: processor.process_all_sessions()\n",
"\n",
"🧠 内存管理特性:\n",
" ✅ 自动批量保存 (每25个试验)\n",
" ✅ 实时内存监控和清理\n",
" ✅ GPU内存即时释放\n",
" ✅ 垃圾回收优化\n",
"\n",
"📊 可用会话数量: 45\n",
"📝 前5个会话: ['t15.2023.08.11', 't15.2023.08.13', 't15.2023.08.18', 't15.2023.08.20', 't15.2023.08.25']\n",
"💡 问题会话 t15.2023.09.01 在位置: 6\n",
"\n",
"🔍 当前系统状态:\n",
" RAM: 1.18GB (8.6%)\n",
"\n",
"🧪 快速测试: 处理会话 't15.2023.08.13' 的训练数据...\n",
" 这将测试内存管理功能...\n",
"\n",
"🔄 处理会话: t15.2023.08.13\n",
" 开始时内存: RAM: 1.18GB (8.6%)\n",
" 📁 处理 train 数据...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" train: 0%| | 0/348 [00:00<?, ?it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"⚠️ 内存使用率过高: 96.0%\n",
" 🧹 内存清理: RAM: 1.39GB (10.2%)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" train: 3%|▎ | 10/348 [00:06<03:27, 1.63it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"⚠️ 内存使用率过高: 93.7%\n",
" 🧹 内存清理: RAM: 1.56GB (11.4%)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" train: 6%|▌ | 20/348 [00:13<02:59, 1.82it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"⚠️ 内存使用率过高: 93.8%\n",
" 🧹 内存清理: RAM: 1.65GB (12.0%)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" train: 7%|▋ | 24/348 [00:15<03:22, 1.60it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" 💾 保存批次: t15.2023.08.13_train_rnn_processed_batch0.npz (25 个试验)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" train: 9%|▊ | 30/348 [00:41<09:06, 1.72s/it]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"⚠️ 内存使用率过高: 93.3%\n",
" 🧹 内存清理: RAM: 1.55GB (11.3%)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" train: 11%|█▏ | 40/348 [00:47<03:49, 1.34it/s]"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"⚠️ 内存使用率过高: 93.7%\n",
" 🧹 内存清理: RAM: 1.63GB (11.9%)\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" "
]
},
{
"ename": "KeyboardInterrupt",
"evalue": "",
"output_type": "error",
"traceback": [
"\u001b[1;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[1;31mKeyboardInterrupt\u001b[0m Traceback (most recent call last)",
"Cell \u001b[1;32mIn[2], line 36\u001b[0m\n\u001b[0;32m 33\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m 这将测试内存管理功能...\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m 35\u001b[0m \u001b[38;5;66;03m# 处理单个session仅train数据进行测试\u001b[39;00m\n\u001b[1;32m---> 36\u001b[0m single_result \u001b[38;5;241m=\u001b[39m \u001b[43mprocessor\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mprocess_session\u001b[49m\u001b[43m(\u001b[49m\u001b[43mtest_session\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[38;5;124;43mtrain\u001b[39;49m\u001b[38;5;124;43m'\u001b[39;49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 38\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m single_result \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtrain\u001b[39m\u001b[38;5;124m'\u001b[39m \u001b[38;5;129;01min\u001b[39;00m single_result:\n\u001b[0;32m 39\u001b[0m train_info \u001b[38;5;241m=\u001b[39m single_result[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtrain\u001b[39m\u001b[38;5;124m'\u001b[39m]\n",
"Cell \u001b[1;32mIn[1], line 400\u001b[0m, in \u001b[0;36mRNNDataProcessor.process_session\u001b[1;34m(self, session_name, data_types, save_dir)\u001b[0m\n\u001b[0;32m 398\u001b[0m \u001b[38;5;66;03m# 批量保存策略\u001b[39;00m\n\u001b[0;32m 399\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (trial_idx \u001b[38;5;241m+\u001b[39m \u001b[38;5;241m1\u001b[39m) \u001b[38;5;241m%\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mbatch_save_interval \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m \u001b[38;5;129;01mor\u001b[39;00m trial_idx \u001b[38;5;241m==\u001b[39m num_trials \u001b[38;5;241m-\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[1;32m--> 400\u001b[0m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_save_batch_data\u001b[49m\u001b[43m(\u001b[49m\u001b[43mresults\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msession_name\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mdata_type\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43msave_path\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mbatch_count\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 401\u001b[0m batch_count \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;241m1\u001b[39m\n\u001b[0;32m 403\u001b[0m \u001b[38;5;66;03m# 强制内存清理\u001b[39;00m\n",
"Cell \u001b[1;32mIn[1], line 296\u001b[0m, in \u001b[0;36mRNNDataProcessor._save_batch_data\u001b[1;34m(self, results, session_name, data_type, save_path, batch_idx)\u001b[0m\n\u001b[0;32m 287\u001b[0m filepath \u001b[38;5;241m=\u001b[39m save_path \u001b[38;5;241m/\u001b[39m filename\n\u001b[0;32m 289\u001b[0m save_data \u001b[38;5;241m=\u001b[39m {\n\u001b[0;32m 290\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mconcatenated_data\u001b[39m\u001b[38;5;124m'\u001b[39m: np\u001b[38;5;241m.\u001b[39marray(results[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mconcatenated_data\u001b[39m\u001b[38;5;124m'\u001b[39m], dtype\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mobject\u001b[39m),\n\u001b[0;32m 291\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mprocessed_features\u001b[39m\u001b[38;5;124m'\u001b[39m: np\u001b[38;5;241m.\u001b[39marray(results[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mprocessed_features\u001b[39m\u001b[38;5;124m'\u001b[39m], dtype\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mobject\u001b[39m),\n\u001b[0;32m 292\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mconfidence_scores\u001b[39m\u001b[38;5;124m'\u001b[39m: np\u001b[38;5;241m.\u001b[39marray(results[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mconfidence_scores\u001b[39m\u001b[38;5;124m'\u001b[39m], dtype\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mobject\u001b[39m),\n\u001b[0;32m 293\u001b[0m \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtrial_metadata\u001b[39m\u001b[38;5;124m'\u001b[39m: np\u001b[38;5;241m.\u001b[39marray(results[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mtrial_metadata\u001b[39m\u001b[38;5;124m'\u001b[39m], dtype\u001b[38;5;241m=\u001b[39m\u001b[38;5;28mobject\u001b[39m),\n\u001b[0;32m 294\u001b[0m }\n\u001b[1;32m--> 296\u001b[0m np\u001b[38;5;241m.\u001b[39msavez_compressed(\u001b[38;5;28mstr\u001b[39m(filepath), \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39msave_data)\n\u001b[0;32m 297\u001b[0m \u001b[38;5;28mprint\u001b[39m(\u001b[38;5;124mf\u001b[39m\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124m 💾 保存批次: \u001b[39m\u001b[38;5;132;01m{\u001b[39;00mfilename\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m (\u001b[39m\u001b[38;5;132;01m{\u001b[39;00m\u001b[38;5;28mlen\u001b[39m(results[\u001b[38;5;124m'\u001b[39m\u001b[38;5;124mconcatenated_data\u001b[39m\u001b[38;5;124m'\u001b[39m])\u001b[38;5;132;01m}\u001b[39;00m\u001b[38;5;124m 个试验)\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m 299\u001b[0m \u001b[38;5;66;03m# 清理结果数据释放内存\u001b[39;00m\n",
"File \u001b[1;32md:\\SoftWare\\Anaconda3\\envs\\b2txt25\\lib\\site-packages\\numpy\\lib\\_npyio_impl.py:753\u001b[0m, in \u001b[0;36msavez_compressed\u001b[1;34m(file, *args, **kwds)\u001b[0m\n\u001b[0;32m 689\u001b[0m \u001b[38;5;129m@array_function_dispatch\u001b[39m(_savez_compressed_dispatcher)\n\u001b[0;32m 690\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;21msavez_compressed\u001b[39m(file, \u001b[38;5;241m*\u001b[39margs, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwds):\n\u001b[0;32m 691\u001b[0m \u001b[38;5;250m \u001b[39m\u001b[38;5;124;03m\"\"\"\u001b[39;00m\n\u001b[0;32m 692\u001b[0m \u001b[38;5;124;03m Save several arrays into a single file in compressed ``.npz`` format.\u001b[39;00m\n\u001b[0;32m 693\u001b[0m \n\u001b[1;32m (...)\u001b[0m\n\u001b[0;32m 751\u001b[0m \n\u001b[0;32m 752\u001b[0m \u001b[38;5;124;03m \"\"\"\u001b[39;00m\n\u001b[1;32m--> 753\u001b[0m \u001b[43m_savez\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfile\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwds\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m)\u001b[49m\n",
"File \u001b[1;32md:\\SoftWare\\Anaconda3\\envs\\b2txt25\\lib\\site-packages\\numpy\\lib\\_npyio_impl.py:786\u001b[0m, in \u001b[0;36m_savez\u001b[1;34m(file, args, kwds, compress, allow_pickle, pickle_kwargs)\u001b[0m\n\u001b[0;32m 784\u001b[0m \u001b[38;5;66;03m# always force zip64, gh-10776\u001b[39;00m\n\u001b[0;32m 785\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m zipf\u001b[38;5;241m.\u001b[39mopen(fname, \u001b[38;5;124m'\u001b[39m\u001b[38;5;124mw\u001b[39m\u001b[38;5;124m'\u001b[39m, force_zip64\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m) \u001b[38;5;28;01mas\u001b[39;00m fid:\n\u001b[1;32m--> 786\u001b[0m \u001b[38;5;28;43mformat\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mwrite_array\u001b[49m\u001b[43m(\u001b[49m\u001b[43mfid\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 787\u001b[0m \u001b[43m \u001b[49m\u001b[43mallow_pickle\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mallow_pickle\u001b[49m\u001b[43m,\u001b[49m\n\u001b[0;32m 788\u001b[0m \u001b[43m \u001b[49m\u001b[43mpickle_kwargs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mpickle_kwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 790\u001b[0m zipf\u001b[38;5;241m.\u001b[39mclose()\n",
"File \u001b[1;32md:\\SoftWare\\Anaconda3\\envs\\b2txt25\\lib\\site-packages\\numpy\\lib\\format.py:746\u001b[0m, in \u001b[0;36mwrite_array\u001b[1;34m(fp, array, version, allow_pickle, pickle_kwargs)\u001b[0m\n\u001b[0;32m 744\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m pickle_kwargs \u001b[38;5;129;01mis\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m:\n\u001b[0;32m 745\u001b[0m pickle_kwargs \u001b[38;5;241m=\u001b[39m {}\n\u001b[1;32m--> 746\u001b[0m pickle\u001b[38;5;241m.\u001b[39mdump(array, fp, protocol\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m4\u001b[39m, \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mpickle_kwargs)\n\u001b[0;32m 747\u001b[0m \u001b[38;5;28;01melif\u001b[39;00m array\u001b[38;5;241m.\u001b[39mflags\u001b[38;5;241m.\u001b[39mf_contiguous \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m array\u001b[38;5;241m.\u001b[39mflags\u001b[38;5;241m.\u001b[39mc_contiguous:\n\u001b[0;32m 748\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m isfileobj(fp):\n",
"File \u001b[1;32md:\\SoftWare\\Anaconda3\\envs\\b2txt25\\lib\\zipfile.py:1142\u001b[0m, in \u001b[0;36m_ZipWriteFile.write\u001b[1;34m(self, data)\u001b[0m\n\u001b[0;32m 1140\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_crc \u001b[38;5;241m=\u001b[39m crc32(data, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_crc)\n\u001b[0;32m 1141\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compressor:\n\u001b[1;32m-> 1142\u001b[0m data \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43m_compressor\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompress\u001b[49m\u001b[43m(\u001b[49m\u001b[43mdata\u001b[49m\u001b[43m)\u001b[49m\n\u001b[0;32m 1143\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_compress_size \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m \u001b[38;5;28mlen\u001b[39m(data)\n\u001b[0;32m 1144\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_fileobj\u001b[38;5;241m.\u001b[39mwrite(data)\n",
"\u001b[1;31mKeyboardInterrupt\u001b[0m: "
]
}
],
"source": [
"# 使用示例和批量处理(增强内存管理版)\n",
"\n",
"print(\"=\"*70)\n",
"print(\"RNN数据批量处理 - 使用示例(内存优化版)\")\n",
"print(\"=\"*70)\n",
"\n",
"if processor is not None:\n",
" \n",
" # 方法1: 处理单个session (推荐用于测试)\n",
" print(\"\\n可用的处理方法:\")\n",
" print(\"1. 单session处理: processor.process_session('session_name')\")\n",
" print(\"2. 批量处理所有: processor.process_all_sessions()\")\n",
" print(\"\\n内存管理特性:\")\n",
" print(\" 自动批量保存 (每25个试验)\")\n",
" print(\" 实时内存监控和清理\")\n",
" print(\" GPU内存即时释放\")\n",
" print(\" 垃圾回收优化\")\n",
" \n",
" # 显示可用的sessions\n",
" sessions = processor.model_args['dataset']['sessions']\n",
" print(f\"\\n可用会话数量: {len(sessions)}\")\n",
" print(f\"前5个会话: {sessions[:5]}\")\n",
" print(f\"问题会话 t15.2023.09.01 在位置: {sessions.index('t15.2023.09.01') if 't15.2023.09.01' in sessions else '未找到'}\")\n",
" \n",
" # 内存状态检查\n",
" print(f\"\\n当前系统状态:\")\n",
" print(f\" {MemoryManager.get_memory_info()}\")\n",
" \n",
" # 快速测试 - 处理第一个session的部分数据\n",
" test_session = sessions[1] # 't15.2023.08.11'\n",
" \n",
" print(f\"\\n快速测试: 处理会话 '{test_session}' 的训练数据...\")\n",
" print(f\" 这将测试内存管理功能...\")\n",
" \n",
" # 处理单个session仅train数据进行测试\n",
" single_result = processor.process_session(test_session, ['train'])\n",
" \n",
" if single_result and 'train' in single_result:\n",
" train_info = single_result['train']\n",
" \n",
" print(f\"\\n内存管理测试完成结果概览:\")\n",
" print(f\" 处理的试验数: {train_info['total_trials']}\")\n",
" print(f\" 保存的批次数: {train_info['batch_count']}\")\n",
" print(f\" 生成的文件: {len(train_info['files'])}\")\n",
" print(f\" 内存管理状态: {MemoryManager.get_memory_info()}\")\n",
" \n",
" # 加载一个批次文件来验证\n",
" if train_info['files']:\n",
" first_file = Path('./rnn_processed_data') / train_info['files'][0]\n",
" if first_file.exists():\n",
" test_data = np.load(str(first_file), allow_pickle=True)\n",
" sample_data = test_data['concatenated_data'][0]\n",
" \n",
" print(f\"\\n数据验证 (第一批次):\")\n",
" print(f\" 批次文件: {train_info['files'][0]}\")\n",
" print(f\" 样本数据形状: {sample_data.shape}\")\n",
" print(f\" 特征维度详情:\")\n",
" print(f\" - 处理后的神经特征: {sample_data.shape[1] - 41} 维\")\n",
" print(f\" - RNN置信度分数: 41 维\")\n",
" print(f\" - 总拼接特征: {sample_data.shape[1]} 维\")\n",
" print(f\" - 时间步数: {sample_data.shape[0]}\")\n",
" \n",
" # 显示一些样本元数据\n",
" sample_metadata = test_data['trial_metadata'][0]\n",
" print(f\" 样本元数据:\")\n",
" print(f\" - 原始时间步: {sample_metadata['original_time_steps']}\")\n",
" print(f\" - 处理后时间步: {sample_metadata['processed_time_steps']}\")\n",
" print(f\" - 时间压缩比: {sample_metadata['feature_reduction_ratio']:.3f}\")\n",
" \n",
" if 'sentence_label' in sample_metadata:\n",
" print(f\" - 句子标签: {sample_metadata['sentence_label']}\")\n",
" \n",
" # 清理测试数据\n",
" del test_data, sample_data, sample_metadata\n",
" MemoryManager.clear_memory()\n",
" \n",
" print(f\"\\n处理大数据集建议:\")\n",
" print(f\" 使用增强版处理器,自动内存管理\")\n",
" print(f\" 数据自动分批保存,避免内存溢出\") \n",
" print(f\" 可安全处理 t15.2023.09.01 等大批次\")\n",
" print(f\" 要批量处理所有数据,运行:\")\n",
" print(f\" results = processor.process_all_sessions()\")\n",
" \n",
"else:\n",
" print(\"处理器未创建成功,请检查上面的错误信息\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"======================================================================\n",
"🚀 批量处理选项\n",
"======================================================================\n",
"📊 批量处理配置:\n",
" 启用批量处理: True\n",
" 保存目录: ./rnn_processed_data\n",
" 数据类型: ['train', 'val', 'test']\n",
" 总会话数: 45\n",
"\n",
"🚀 开始批量处理所有数据...\n",
"⚠️ 这可能需要较长时间预计30-60分钟\n",
"\n",
"🚀 开始批量处理所有会话...\n",
" 目标数据类型: ['train', 'val', 'test']\n",
" 保存目录: ./rnn_processed_data\n",
"\n",
"📊 进度: 1/45\n",
"\n",
"🔄 处理会话: t15.2023.08.11\n",
" 📁 处理 train 数据...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" ✅ train 处理完成:\n",
" 试验数: 288\n",
" 时间步范围: 30-251\n",
" 特征维度: 7209 (处理后特征: 7169, 置信度: 40)\n",
" 平均时间压缩比: 0.240\n",
" ⚠️ val 数据文件不存在,跳过\n",
" ⚠️ test 数据文件不存在,跳过\n",
" 💾 保存: t15.2023.08.11_train_rnn_processed.npz\n",
"\n",
"📊 进度: 2/45\n",
"\n",
"🔄 处理会话: t15.2023.08.13\n",
" 📁 处理 train 数据...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" ✅ train 处理完成:\n",
" 试验数: 348\n",
" 时间步范围: 55-352\n",
" 特征维度: 7209 (处理后特征: 7169, 置信度: 40)\n",
" 平均时间压缩比: 0.243\n",
" 📁 处理 val 数据...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" ✅ val 处理完成:\n",
" 试验数: 35\n",
" 时间步范围: 90-296\n",
" 特征维度: 7209 (处理后特征: 7169, 置信度: 40)\n",
" 平均时间压缩比: 0.243\n",
" 📁 处理 test 数据...\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
" \r"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
" ✅ test 处理完成:\n",
" 试验数: 35\n",
" 时间步范围: 80-238\n",
" 特征维度: 7209 (处理后特征: 7169, 置信度: 40)\n",
" 平均时间压缩比: 0.242\n",
" 💾 保存: t15.2023.08.13_train_rnn_processed.npz\n",
" 💾 保存: t15.2023.08.13_val_rnn_processed.npz\n"
]
}
],
"source": [
"# 🚀 批量处理所有数据 (增强内存管理 + 自动上传版本)\n",
"\n",
"print(\"=\"*70)\n",
"print(\"批量处理选项 - 内存优化 + 自动上传版\")\n",
"print(\"=\"*70)\n",
"\n",
"# 设置参数\n",
"ENABLE_FULL_PROCESSING = True # 设为True开始批量处理 \n",
"SAVE_DIR = \"./rnn_processed_data\" # 保存目录\n",
"DATA_TYPES = ['train', 'val', 'test'] # 要处理的数据类型\n",
"\n",
"print(f\" 批量处理配置 (内存优化 + 自动上传版):\")\n",
"print(f\" 启用批量处理: {ENABLE_FULL_PROCESSING}\")\n",
"print(f\" 保存目录: {SAVE_DIR}\")\n",
"print(f\" 数据类型: {DATA_TYPES}\")\n",
"print(f\" 总会话数: {len(processor.model_args['dataset']['sessions'])}\")\n",
"print(f\" 批量保存策略: 每{processor.batch_save_interval}个试验保存一次\")\n",
"print(f\" 内存监控阈值: {processor.memory_threshold}%\")\n",
"print(f\" 📤 自动上传: {'✅ 启用' if processor.enable_auto_upload else '❌ 禁用'}\")\n",
"print(f\" 🗑️ 自动清理: {'✅ 启用' if processor.enable_auto_upload else '❌ 禁用'}\")\n",
"\n",
"# 显示新功能优势\n",
"print(f\"\\n 🆕 新增功能优势:\")\n",
"print(f\" 📤 处理完成后自动上传到WebDAV\")\n",
"print(f\" 🗑️ 上传成功后自动删除本地文件\")\n",
"print(f\" 💾 节省本地存储空间\")\n",
"print(f\" ☁️ 数据安全备份到云端\")\n",
"print(f\" 🔄 无需手动管理文件传输\")\n",
"\n",
"print(f\"\\n 🔧 技术特性:\")\n",
"print(f\" 自动分批保存,避免内存积累\")\n",
"print(f\" 实时GPU内存清理\")\n",
"print(f\" 垃圾回收优化\")\n",
"print(f\" 内存使用监控和警告\")\n",
"print(f\" 可处理 t15.2023.09.01 等大数据集\")\n",
"\n",
"if ENABLE_FULL_PROCESSING and processor is not None:\n",
" print(f\"\\n 🚀 开始批量处理所有数据(内存优化 + 自动上传版)...\")\n",
" print(f\" 这可能需要较长时间预计30-60分钟\")\n",
" print(f\" 内存不足问题已解决,可安全处理大数据集\")\n",
" print(f\" 📤 文件将自动上传到WebDAV并删除本地副本\")\n",
" print(f\" 开始时内存状态: {MemoryManager.get_memory_info()}\")\n",
" \n",
" # 记录处理开始时间\n",
" import time\n",
" start_processing_time = time.time()\n",
" \n",
" # 批量处理\n",
" all_results = processor.process_all_sessions(\n",
" data_types=DATA_TYPES,\n",
" save_dir=SAVE_DIR\n",
" )\n",
" \n",
" # 计算处理时间\n",
" end_processing_time = time.time()\n",
" total_processing_time = end_processing_time - start_processing_time\n",
" \n",
" print(f\"\\n 🎉 批量处理完成!\")\n",
" print(f\" ⏱ 总处理时间: {total_processing_time/60:.2f} 分钟\")\n",
" print(f\" 最终内存状态: {MemoryManager.get_memory_info()}\")\n",
" \n",
" # 详细统计\n",
" total_files = 0\n",
" total_uploaded = 0\n",
" for session_name, session_data in all_results.items():\n",
" for data_type, type_data in session_data.items():\n",
" total_files += type_data['batch_count']\n",
" total_uploaded += type_data.get('uploaded_files', 0)\n",
" \n",
" print(f\"\\n 📊 处理统计详情:\")\n",
" print(f\" 成功处理的会话: {len(all_results)}\")\n",
" print(f\" 生成文件总数: {total_files}\")\n",
" print(f\" 📤 成功上传文件: {total_uploaded}\")\n",
" print(f\" 💾 本地保留文件: {total_files - total_uploaded}\")\n",
" print(f\" 🔄 上传成功率: {(total_uploaded/total_files*100) if total_files > 0 else 0:.1f}%\")\n",
" \n",
" # 存储空间统计\n",
" try:\n",
" import os\n",
" local_size = 0\n",
" if os.path.exists(SAVE_DIR):\n",
" for root, dirs, files in os.walk(SAVE_DIR):\n",
" for file in files:\n",
" local_size += os.path.getsize(os.path.join(root, file))\n",
" \n",
" print(f\" 💾 剩余本地文件大小: {local_size / 1024**3:.2f} GB\")\n",
" except:\n",
" pass\n",
" \n",
" print(f\"\\n ☁️ WebDAV云端存储:\")\n",
" print(f\" 远程路径: /移动云盘/DATA/rnn_processed_data/\")\n",
" print(f\" 文件格式: session_name_datatype_rnn_processed_batchN.npz\")\n",
" print(f\" 例如: t15.2023.08.13_train_rnn_processed_batch0.npz\")\n",
" \n",
" if total_uploaded > 0:\n",
" print(f\"\\n ✅ 自动上传工作流成功!\")\n",
" print(f\" 所有处理完的文件已自动上传到云端\")\n",
" print(f\" 本地存储空间得到有效管理\")\n",
" print(f\" 数据安全性和可访问性得到保障\")\n",
" \n",
"else:\n",
" print(f\"\\n 要开始批量处理,请将 ENABLE_FULL_PROCESSING 设为 True\")\n",
" print(f\" 或者手动运行: processor.process_all_sessions()\")\n",
"\n",
"print(f\"\\n 📋 数据使用说明(自动上传版本):\")\n",
"print(f\" 🔄 处理流程:\")\n",
"print(f\" 1. 处理原始神经数据 → RNN输出\")\n",
"print(f\" 2. 保存到本地 (.npz 文件)\")\n",
"print(f\" 3. 自动上传到WebDAV云端\")\n",
"print(f\" 4. 删除本地文件,释放存储空间\")\n",
"print(f\"\")\n",
"print(f\" 📤 云端文件结构:\")\n",
"print(f\" /移动云盘/DATA/rnn_processed_data/\")\n",
"print(f\" ├── t15.2023.08.11_train_rnn_processed_batch0.npz\")\n",
"print(f\" ├── t15.2023.08.11_train_rnn_processed_batch1.npz\")\n",
"print(f\" ├── t15.2023.08.11_val_rnn_processed_batch0.npz\")\n",
"print(f\" └── ...\")\n",
"print(f\"\")\n",
"print(f\" 📥 下载和使用:\")\n",
"print(f\" # 从WebDAV下载文件\")\n",
"print(f\" uploader.client.download_sync(\")\n",
"print(f\" remote_path='/移动云盘/DATA/rnn_processed_data/filename.npz',\")\n",
"print(f\" local_path='./filename.npz'\")\n",
"print(f\" )\")\n",
"print(f\" \")\n",
"print(f\" # 加载数据\")\n",
"print(f\" data = np.load('filename.npz', allow_pickle=True)\")\n",
"print(f\" features = data['concatenated_data']\")\n",
"print(f\" metadata = data['trial_metadata']\")\n",
"print(f\"\")\n",
"print(f\" 🎯 优势总结:\")\n",
"print(f\" ✅ 解决了 t15.2023.09.01 内存不足问题\")\n",
"print(f\" ✅ 数据自动分批,便于后续加载\")\n",
"print(f\" ✅ 处理速度优化,内存使用稳定\")\n",
"print(f\" ✅ 错误恢复能力强,单个批次失败不影响整体\")\n",
"print(f\" 🆕 自动上传,无需手动管理文件\")\n",
"print(f\" 🆕 本地存储空间自动释放\")\n",
"print(f\" 🆕 云端数据安全备份\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# 📁 云端文件管理工具\n",
"\n",
"def list_cloud_files(remote_dir='/移动云盘/DATA/rnn_processed_data/'):\n",
" \"\"\"列出云端的所有处理文件\"\"\"\n",
" if 'uploader' not in globals():\n",
" print(\"❌ WebDAV上传器未初始化\")\n",
" return []\n",
" \n",
" try:\n",
" files = uploader.client.list(remote_dir)\n",
" rnn_files = [f for f in files if f.endswith('.npz')]\n",
" \n",
" print(f\"☁️ 云端文件列表 ({len(rnn_files)} 个文件):\")\n",
" print(f\"📍 路径: {remote_dir}\")\n",
" print(\"-\" * 60)\n",
" \n",
" for i, file in enumerate(rnn_files, 1):\n",
" print(f\"{i:3d}. {file}\")\n",
" \n",
" return rnn_files\n",
" \n",
" except Exception as e:\n",
" print(f\"❌ 获取文件列表失败: {e}\")\n",
" return []\n",
"\n",
"def download_cloud_file(filename, local_dir='./downloaded_data/'):\n",
" \"\"\"从云端下载单个文件\"\"\"\n",
" if 'uploader' not in globals():\n",
" print(\"❌ WebDAV上传器未初始化\")\n",
" return False\n",
" \n",
" try:\n",
" # 确保本地目录存在\n",
" os.makedirs(local_dir, exist_ok=True)\n",
" \n",
" remote_path = f'/移动云盘/DATA/rnn_processed_data/{filename}'\n",
" local_path = os.path.join(local_dir, filename)\n",
" \n",
" uploader.client.download_sync(remote_path=remote_path, local_path=local_path)\n",
" \n",
" print(f\"✅ 下载成功: {filename}\")\n",
" print(f\"📁 保存到: {local_path}\")\n",
" \n",
" # 显示文件信息\n",
" if os.path.exists(local_path):\n",
" file_size = os.path.getsize(local_path) / 1024**2 # MB\n",
" print(f\"📊 文件大小: {file_size:.2f} MB\")\n",
" \n",
" return True\n",
" \n",
" except Exception as e:\n",
" print(f\"❌ 下载失败: {e}\")\n",
" return False\n",
"\n",
"def download_session_files(session_name, data_types=['train', 'val', 'test'], local_dir='./downloaded_data/'):\n",
" \"\"\"下载指定会话的所有文件\"\"\"\n",
" files = list_cloud_files()\n",
" \n",
" session_files = []\n",
" for file in files:\n",
" if file.startswith(session_name):\n",
" for data_type in data_types:\n",
" if f'_{data_type}_' in file:\n",
" session_files.append(file)\n",
" break\n",
" \n",
" if not session_files:\n",
" print(f\"❌ 未找到会话 {session_name} 的文件\")\n",
" return False\n",
" \n",
" print(f\"\\n📥 下载会话 {session_name} 的文件...\")\n",
" success_count = 0\n",
" \n",
" for file in session_files:\n",
" if download_cloud_file(file, local_dir):\n",
" success_count += 1\n",
" \n",
" print(f\"\\n✅ 下载完成: {success_count}/{len(session_files)} 个文件\")\n",
" return success_count == len(session_files)\n",
"\n",
"def check_local_storage():\n",
" \"\"\"检查本地存储使用情况\"\"\"\n",
" print(\"💾 本地存储检查:\")\n",
" \n",
" # 检查处理数据目录\n",
" if os.path.exists('./rnn_processed_data'):\n",
" total_size = 0\n",
" file_count = 0\n",
" \n",
" for root, dirs, files in os.walk('./rnn_processed_data'):\n",
" for file in files:\n",
" if file.endswith('.npz'):\n",
" filepath = os.path.join(root, file)\n",
" total_size += os.path.getsize(filepath)\n",
" file_count += 1\n",
" \n",
" print(f\" 📁 ./rnn_processed_data/\")\n",
" print(f\" 📊 文件数量: {file_count}\")\n",
" print(f\" 📊 总大小: {total_size / 1024**3:.2f} GB\")\n",
" \n",
" if file_count > 0:\n",
" print(f\" 💡 建议: 这些文件已处理完成,可以删除以释放空间\")\n",
" print(f\" 使用: rm -rf ./rnn_processed_data/\")\n",
" else:\n",
" print(f\" ✅ ./rnn_processed_data/ 目录不存在或为空\")\n",
" \n",
" # 检查下载目录\n",
" if os.path.exists('./downloaded_data'):\n",
" download_size = 0\n",
" download_count = 0\n",
" \n",
" for root, dirs, files in os.walk('./downloaded_data'):\n",
" for file in files:\n",
" if file.endswith('.npz'):\n",
" filepath = os.path.join(root, file)\n",
" download_size += os.path.getsize(filepath)\n",
" download_count += 1\n",
" \n",
" print(f\" 📁 ./downloaded_data/\")\n",
" print(f\" 📊 下载文件数: {download_count}\")\n",
" print(f\" 📊 下载大小: {download_size / 1024**3:.2f} GB\")\n",
"\n",
"# 使用示例\n",
"print(\"📁 云端文件管理工具已加载!\")\n",
"print(\"\\n🛠 可用函数:\")\n",
"print(\"• list_cloud_files() # 列出所有云端文件\")\n",
"print(\"• download_cloud_file('filename.npz') # 下载单个文件\")\n",
"print(\"• download_session_files('t15.2023.08.13') # 下载指定会话的所有文件\")\n",
"print(\"• check_local_storage() # 检查本地存储使用情况\")\n",
"\n",
"print(\"\\n💡 使用示例:\")\n",
"print(\"# 查看云端有哪些文件\")\n",
"print(\"files = list_cloud_files()\")\n",
"print(\"\")\n",
"print(\"# 下载特定文件\")\n",
"print(\"download_cloud_file('t15.2023.08.13_train_rnn_processed_batch0.npz')\")\n",
"print(\"\")\n",
"print(\"# 下载整个会话的数据\")\n",
"print(\"download_session_files('t15.2023.08.13')\")\n",
"print(\"\")\n",
"print(\"# 检查本地存储\")\n",
"print(\"check_local_storage()\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 📤 WebDAV文件上传工具"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 📚 WebDAV库选择 - 现成的解决方案"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"✅ webdavclient3 可用\n",
"修复后的WebDAV上传函数:\n",
"1. 单个文件上传会自动添加文件名到远程路径\n",
"2. 目录上传会过滤掉 .git 等不需要的文件\n",
"3. 显示详细的上传和跳过信息\n",
"\n",
"测试单个文件上传...\n",
"单个文件上传结果: {'success': True, 'local_file': 'F:\\\\BRAIN-TO-TEXT\\\\nejm-brain-to-text\\\\brain-to-text-25\\\\client_secrets.json', 'remote_path': '/移动云盘/DATA/client_secrets.json', 'library': 'webdavclient3'}\n",
"\n",
"测试目录上传(带过滤)...\n",
"单个文件上传结果: {'success': True, 'local_file': 'F:\\\\BRAIN-TO-TEXT\\\\nejm-brain-to-text\\\\brain-to-text-25\\\\client_secrets.json', 'remote_path': '/移动云盘/DATA/client_secrets.json', 'library': 'webdavclient3'}\n",
"\n",
"测试目录上传(带过滤)...\n",
"目录上传结果: {'success': True, 'local_dir': 'F:\\\\BRAIN-TO-TEXT\\\\nejm-brain-to-text\\\\data-kaggle', 'remote_dir': '/移动云盘/DATA/data-kaggle', 'uploaded_files': [], 'skipped_files': [], 'library': 'webdavclient3'}\n",
"上传了 0 个文件\n",
"跳过了 0 个文件\n",
"目录上传结果: {'success': True, 'local_dir': 'F:\\\\BRAIN-TO-TEXT\\\\nejm-brain-to-text\\\\data-kaggle', 'remote_dir': '/移动云盘/DATA/data-kaggle', 'uploaded_files': [], 'skipped_files': [], 'library': 'webdavclient3'}\n",
"上传了 0 个文件\n",
"跳过了 0 个文件\n"
]
}
],
"source": [
"# 📋 WebDAV库安装指南\n",
"\n",
"print(\"📦 安装WebDAV客户端库:\")\n",
"print(\"pip install webdavclient3\")\n",
"print(\"\")\n",
"print(\"💡 如果已安装,下面的简化版上传工具就可以直接使用了!\")\n",
"print(\" 所有复杂的代码都已简化为易用的类和函数。\")\n",
"\n",
"# 检查安装状态\n",
"try:\n",
" from webdav3.client import Client\n",
" print(\"✅ webdavclient3 已安装并可用\")\n",
"except ImportError:\n",
" print(\"❌ 需要安装: pip install webdavclient3\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"✅ WebDAV连接已建立: http://zchens.cn:5244/dav/\n",
"\n",
"📖 使用方法:\n",
"# 上传单个文件\n",
"uploader.upload_file(r'F:\\path\\to\\file.txt')\n",
"\n",
"# 上传目录(自动过滤.git等文件\n",
"uploader.upload_dir(r'F:\\path\\to\\dir', '/移动云盘/DATA/目标目录')\n",
"\n",
"# 上传目录(不过滤任何文件)\n",
"uploader.upload_dir(r'F:\\path\\to\\dir', '/移动云盘/DATA/目标目录', exclude_git=False)\n"
]
}
],
"source": [
"# 📤 简化版WebDAV上传工具\n",
"\n",
"from webdav3.client import Client\n",
"import os\n",
"import fnmatch\n",
"\n",
"class WebDAVUploader:\n",
" \"\"\"简化的WebDAV上传器\"\"\"\n",
" \n",
" def __init__(self, url='http://zchens.cn:5244/dav/', username='admin', password='Zccns20050420'):\n",
" \"\"\"初始化WebDAV连接\"\"\"\n",
" self.client = Client({\n",
" 'webdav_hostname': url,\n",
" 'webdav_login': username,\n",
" 'webdav_password': password,\n",
" 'webdav_timeout': 30\n",
" })\n",
" print(f\"✅ WebDAV连接已建立: {url}\")\n",
" \n",
" def upload_file(self, local_file, remote_dir='/移动云盘/DATA/'):\n",
" \"\"\"上传单个文件\"\"\"\n",
" try:\n",
" filename = os.path.basename(local_file)\n",
" remote_path = remote_dir.rstrip('/') + '/' + filename\n",
" \n",
" # 确保远程目录存在\n",
" if not self.client.check(remote_dir):\n",
" self.client.mkdir(remote_dir)\n",
" \n",
" self.client.upload_sync(remote_path=remote_path, local_path=local_file)\n",
" print(f\"✅ 文件上传成功: {filename} -> {remote_path}\")\n",
" return True\n",
" \n",
" except Exception as e:\n",
" print(f\"❌ 文件上传失败: {e}\")\n",
" return False\n",
" \n",
" def upload_dir(self, local_dir, remote_dir, exclude_git=True):\n",
" \"\"\"上传目录(自动过滤不需要的文件)\"\"\"\n",
" exclude_patterns = ['.git*', '__pycache__*', '*.pyc', '.vscode*'] if exclude_git else []\n",
" \n",
" try:\n",
" uploaded = 0\n",
" skipped = 0\n",
" \n",
" for root, dirs, files in os.walk(local_dir):\n",
" # 过滤目录\n",
" if exclude_git:\n",
" dirs[:] = [d for d in dirs if not any(fnmatch.fnmatch(d, p) for p in exclude_patterns)]\n",
" \n",
" for file in files:\n",
" # 检查是否跳过文件\n",
" if exclude_git and any(fnmatch.fnmatch(file, p) for p in exclude_patterns):\n",
" skipped += 1\n",
" continue\n",
" \n",
" local_file = os.path.join(root, file)\n",
" rel_path = os.path.relpath(local_file, local_dir)\n",
" remote_file = remote_dir.rstrip('/') + '/' + rel_path.replace('\\\\', '/')\n",
" \n",
" # 确保远程目录存在\n",
" remote_file_dir = '/'.join(remote_file.split('/')[:-1])\n",
" if not self.client.check(remote_file_dir):\n",
" self.client.mkdir(remote_file_dir)\n",
" \n",
" self.client.upload_sync(remote_path=remote_file, local_path=local_file)\n",
" uploaded += 1\n",
" \n",
" print(f\"✅ 目录上传完成: 上传 {uploaded} 个文件,跳过 {skipped} 个文件\")\n",
" return True\n",
" \n",
" except Exception as e:\n",
" print(f\"❌ 目录上传失败: {e}\")\n",
" return False\n",
"\n",
"# 创建上传器实例\n",
"uploader = WebDAVUploader()\n",
"\n",
"print(\"\\n📖 使用方法:\")\n",
"print(\"# 上传单个文件\")\n",
"print(\"uploader.upload_file(r'F:\\\\path\\\\to\\\\file.txt')\")\n",
"print(\"\")\n",
"print(\"# 上传目录(自动过滤.git等文件\") \n",
"print(\"uploader.upload_dir(r'F:\\\\path\\\\to\\\\dir', '/移动云盘/DATA/目标目录')\")\n",
"print(\"\")\n",
"print(\"# 上传目录(不过滤任何文件)\")\n",
"print(\"uploader.upload_dir(r'F:\\\\path\\\\to\\\\dir', '/移动云盘/DATA/目标目录', exclude_git=False)\")"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"==================================================\n",
"WebDAV上传示例\n",
"==================================================\n",
"\n",
"1⃣ 上传单个文件:\n",
"✅ 文件上传成功: client_secrets.json -> /移动云盘/DATA/client_secrets.json\n",
"\n",
"2⃣ 上传目录(过滤.git等文件:\n",
"✅ 目录上传完成: 上传 0 个文件,跳过 0 个文件\n",
"\n",
"3⃣ 上传RNN处理结果:\n",
"✅ 文件上传成功: client_secrets.json -> /移动云盘/DATA/client_secrets.json\n",
"\n",
"2⃣ 上传目录(过滤.git等文件:\n",
"✅ 目录上传完成: 上传 0 个文件,跳过 0 个文件\n",
"\n",
"3⃣ 上传RNN处理结果:\n",
"❌ 目录上传失败: HTTPConnectionPool(host='127.0.0.1', port=7897): Read timed out. (read timeout=30)\n",
"RNN处理结果上传完成\n",
"\n",
"✨ 上传任务完成!现在你的文件应该在云盘中了。\n"
]
}
],
"source": [
"# 🚀 快速使用示例\n",
"\n",
"print(\"=\"*50)\n",
"print(\"WebDAV上传示例\")\n",
"print(\"=\"*50)\n",
"\n",
"# 示例1: 上传单个文件\n",
"print(\"\\n1⃣ 上传单个文件:\")\n",
"uploader.upload_file(r'F:\\BRAIN-TO-TEXT\\nejm-brain-to-text\\brain-to-text-25\\client_secrets.json')\n",
"\n",
"# 示例2: 上传目录(自动过滤.git\n",
"print(\"\\n2⃣ 上传目录(过滤.git等文件:\")\n",
"uploader.upload_dir(\n",
" r'F:\\BRAIN-TO-TEXT\\nejm-brain-to-text\\data-kaggle', \n",
" '/移动云盘/DATA/data-kaggle-clean'\n",
")\n",
"\n",
"# 示例3: 上传处理后的数据目录\n",
"print(\"\\n3⃣ 上传RNN处理结果:\")\n",
"if os.path.exists('./rnn_processed_data'):\n",
" uploader.upload_dir(\n",
" './rnn_processed_data',\n",
" '/移动云盘/DATA/rnn_processed_data'\n",
" )\n",
" print(\"RNN处理结果上传完成\")\n",
"else:\n",
" print(\"⚠️ 未找到RNN处理数据目录\")\n",
"\n",
"print(\"\\n✨ 上传任务完成!现在你的文件应该在云盘中了。\")"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"🎯 便捷函数已定义!\n",
"\n",
"可用的快捷函数:\n",
"• quick_upload_file('文件路径') # 上传单个文件\n",
"• quick_upload_project('项目目录') # 上传整个项目\n",
"• quick_upload_results() # 上传所有结果文件\n",
"\n",
"例如:\n",
"quick_upload_file('data.csv')\n",
"quick_upload_project(r'F:\\BRAIN-TO-TEXT\\nejm-brain-to-text')\n",
"quick_upload_results()\n"
]
}
],
"source": [
"# 🎯 便捷函数 - 一键上传常用文件\n",
"\n",
"def quick_upload_file(file_path, remote_dir='/移动云盘/DATA/'):\n",
" \"\"\"快速上传单个文件\"\"\"\n",
" return uploader.upload_file(file_path, remote_dir)\n",
"\n",
"def quick_upload_project(project_dir, remote_name=None):\n",
" \"\"\"快速上传整个项目目录(自动过滤.git等\"\"\"\n",
" if remote_name is None:\n",
" remote_name = os.path.basename(project_dir.rstrip('/\\\\'))\n",
" \n",
" remote_dir = f'/移动云盘/DATA/{remote_name}'\n",
" return uploader.upload_dir(project_dir, remote_dir, exclude_git=True)\n",
"\n",
"def quick_upload_results():\n",
" \"\"\"快速上传所有结果文件\"\"\"\n",
" results = []\n",
" \n",
" # 上传RNN处理结果\n",
" if os.path.exists('./rnn_processed_data'):\n",
" print(\"📊 上传RNN处理结果...\")\n",
" results.append(uploader.upload_dir('./rnn_processed_data', '/移动云盘/DATA/rnn_processed_data'))\n",
" \n",
" # 上传notebook文件\n",
" notebook_files = [f for f in os.listdir('.') if f.endswith('.ipynb')]\n",
" for nb in notebook_files:\n",
" print(f\"📓 上传notebook: {nb}\")\n",
" results.append(uploader.upload_file(nb, '/移动云盘/DATA/notebooks/'))\n",
" \n",
" success_count = sum(results)\n",
" print(f\"\\n✅ 完成!成功上传 {success_count}/{len(results)} 个项目\")\n",
" return all(results)\n",
"\n",
"# 使用示例\n",
"print(\"🎯 便捷函数已定义!\")\n",
"print(\"\\n可用的快捷函数:\")\n",
"print(\"• quick_upload_file('文件路径') # 上传单个文件\")\n",
"print(\"• quick_upload_project('项目目录') # 上传整个项目\")\n",
"print(\"• quick_upload_results() # 上传所有结果文件\")\n",
"print(\"\\n例如:\")\n",
"print(\"quick_upload_file('data.csv')\")\n",
"print(\"quick_upload_project(r'F:\\\\BRAIN-TO-TEXT\\\\nejm-brain-to-text')\")\n",
"print(\"quick_upload_results()\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kaggle": {
"accelerator": "tpu1vmV38",
"dataSources": [
{
"databundleVersionId": 13056355,
"sourceId": 106809,
"sourceType": "competition"
}
],
"dockerImageVersionId": 31091,
"isGpuEnabled": false,
"isInternetEnabled": true,
"language": "python",
"sourceType": "notebook"
},
"kernelspec": {
"display_name": "b2txt25",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.18"
}
},
"nbformat": 4,
"nbformat_minor": 4
}