diff --git a/README.md b/README.md index bceb4da..f63511a 100644 --- a/README.md +++ b/README.md @@ -25,6 +25,15 @@ The code is organized into four main directories: `utils`, `analyses`, `data`, a - The `model_training` directory contains the code necessary to train and evaluate the brain-to-text model. See the README.md in that folder for more detailed instructions. - The `language_model` directory contains the ngram language model implementation and a pretrained 1gram language model. Pretrained 3gram and 5gram language models can be downloaded [here](https://datadryad.org/dataset/doi:10.5061/dryad.x69p8czpq) (`languageModel.tar.gz` and `languageModel_5gram.tar.gz`). See the `README.md` in this directory for more information. +## Data +The data used in this repository consists of various datasets for recreating figures and training/evaluating the brain-to-text model: +- `t15_copyTask.pkl`: This file contains the online Copy Task results required for generating Figure 2. +- `t15_personalUse.pkl`: This file contains the Conversation Mode data required for generating Figure 4. +- `t15_copyTask_neuralData.zip`: This dataset contains the neural data for the Copy Task. There are more than 11,300 sentences from 45 sessions spanning 20 months. The data is split into training, validation, and test sets. Data for each session/split is stored in `.hdf5` files. An example of how to load this data using the Python `h5py` library is provided in the `model_training/evaluate_model_helpers.py` file in the `load_h5py_file()` function. +- `t15_pretrained_rnn_baseline.zip`: This dataset contains the pretrained RNN baseline model checkpoint and args. An example of how to load this model and use it for inference is provided in the `model_training/evaluate_model.py` file. + +Please download these datasets from [Dryad](https://datadryad.org/stash/dataset/doi:10.5061/dryad.dncjsxm85) and place them in the `data` directory. Be sure to unzip both datasets before running the code. + ## Dependencies - The code has only been tested on Ubuntu 22.04 with two NVIDIA RTX 4090 GPUs. - We recommend using a conda environment to manage the dependencies. To install miniconda, follow the instructions [here](https://docs.anaconda.com/miniconda/miniconda-install/). diff --git a/language_model/README.md b/language_model/README.md index 1ed27aa..b889f8e 100644 --- a/language_model/README.md +++ b/language_model/README.md @@ -23,18 +23,26 @@ The `language-model-standalone.py` script included here is made to work with the 3. Connect to the `localhost` redis server (or a different server, specified by the `--redis_ip` and `--redis_port` args) 4. Wait to receive phoneme logits via redis, and then make word predictions and pass them back via redis. + +### `language-model-standalone.py` input args +See the bottom of the `language-model-standalone.py` script for a full list of input args. + + +### run a 1gram model To run the 1gram language model from the root directory of this repository: ```bash conda activate b2txt_lm python language_model/language-model-standalone.py --lm_path language_model/pretrained_language_models/openwebtext_1gram_lm_sil --do_opt --nbest 100 --acoustic_scale 0.325 --blank_penalty 90 --alpha 0.55 --redis_ip localhost --gpu_number 0 ``` +### run a 3gram model To run the 3gram language model from the root directory of this repository (requires ~60GB RAM): ```bash conda activate b2txt_lm python language_model/language-model-standalone.py --lm_path language_model/pretrained_language_models/openwebtext_3gram_lm_sil --do_opt --nbest 100 --acoustic_scale 0.325 --blank_penalty 90 --alpha 0.55 --redis_ip localhost --gpu_number 0 ``` +### run a 5gram model To run the 5gram language model from the root directory of this repository (requires ~300GB of RAM): ```bash conda activate b2txt_lm diff --git a/language_model/language-model-standalone.py b/language_model/language-model-standalone.py index 3362e5d..173e322 100644 --- a/language_model/language-model-standalone.py +++ b/language_model/language-model-standalone.py @@ -11,21 +11,10 @@ import lm_decoder from functools import lru_cache from transformers import AutoModelForCausalLM, AutoTokenizer -''' -Command string for normal chang lm: -~/miniconda3/envs/b2txt25/bin/python ~/brand/brand-modules/npl-davis/nodes/brainToText_closedLoop/language-model-standalone/language-model-standalone.py --lm_path ~/brand/LanguageModels/chang_lm_sil --acoustic_scale 0.8 --blank_penalty 2 - -Command string for normal open webtext lm: -~/miniconda3/envs/b2txt25/bin/python ~/brand/brand-modules/npl-davis/nodes/brainToText_closedLoop/language-model-standalone/language-model-standalone.py --lm_path ~/brand/LanguageModels/openwebtext_3gram_lm_sil --acoustic_scale 0.3 --blank_penalty 7 --nbest 100 - -Command string for giant language model: -~/miniconda3/envs/b2txt25/bin/python ~/brand/brand-modules/npl-davis/nodes/brainToText_closedLoop/language-model-standalone/language-model-standalone.py --lm_path ~/brand/LanguageModels/openwebtext_5gram_lm_sil_t15_names --do_opt 1 --nbest 100 --rescore 1 --acoustic_scale 0.325 --blank_penalty 90 --alpha 0.55 -''' - +# set up logging logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s',level=logging.INFO) - - +# function for initializing the ngram decoder def build_lm_decoder( model_path, max_active=7000, @@ -73,6 +62,7 @@ def build_lm_decoder( return decoder +# function for updating the ngram decoder parameters def update_ngram_params( ngramDecoder, max_active=200, @@ -98,6 +88,7 @@ def update_ngram_params( ngramDecoder.SetOpt(decode_opts) +# function for initializing the OPT model and tokenizer def build_opt( model_name='facebook/opt-6.7b', cache_dir=None, @@ -131,6 +122,8 @@ def build_opt( return model, tokenizer + +# function for rescoring hypotheses with the GPT-2 model @torch.inference_mode() def rescore_with_gpt2( model, @@ -168,6 +161,7 @@ def rescore_with_gpt2( return scores +# function for decoding with the GPT-2 model def gpt2_lm_decode( model, tokenizer, @@ -275,6 +269,7 @@ def get_current_redis_time_ms(redis_conn): return int(t[0]*1000 + t[1]/1000) +# function to get string differences between two sentences def get_string_differences(cue, decoder_output): decoder_output_words = decoder_output.split() cue_words = cue.split() @@ -328,6 +323,7 @@ def remove_punctuation(sentence): return sentence +# function to augment the nbest list by swapping words around, artificially increasing the number of candidates def augment_nbest(nbest, top_candidates_to_augment=20, acoustic_scale=0.3, score_penalty_percent=0.01): sentences = [] @@ -437,7 +433,6 @@ def main(args): opt_cache_dir = args.opt_cache_dir alpha = args.alpha rescore = args.rescore - specific_word_bias = args.specific_word_bias redis_ip = args.redis_ip redis_port = args.redis_port @@ -469,7 +464,6 @@ def main(args): 'rescore': rescore, 'top_candidates_to_augment': top_candidates_to_augment, 'score_penalty_percent': score_penalty_percent, - 'specific_word_bias': specific_word_bias, } # pick GPU @@ -681,7 +675,6 @@ def main(args): rescore = int(entry_data.get(b'rescore', rescore)) top_candidates_to_augment = int(entry_data.get(b'top_candidates_to_augment', top_candidates_to_augment)) score_penalty_percent = float(entry_data.get(b'score_penalty_percent', score_penalty_percent)) - specific_word_bias = float(entry_data.get(b'specific_word_bias', specific_word_bias)) # make sure that the update remote lm args are put into redis nicely lm_args = { @@ -701,7 +694,6 @@ def main(args): 'rescore': rescore, 'top_candidates_to_augment': top_candidates_to_augment, 'score_penalty_percent': score_penalty_percent, - 'specific_word_bias': specific_word_bias, } r.xadd('remote_lm_args', lm_args) @@ -732,8 +724,7 @@ def main(args): f'\n\tdo_opt = {do_opt}' + f'\n\trescore = {rescore}' + f'\n\ttop_candidates_to_augment = {top_candidates_to_augment}' + - f'\n\tscore_penalty_percent = {score_penalty_percent}' + - f'\n\tspecific_word_bias = {specific_word_bias}' + f'\n\tscore_penalty_percent = {score_penalty_percent}' ) r.xadd('remote_lm_done_updating_params', {'done': 1}) @@ -805,16 +796,15 @@ if __name__ == "__main__": parser.add_argument('--ctc_blank_skip_threshold', type=float, default=1., help='ctc_blank_skip_threshold param for LM') parser.add_argument('--length_penalty', type=float, default=0.0, help='length_penalty param for LM') parser.add_argument('--acoustic_scale', type=float, default=0.3, help='Acoustic scale for LM') - parser.add_argument('--nbest', type=int, default=1, help='# of candidate sentences for LM decoding') + parser.add_argument('--nbest', type=int, default=100, help='# of candidate sentences for LM decoding') parser.add_argument('--top_candidates_to_augment', type=int, default=20, help='# of top candidates to augment') parser.add_argument('--score_penalty_percent', type=float, default=0.01, help='Score penalty percent for augmented candidates') - parser.add_argument('--blank_penalty', type=float, default=7.0, help='Blank penalty for LM') + parser.add_argument('--blank_penalty', type=float, default=9.0, help='Blank penalty for LM') parser.add_argument('--rescore', action='store_true', help='Use an unpruned ngram model for rescoring?') parser.add_argument('--do_opt', action='store_true', help='Use the opt model for rescoring?') parser.add_argument('--opt_cache_dir', type=str, default=None, help='path to opt cache') - parser.add_argument('--alpha', type=float, default=0.6, help='alpha value [0-1]: Higher = more weight on OPT rescore. Lower = more weight on ngram rescore') - parser.add_argument('--specific_word_bias', type=float, default=0.10, help='percentage to bias the LM score for sentences with specific words') + parser.add_argument('--alpha', type=float, default=0.5, help='alpha value [0-1]: Higher = more weight on OPT rescore. Lower = more weight on ngram rescore') parser.add_argument('--redis_ip', type=str, default='192.168.150.2', help='IP of the redis stream (string)') parser.add_argument('--redis_port', type=int, default=6379, help='Port of the redis stream (int)') diff --git a/model_training/README.md b/model_training/README.md index ffe567c..8620b21 100644 --- a/model_training/README.md +++ b/model_training/README.md @@ -11,7 +11,11 @@ All model training and evaluation code was tested on a computer running Ubuntu 2 2. Download the dataset from Dryad: [Dryad Dataset](https://datadryad.org/dataset/doi:10.5061/dryad.dncjsxm85). Place the downloaded data in the `data` directory. Be sure to unzip `t15_copyTask_neuralData.zip` and `t15_pretrained_rnn_baseline.zip`. ## Training -To train the baseline RNN model, run the following command from the `model_training` directory: +### Baseline RNN Model +We have included a custom PyTorch implementation of the RNN model used in the paper (the paper used a TensorFlow implementation). This implementation aims to replicate or improve upon the original model's performance while leveraging PyTorch's features, resulting in a more efficient training process with a slight increase in decoding accuracy. This model includes day-specific input layers (512x512 linear input layers with softsign activation), a 5-layer GRU with 768 hidden units per layer, and a linear output layer. The model is trained to predict phonemes from neural data using CTC loss and the AdamW optimizer. Data is augmented with noise and temporal jitter to improve robustness. All model hyperparameters are specified in the `rnn_args.yaml` file. + +### Model training script +To train the baseline RNN model, use the `b2txt25` conda environment to run the `train_model.py` script from the `model_training` directory: ```bash conda activate b2txt25 python train_model.py @@ -20,13 +24,13 @@ The model will train for 120,000 mini-batches (~3.5 hours on an RTX 4090) and sh ## Evaluation ### Start redis server -To evaluate the model, first start a redis server in terminal with: +To evaluate the model, first start a redis server on `localhost` in terminal with: ```bash redis-server ``` ### Start language model -Next, start the ngram language model in a seperate terminal window. For example, the 1gram language model can be started using the command below. Note that the 1gram model has no gramatical structure built into it. Details on downloading pretrained 3gram and 5gram language models and running them can be found in the README.md in the `language_model` directory. +Next, use the `b2txt25_lm` conda environment to start the ngram language model in a seperate terminal window. For example, the 1gram language model can be started using the command below. Note that the 1gram model has no gramatical structure built into it. Details on downloading pretrained 3gram and 5gram language models and running them can be found in the README.md in the `language_model` directory. To run the 1gram language model from the root directory of this repository: ```bash conda activate b2txt_lm @@ -34,7 +38,7 @@ python language_model/language-model-standalone.py --lm_path language_model/pret ``` ### Evaluate -Finally, run the `evaluate_model.py` script to load the pretrained baseline RNN, use it for inference on the heldout val or test sets to get phoneme logits, pass them through the language model via redis to get word predictions, and then save the predicted sentences to a .txt file in the format required for competition submission. +Finally, use the `b2txt25` conda environment to run the `evaluate_model.py` script to load the pretrained baseline RNN, use it for inference on the heldout val or test sets to get phoneme logits, pass them through the language model via redis to get word predictions, and then save the predicted sentences to a .txt file in the format required for competition submission. An example output file for the val split can be found at `rnn_baseline_submission_file_valsplit.txt`. ```bash conda activate b2txt25 python evaluate_model.py --model_path ../data/t15_pretrained_rnn_baseline --data_dir ../data/t15_copyTask_neuralData --eval_type test --gpu_number 1 diff --git a/model_training/evaluate_model.py b/model_training/evaluate_model.py index 511deef..97879fe 100644 --- a/model_training/evaluate_model.py +++ b/model_training/evaluate_model.py @@ -262,4 +262,9 @@ if eval_type == 'val': output_file = os.path.join(model_path, f'baseline_rnn_{eval_type}_predicted_sentences_{time.strftime("%Y%m%d_%H%M%S")}.txt') with open(output_file, 'w') as f: for i in range(len(lm_results['pred_sentence'])): - f.write(f"{remove_punctuation(lm_results['pred_sentence'][i])}\n") \ No newline at end of file + if i < len(lm_results['pred_sentence']) - 1: + # write sentence + newline + f.write(f"{remove_punctuation(lm_results['pred_sentence'][i])}\n") + else: + # don't add a newline at the end of the last sentence + f.write(f"{remove_punctuation(lm_results['pred_sentence'][i])}") \ No newline at end of file diff --git a/model_training/rnn_args.yaml b/model_training/rnn_args.yaml index db2c966..63f194f 100644 --- a/model_training/rnn_args.yaml +++ b/model_training/rnn_args.yaml @@ -130,7 +130,7 @@ dataset: - t15.2025.03.30 - t15.2025.04.13 dataset_probability_val: # probability of including a trial in the validation set (0 or 1) - - 0 + - 0 # no val or test data from this day - 1 - 1 - 1 @@ -158,12 +158,12 @@ dataset: - 1 - 1 - 1 - - 0 + - 0 # no val or test data from this day - 1 - 1 - 1 - - 0 - - 0 + - 0 # no val or test data from this day + - 0 # no val or test data from this day - 1 - 1 - 1