Merge branch 'b2txt25' into tscizzlebg-patch-1
This commit is contained in:
169
.gitignore
vendored
169
.gitignore
vendored
@@ -1,170 +1,9 @@
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/latest/usage/project/#working-with-version-control
|
||||
.pdm.toml
|
||||
.pdm-python
|
||||
.pdm-build/
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
|
||||
# ignore data folder
|
||||
data/*
|
||||
# ignore python bytecode files
|
||||
*.pyc
|
||||
*.egg-info
|
||||
|
||||
# ignore rdb files
|
||||
*.rdb
|
||||
|
||||
# ignore ds_store files
|
||||
.DS_Store
|
26
README.md
26
README.md
@@ -23,14 +23,30 @@ The code is organized into five main directories: `utils`, `analyses`, `data`, `
|
||||
- The `analyses` directory contains the code necessary to reproduce results shown in the main text and supplemental appendix.
|
||||
- The `data` directory contains the data necessary to reproduce the results in the paper. Download it from Dryad using the link above and place it in this directory.
|
||||
- The `model_training` directory contains the code necessary to train and evaluate the brain-to-text model. See the README.md in that folder for more detailed instructions.
|
||||
- The `language_model` directory contains the ngram language model implementation and a pretrained 1gram language model. Pretrained 3gram and 5gram language models can be downloaded [here](https://datadryad.org/dataset/doi:10.5061/dryad.x69p8czpq) (`languageModel.tar.gz` and `languageModel_5gram.tar.gz`). See the `README.md` in this directory for more information.
|
||||
- The `language_model` directory contains the ngram language model implementation and a pretrained 1gram language model. Pretrained 3gram and 5gram language models can be downloaded [here](https://datadryad.org/dataset/doi:10.5061/dryad.x69p8czpq) (`languageModel.tar.gz` and `languageModel_5gram.tar.gz`). See [`language_model/README.md`](language_model/README.md) for more information.
|
||||
|
||||
## Data
|
||||
The data used in this repository consists of various datasets for recreating figures and training/evaluating the brain-to-text model:
|
||||
- `t15_copyTask.pkl`: This file contains the online Copy Task results required for generating Figure 2.
|
||||
- `t15_personalUse.pkl`: This file contains the Conversation Mode data required for generating Figure 4.
|
||||
- `t15_copyTask_neuralData.zip`: This dataset contains the neural data for the Copy Task. There are more than 11,300 sentences from 45 sessions spanning 20 months. The data is split into training, validation, and test sets. Data for each session/split is stored in `.hdf5` files. An example of how to load this data using the Python `h5py` library is provided in the `model_training/evaluate_model_helpers.py` file in the `load_h5py_file()` function.
|
||||
- `t15_pretrained_rnn_baseline.zip`: This dataset contains the pretrained RNN baseline model checkpoint and args. An example of how to load this model and use it for inference is provided in the `model_training/evaluate_model.py` file.
|
||||
- `t15_copyTask_neuralData.zip`: This dataset contains the neural data for the Copy Task.
|
||||
- There are more than 11,300 sentences from 45 sessions spanning 20 months. Each trial of data includes:
|
||||
- The session date, block number, and trial number
|
||||
- 512 neural features (2 features [-4.5 RMS threshold crossings and spike band power] per electrode, 256 electrodes), binned at 20 ms resolution. The data were recorded from the speech motor cortex via four high-density microelectrode arrays (64 electrodes each). The 512 features are ordered as follows in all data files:
|
||||
- 0-64: ventral 6v threshold crossings
|
||||
- 65-128: area 4 threshold crossings
|
||||
- 129-192: 55b threshold crossings
|
||||
- 193-256: dorsal 6v threshold crossings
|
||||
- 257-320: ventral 6v spike band power
|
||||
- 321-384: area 4 spike band power
|
||||
- 385-448: 55b spike band power
|
||||
- 449-512: dorsal 6v spike band power
|
||||
- The ground truth sentence label
|
||||
- The ground truth phoneme sequence label
|
||||
- The data is split into training, validation, and test sets. The test set does not include ground truth sentence or phoneme labels.
|
||||
- Data for each session/split is stored in `.hdf5` files. An example of how to load this data using the Python `h5py` library is provided in the [`model_training/evaluate_model_helpers.py`](model_training/evaluate_model_helpers.py) file in the `load_h5py_file()` function.
|
||||
- Each block of data contains sentences drawn from a range of corpuses (Switchboard, OpenWebText2, a 50-word corpus, a custom frequent-word corpus, and a corpus of random word sequences). Furthermore, the majority of the data is during attempted vocalized speaking, but some of it is during attempted silent speaking.
|
||||
- `t15_pretrained_rnn_baseline.zip`: This dataset contains the pretrained RNN baseline model checkpoint and args. An example of how to load this model and use it for inference is provided in the [`model_training/evaluate_model.py`](model_training/evaluate_model.py) file.
|
||||
|
||||
Please download these datasets from [Dryad](https://datadryad.org/stash/dataset/doi:10.5061/dryad.dncjsxm85) and place them in the `data` directory. Be sure to unzip both datasets before running the code.
|
||||
|
||||
@@ -61,9 +77,9 @@ To create a conda environment with the necessary dependencies, run the following
|
||||
Verify it worked by activating the conda environment with the command `conda activate b2txt25`.
|
||||
|
||||
## Python environment setup for ngram language model and OPT rescoring
|
||||
We use an ngram language model plus rescoring via the [Facebook OPT 6.7b](https://huggingface.co/facebook/opt-6.7b) LLM. A pretrained 1gram language model is included in this repository at `language_model/pretrained_language_models/openwebtext_1gram_lm_sil`. Pretrained 3gram and 5gram language models are available for download [here](https://datadryad.org/dataset/doi:10.5061/dryad.x69p8czpq) (`languageModel.tar.gz` and `languageModel_5gram.tar.gz`). Note that the 3gram model requires ~60GB of RAM, and the 5gram model requires ~300GB of RAM. Furthermore, OPT 6.7b requires a GPU with at least ~12.4 GB of VRAM to load for inference.
|
||||
We use an ngram language model plus rescoring via the [Facebook OPT 6.7b](https://huggingface.co/facebook/opt-6.7b) LLM. A pretrained 1gram language model is included in this repository at [`language_model/pretrained_language_models/openwebtext_1gram_lm_sil`](language_model/pretrained_language_models/openwebtext_1gram_lm_sil). Pretrained 3gram and 5gram language models are available for download [here](https://datadryad.org/dataset/doi:10.5061/dryad.x69p8czpq) (`languageModel.tar.gz` and `languageModel_5gram.tar.gz`). Note that the 3gram model requires ~60GB of RAM, and the 5gram model requires ~300GB of RAM. Furthermore, OPT 6.7b requires a GPU with at least ~12.4 GB of VRAM to load for inference.
|
||||
|
||||
Our Kaldi-based ngram implementation requires a different version of torch than our model training pipeline, so running the ngram language models requires an additional seperate python conda environment. To create this conda environment, run the following command from the root directory of this repository. For more detailed instructions, see the README.md in the `language_model` subdirectory.
|
||||
Our Kaldi-based ngram implementation requires a different version of torch than our model training pipeline, so running the ngram language models requires an additional seperate python conda environment. To create this conda environment, run the following command from the root directory of this repository. For more detailed instructions, see the README.md in the [`language_model`](language_model) subdirectory.
|
||||
```bash
|
||||
./setup_lm.sh
|
||||
```
|
||||
|
4
data/.gitignore
vendored
Normal file
4
data/.gitignore
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
# ignore everything in this folder except my README.md and myself
|
||||
*
|
||||
!README.md
|
||||
!/.gitignore
|
1
data/README.md
Normal file
1
data/README.md
Normal file
@@ -0,0 +1 @@
|
||||
Data can be downloaded from Dryad, [here](https://datadryad.org/stash/dataset/doi:10.5061/dryad.dncjsxm85). Please download this data and place it in the `data` directory before running the code. Be sure to unzip `t15_copyTask_neuralData.zip` and `t15_pretrained_rnn_baseline.zip`.
|
@@ -1,5 +1,5 @@
|
||||
# Pretrained ngram language models
|
||||
A pretrained 1gram language model is included in this repository at `language_model/pretrained_language_models/openwebtext_1gram_lm_sil`. Pretrained 3gram and 5gram language models are available for download [here](https://datadryad.org/dataset/doi:10.5061/dryad.x69p8czpq) (`languageModel.tar.gz` and `languageModel_5gram.tar.gz`) and should likewise be placed in the `language_model/pretrained_language_models/` directory. Note that the 3gram model requires ~60GB of RAM, and the 5gram model requires ~300GB of RAM. Furthermore, OPT 6.7b requires a GPU with at least ~12.4 GB of VRAM to load for inference.
|
||||
A pretrained 1gram language model is included in this repository at `language_model/pretrained_language_models/openwebtext_1gram_lm_sil`. Pretrained 3gram and 5gram language models are available for download [here](https://datadryad.org/dataset/doi:10.5061/dryad.x69p8czpq) (`languageModel.tar.gz` and `languageModel_5gram.tar.gz`) and should likewise be placed in the [`pretrained_language_models`](pretrained_language_models) directory. Note that the 3gram model requires ~60GB of RAM, and the 5gram model requires ~300GB of RAM. Furthermore, OPT 6.7b requires a GPU with at least ~12.4 GB of VRAM to load for inference.
|
||||
|
||||
# Dependencies
|
||||
```
|
||||
@@ -13,11 +13,11 @@ sudo apt-get install build-essential
|
||||
```
|
||||
|
||||
# Install language model python package
|
||||
Use the `setup_lm.sh` script in the root directory of this repository to create the `b2txt_lm` conda env and install the `lm-decoder` package to it. Before install, make sure that there is no `build` or `fc_base` directory in your `language_model/runtime/server/x86` directory, as this may cause the build to fail.
|
||||
Use the `setup_lm.sh` script in the root directory of this repository to create the `b2txt_lm` conda env and install the `lm-decoder` package to it. Before install, make sure that there is no `build` or `fc_base` directory in your [`runtime/server/x86`](runtime/server/x86) directory, as this may cause the build to fail.
|
||||
|
||||
|
||||
# Using a pretrained ngram language model
|
||||
The `language-model-standalone.py` script included here is made to work with the `evaluate_model.py` script in the `model_training` directory. `language-model-standalone.py` will do the following when run:
|
||||
The [`language-model-standalone.py`](language-model-standalone.py) script included here is made to work with [`evaluate_model.py`](../model_training/evaluate_model.py). `language-model-standalone.py` will do the following when run:
|
||||
1. Initialize `opt-6.7b` it on the specified gpu (`--gpu_number` arg). The first time you run the script, it will automatically download `opt-6.7b` from huggingface.
|
||||
2. Initialize the ngram language model (specified with the `--lm_path` arg)
|
||||
3. Connect to the `localhost` redis server (or a different server, specified by the `--redis_ip` and `--redis_port` args)
|
||||
|
4
language_model/pretrained_language_models/.gitignore
vendored
Normal file
4
language_model/pretrained_language_models/.gitignore
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
# ignore everything in this folder except a few things
|
||||
*
|
||||
!openwebtext_1gram_lm_sil
|
||||
!/.gitignore
|
166
language_model/runtime/core/patch/openfst/src/lib/flags.cc
Normal file
166
language_model/runtime/core/patch/openfst/src/lib/flags.cc
Normal file
@@ -0,0 +1,166 @@
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
//
|
||||
// Google-style flag handling definitions.
|
||||
|
||||
#include <cstring>
|
||||
|
||||
#if _MSC_VER
|
||||
#include <io.h>
|
||||
#include <fcntl.h>
|
||||
#endif
|
||||
|
||||
#include <fst/compat.h>
|
||||
#include <fst/flags.h>
|
||||
|
||||
static const char *private_tmpdir = getenv("TMPDIR");
|
||||
|
||||
// DEFINE_int32(v, 0, "verbosity level");
|
||||
// DEFINE_bool(help, false, "show usage information");
|
||||
// DEFINE_bool(helpshort, false, "show brief usage information");
|
||||
#ifndef _MSC_VER
|
||||
DEFINE_string(tmpdir, private_tmpdir ? private_tmpdir : "/tmp",
|
||||
"temporary directory");
|
||||
#else
|
||||
DEFINE_string(tmpdir, private_tmpdir ? private_tmpdir : getenv("TEMP"),
|
||||
"temporary directory");
|
||||
#endif // !_MSC_VER
|
||||
|
||||
using namespace std;
|
||||
|
||||
static string flag_usage;
|
||||
static string prog_src;
|
||||
|
||||
// Sets prog_src to src.
|
||||
static void SetProgSrc(const char *src) {
|
||||
prog_src = src;
|
||||
#if _MSC_VER
|
||||
// This common code is invoked by all FST binaries, and only by them. Switch
|
||||
// stdin and stdout into "binary" mode, so that 0x0A won't be translated into
|
||||
// a 0x0D 0x0A byte pair in a pipe or a shell redirect. Other streams are
|
||||
// already using ios::binary where binary files are read or written.
|
||||
// Kudos to @daanzu for the suggested fix.
|
||||
// https://github.com/kkm000/openfst/issues/20
|
||||
// https://github.com/kkm000/openfst/pull/23
|
||||
// https://github.com/kkm000/openfst/pull/32
|
||||
_setmode(_fileno(stdin), O_BINARY);
|
||||
_setmode(_fileno(stdout), O_BINARY);
|
||||
#endif
|
||||
// Remove "-main" in src filename. Flags are defined in fstx.cc but SetFlags()
|
||||
// is called in fstx-main.cc, which results in a filename mismatch in
|
||||
// ShowUsageRestrict() below.
|
||||
static constexpr char kMainSuffix[] = "-main.cc";
|
||||
const int prefix_length = prog_src.size() - strlen(kMainSuffix);
|
||||
if (prefix_length > 0 && prog_src.substr(prefix_length) == kMainSuffix) {
|
||||
prog_src.erase(prefix_length, strlen("-main"));
|
||||
}
|
||||
}
|
||||
|
||||
void SetFlags(const char *usage, int *argc, char ***argv,
|
||||
bool remove_flags, const char *src) {
|
||||
flag_usage = usage;
|
||||
SetProgSrc(src);
|
||||
|
||||
int index = 1;
|
||||
for (; index < *argc; ++index) {
|
||||
string argval = (*argv)[index];
|
||||
if (argval[0] != '-' || argval == "-") break;
|
||||
while (argval[0] == '-') argval = argval.substr(1); // Removes initial '-'.
|
||||
string arg = argval;
|
||||
string val = "";
|
||||
// Splits argval (arg=val) into arg and val.
|
||||
auto pos = argval.find("=");
|
||||
if (pos != string::npos) {
|
||||
arg = argval.substr(0, pos);
|
||||
val = argval.substr(pos + 1);
|
||||
}
|
||||
auto bool_register = FlagRegister<bool>::GetRegister();
|
||||
if (bool_register->SetFlag(arg, val))
|
||||
continue;
|
||||
auto string_register = FlagRegister<string>::GetRegister();
|
||||
if (string_register->SetFlag(arg, val))
|
||||
continue;
|
||||
auto int32_register = FlagRegister<int32>::GetRegister();
|
||||
if (int32_register->SetFlag(arg, val))
|
||||
continue;
|
||||
auto int64_register = FlagRegister<int64>::GetRegister();
|
||||
if (int64_register->SetFlag(arg, val))
|
||||
continue;
|
||||
auto double_register = FlagRegister<double>::GetRegister();
|
||||
if (double_register->SetFlag(arg, val))
|
||||
continue;
|
||||
LOG(FATAL) << "SetFlags: Bad option: " << (*argv)[index];
|
||||
}
|
||||
if (remove_flags) {
|
||||
for (auto i = 0; i < *argc - index; ++i) {
|
||||
(*argv)[i + 1] = (*argv)[i + index];
|
||||
}
|
||||
*argc -= index - 1;
|
||||
}
|
||||
// if (FLAGS_help) {
|
||||
// ShowUsage(true);
|
||||
// exit(1);
|
||||
// }
|
||||
// if (FLAGS_helpshort) {
|
||||
// ShowUsage(false);
|
||||
// exit(1);
|
||||
// }
|
||||
}
|
||||
|
||||
// If flag is defined in file 'src' and 'in_src' true or is not
|
||||
// defined in file 'src' and 'in_src' is false, then print usage.
|
||||
static void
|
||||
ShowUsageRestrict(const std::set<pair<string, string>> &usage_set,
|
||||
const string &src, bool in_src, bool show_file) {
|
||||
string old_file;
|
||||
bool file_out = false;
|
||||
bool usage_out = false;
|
||||
for (const auto &pair : usage_set) {
|
||||
const auto &file = pair.first;
|
||||
const auto &usage = pair.second;
|
||||
bool match = file == src;
|
||||
if ((match && !in_src) || (!match && in_src)) continue;
|
||||
if (file != old_file) {
|
||||
if (show_file) {
|
||||
if (file_out) cout << "\n";
|
||||
cout << "Flags from: " << file << "\n";
|
||||
file_out = true;
|
||||
}
|
||||
old_file = file;
|
||||
}
|
||||
cout << usage << "\n";
|
||||
usage_out = true;
|
||||
}
|
||||
if (usage_out) cout << "\n";
|
||||
}
|
||||
|
||||
void ShowUsage(bool long_usage) {
|
||||
std::set<pair<string, string>> usage_set;
|
||||
cout << flag_usage << "\n";
|
||||
auto bool_register = FlagRegister<bool>::GetRegister();
|
||||
bool_register->GetUsage(&usage_set);
|
||||
auto string_register = FlagRegister<string>::GetRegister();
|
||||
string_register->GetUsage(&usage_set);
|
||||
auto int32_register = FlagRegister<int32>::GetRegister();
|
||||
int32_register->GetUsage(&usage_set);
|
||||
auto int64_register = FlagRegister<int64>::GetRegister();
|
||||
int64_register->GetUsage(&usage_set);
|
||||
auto double_register = FlagRegister<double>::GetRegister();
|
||||
double_register->GetUsage(&usage_set);
|
||||
if (!prog_src.empty()) {
|
||||
cout << "PROGRAM FLAGS:\n\n";
|
||||
ShowUsageRestrict(usage_set, prog_src, true, false);
|
||||
}
|
||||
if (!long_usage) return;
|
||||
if (!prog_src.empty()) cout << "LIBRARY FLAGS:\n\n";
|
||||
ShowUsageRestrict(usage_set, prog_src, false, true);
|
||||
}
|
@@ -12,7 +12,7 @@ All model training and evaluation code was tested on a computer running Ubuntu 2
|
||||
|
||||
## Training
|
||||
### Baseline RNN Model
|
||||
We have included a custom PyTorch implementation of the RNN model used in the paper (the paper used a TensorFlow implementation). This implementation aims to replicate or improve upon the original model's performance while leveraging PyTorch's features, resulting in a more efficient training process with a slight increase in decoding accuracy. This model includes day-specific input layers (512x512 linear input layers with softsign activation), a 5-layer GRU with 768 hidden units per layer, and a linear output layer. The model is trained to predict phonemes from neural data using CTC loss and the AdamW optimizer. Data is augmented with noise and temporal jitter to improve robustness. All model hyperparameters are specified in the `rnn_args.yaml` file.
|
||||
We have included a custom PyTorch implementation of the RNN model used in the paper (the paper used a TensorFlow implementation). This implementation aims to replicate or improve upon the original model's performance while leveraging PyTorch's features, resulting in a more efficient training process with a slight increase in decoding accuracy. This model includes day-specific input layers (512x512 linear input layers with softsign activation), a 5-layer GRU with 768 hidden units per layer, and a linear output layer. The model is trained to predict phonemes from neural data using CTC loss and the AdamW optimizer. Data is augmented with noise and temporal jitter to improve robustness. All model hyperparameters are specified in the [`rnn_args.yaml`](rnn_args.yaml) file.
|
||||
|
||||
### Model training script
|
||||
To train the baseline RNN model, use the `b2txt25` conda environment to run the `train_model.py` script from the `model_training` directory:
|
||||
@@ -20,7 +20,7 @@ To train the baseline RNN model, use the `b2txt25` conda environment to run the
|
||||
conda activate b2txt25
|
||||
python train_model.py
|
||||
```
|
||||
The model will train for 120,000 mini-batches (~3.5 hours on an RTX 4090) and should achieve an aggregate phoneme error rate of 10.1% on the validation partition. We note that the number of training batches and specific model hyperparameters may not be optimal here, and this baseline model is only meant to serve as an example. See `rnn_args.yaml` for a list of all hyperparameters.
|
||||
The model will train for 120,000 mini-batches (~3.5 hours on an RTX 4090) and should achieve an aggregate phoneme error rate of 10.1% on the validation partition. We note that the number of training batches and specific model hyperparameters may not be optimal here, and this baseline model is only meant to serve as an example. See [`rnn_args.yaml`](rnn_args.yaml) for a list of all hyperparameters.
|
||||
|
||||
## Evaluation
|
||||
### Start redis server
|
||||
|
Reference in New Issue
Block a user