additional documentation

This commit is contained in:
nckcard
2025-07-02 14:28:34 -07:00
parent b2d25bd606
commit d87462df51
6 changed files with 48 additions and 32 deletions

View File

@@ -25,6 +25,15 @@ The code is organized into four main directories: `utils`, `analyses`, `data`, a
- The `model_training` directory contains the code necessary to train and evaluate the brain-to-text model. See the README.md in that folder for more detailed instructions.
- The `language_model` directory contains the ngram language model implementation and a pretrained 1gram language model. Pretrained 3gram and 5gram language models can be downloaded [here](https://datadryad.org/dataset/doi:10.5061/dryad.x69p8czpq) (`languageModel.tar.gz` and `languageModel_5gram.tar.gz`). See the `README.md` in this directory for more information.
## Data
The data used in this repository consists of various datasets for recreating figures and training/evaluating the brain-to-text model:
- `t15_copyTask.pkl`: This file contains the online Copy Task results required for generating Figure 2.
- `t15_personalUse.pkl`: This file contains the Conversation Mode data required for generating Figure 4.
- `t15_copyTask_neuralData.zip`: This dataset contains the neural data for the Copy Task. There are more than 11,300 sentences from 45 sessions spanning 20 months. The data is split into training, validation, and test sets. Data for each session/split is stored in `.hdf5` files. An example of how to load this data using the Python `h5py` library is provided in the `model_training/evaluate_model_helpers.py` file in the `load_h5py_file()` function.
- `t15_pretrained_rnn_baseline.zip`: This dataset contains the pretrained RNN baseline model checkpoint and args. An example of how to load this model and use it for inference is provided in the `model_training/evaluate_model.py` file.
Please download these datasets from [Dryad](https://datadryad.org/stash/dataset/doi:10.5061/dryad.dncjsxm85) and place them in the `data` directory. Be sure to unzip both datasets before running the code.
## Dependencies
- The code has only been tested on Ubuntu 22.04 with two NVIDIA RTX 4090 GPUs.
- We recommend using a conda environment to manage the dependencies. To install miniconda, follow the instructions [here](https://docs.anaconda.com/miniconda/miniconda-install/).

View File

@@ -23,18 +23,26 @@ The `language-model-standalone.py` script included here is made to work with the
3. Connect to the `localhost` redis server (or a different server, specified by the `--redis_ip` and `--redis_port` args)
4. Wait to receive phoneme logits via redis, and then make word predictions and pass them back via redis.
### `language-model-standalone.py` input args
See the bottom of the `language-model-standalone.py` script for a full list of input args.
### run a 1gram model
To run the 1gram language model from the root directory of this repository:
```bash
conda activate b2txt_lm
python language_model/language-model-standalone.py --lm_path language_model/pretrained_language_models/openwebtext_1gram_lm_sil --do_opt --nbest 100 --acoustic_scale 0.325 --blank_penalty 90 --alpha 0.55 --redis_ip localhost --gpu_number 0
```
### run a 3gram model
To run the 3gram language model from the root directory of this repository (requires ~60GB RAM):
```bash
conda activate b2txt_lm
python language_model/language-model-standalone.py --lm_path language_model/pretrained_language_models/openwebtext_3gram_lm_sil --do_opt --nbest 100 --acoustic_scale 0.325 --blank_penalty 90 --alpha 0.55 --redis_ip localhost --gpu_number 0
```
### run a 5gram model
To run the 5gram language model from the root directory of this repository (requires ~300GB of RAM):
```bash
conda activate b2txt_lm

View File

@@ -11,21 +11,10 @@ import lm_decoder
from functools import lru_cache
from transformers import AutoModelForCausalLM, AutoTokenizer
'''
Command string for normal chang lm:
~/miniconda3/envs/b2txt25/bin/python ~/brand/brand-modules/npl-davis/nodes/brainToText_closedLoop/language-model-standalone/language-model-standalone.py --lm_path ~/brand/LanguageModels/chang_lm_sil --acoustic_scale 0.8 --blank_penalty 2
Command string for normal open webtext lm:
~/miniconda3/envs/b2txt25/bin/python ~/brand/brand-modules/npl-davis/nodes/brainToText_closedLoop/language-model-standalone/language-model-standalone.py --lm_path ~/brand/LanguageModels/openwebtext_3gram_lm_sil --acoustic_scale 0.3 --blank_penalty 7 --nbest 100
Command string for giant language model:
~/miniconda3/envs/b2txt25/bin/python ~/brand/brand-modules/npl-davis/nodes/brainToText_closedLoop/language-model-standalone/language-model-standalone.py --lm_path ~/brand/LanguageModels/openwebtext_5gram_lm_sil_t15_names --do_opt 1 --nbest 100 --rescore 1 --acoustic_scale 0.325 --blank_penalty 90 --alpha 0.55
'''
# set up logging
logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s',level=logging.INFO)
# function for initializing the ngram decoder
def build_lm_decoder(
model_path,
max_active=7000,
@@ -73,6 +62,7 @@ def build_lm_decoder(
return decoder
# function for updating the ngram decoder parameters
def update_ngram_params(
ngramDecoder,
max_active=200,
@@ -98,6 +88,7 @@ def update_ngram_params(
ngramDecoder.SetOpt(decode_opts)
# function for initializing the OPT model and tokenizer
def build_opt(
model_name='facebook/opt-6.7b',
cache_dir=None,
@@ -131,6 +122,8 @@ def build_opt(
return model, tokenizer
# function for rescoring hypotheses with the GPT-2 model
@torch.inference_mode()
def rescore_with_gpt2(
model,
@@ -168,6 +161,7 @@ def rescore_with_gpt2(
return scores
# function for decoding with the GPT-2 model
def gpt2_lm_decode(
model,
tokenizer,
@@ -275,6 +269,7 @@ def get_current_redis_time_ms(redis_conn):
return int(t[0]*1000 + t[1]/1000)
# function to get string differences between two sentences
def get_string_differences(cue, decoder_output):
decoder_output_words = decoder_output.split()
cue_words = cue.split()
@@ -328,6 +323,7 @@ def remove_punctuation(sentence):
return sentence
# function to augment the nbest list by swapping words around, artificially increasing the number of candidates
def augment_nbest(nbest, top_candidates_to_augment=20, acoustic_scale=0.3, score_penalty_percent=0.01):
sentences = []
@@ -437,7 +433,6 @@ def main(args):
opt_cache_dir = args.opt_cache_dir
alpha = args.alpha
rescore = args.rescore
specific_word_bias = args.specific_word_bias
redis_ip = args.redis_ip
redis_port = args.redis_port
@@ -469,7 +464,6 @@ def main(args):
'rescore': rescore,
'top_candidates_to_augment': top_candidates_to_augment,
'score_penalty_percent': score_penalty_percent,
'specific_word_bias': specific_word_bias,
}
# pick GPU
@@ -681,7 +675,6 @@ def main(args):
rescore = int(entry_data.get(b'rescore', rescore))
top_candidates_to_augment = int(entry_data.get(b'top_candidates_to_augment', top_candidates_to_augment))
score_penalty_percent = float(entry_data.get(b'score_penalty_percent', score_penalty_percent))
specific_word_bias = float(entry_data.get(b'specific_word_bias', specific_word_bias))
# make sure that the update remote lm args are put into redis nicely
lm_args = {
@@ -701,7 +694,6 @@ def main(args):
'rescore': rescore,
'top_candidates_to_augment': top_candidates_to_augment,
'score_penalty_percent': score_penalty_percent,
'specific_word_bias': specific_word_bias,
}
r.xadd('remote_lm_args', lm_args)
@@ -732,8 +724,7 @@ def main(args):
f'\n\tdo_opt = {do_opt}' +
f'\n\trescore = {rescore}' +
f'\n\ttop_candidates_to_augment = {top_candidates_to_augment}' +
f'\n\tscore_penalty_percent = {score_penalty_percent}' +
f'\n\tspecific_word_bias = {specific_word_bias}'
f'\n\tscore_penalty_percent = {score_penalty_percent}'
)
r.xadd('remote_lm_done_updating_params', {'done': 1})
@@ -805,16 +796,15 @@ if __name__ == "__main__":
parser.add_argument('--ctc_blank_skip_threshold', type=float, default=1., help='ctc_blank_skip_threshold param for LM')
parser.add_argument('--length_penalty', type=float, default=0.0, help='length_penalty param for LM')
parser.add_argument('--acoustic_scale', type=float, default=0.3, help='Acoustic scale for LM')
parser.add_argument('--nbest', type=int, default=1, help='# of candidate sentences for LM decoding')
parser.add_argument('--nbest', type=int, default=100, help='# of candidate sentences for LM decoding')
parser.add_argument('--top_candidates_to_augment', type=int, default=20, help='# of top candidates to augment')
parser.add_argument('--score_penalty_percent', type=float, default=0.01, help='Score penalty percent for augmented candidates')
parser.add_argument('--blank_penalty', type=float, default=7.0, help='Blank penalty for LM')
parser.add_argument('--blank_penalty', type=float, default=9.0, help='Blank penalty for LM')
parser.add_argument('--rescore', action='store_true', help='Use an unpruned ngram model for rescoring?')
parser.add_argument('--do_opt', action='store_true', help='Use the opt model for rescoring?')
parser.add_argument('--opt_cache_dir', type=str, default=None, help='path to opt cache')
parser.add_argument('--alpha', type=float, default=0.6, help='alpha value [0-1]: Higher = more weight on OPT rescore. Lower = more weight on ngram rescore')
parser.add_argument('--specific_word_bias', type=float, default=0.10, help='percentage to bias the LM score for sentences with specific words')
parser.add_argument('--alpha', type=float, default=0.5, help='alpha value [0-1]: Higher = more weight on OPT rescore. Lower = more weight on ngram rescore')
parser.add_argument('--redis_ip', type=str, default='192.168.150.2', help='IP of the redis stream (string)')
parser.add_argument('--redis_port', type=int, default=6379, help='Port of the redis stream (int)')

View File

@@ -11,7 +11,11 @@ All model training and evaluation code was tested on a computer running Ubuntu 2
2. Download the dataset from Dryad: [Dryad Dataset](https://datadryad.org/dataset/doi:10.5061/dryad.dncjsxm85). Place the downloaded data in the `data` directory. Be sure to unzip `t15_copyTask_neuralData.zip` and `t15_pretrained_rnn_baseline.zip`.
## Training
To train the baseline RNN model, run the following command from the `model_training` directory:
### Baseline RNN Model
We have included a custom PyTorch implementation of the RNN model used in the paper (the paper used a TensorFlow implementation). This implementation aims to replicate or improve upon the original model's performance while leveraging PyTorch's features, resulting in a more efficient training process with a slight increase in decoding accuracy. This model includes day-specific input layers (512x512 linear input layers with softsign activation), a 5-layer GRU with 768 hidden units per layer, and a linear output layer. The model is trained to predict phonemes from neural data using CTC loss and the AdamW optimizer. Data is augmented with noise and temporal jitter to improve robustness. All model hyperparameters are specified in the `rnn_args.yaml` file.
### Model training script
To train the baseline RNN model, use the `b2txt25` conda environment to run the `train_model.py` script from the `model_training` directory:
```bash
conda activate b2txt25
python train_model.py
@@ -20,13 +24,13 @@ The model will train for 120,000 mini-batches (~3.5 hours on an RTX 4090) and sh
## Evaluation
### Start redis server
To evaluate the model, first start a redis server in terminal with:
To evaluate the model, first start a redis server on `localhost` in terminal with:
```bash
redis-server
```
### Start language model
Next, start the ngram language model in a seperate terminal window. For example, the 1gram language model can be started using the command below. Note that the 1gram model has no gramatical structure built into it. Details on downloading pretrained 3gram and 5gram language models and running them can be found in the README.md in the `language_model` directory.
Next, use the `b2txt25_lm` conda environment to start the ngram language model in a seperate terminal window. For example, the 1gram language model can be started using the command below. Note that the 1gram model has no gramatical structure built into it. Details on downloading pretrained 3gram and 5gram language models and running them can be found in the README.md in the `language_model` directory.
To run the 1gram language model from the root directory of this repository:
```bash
conda activate b2txt_lm
@@ -34,7 +38,7 @@ python language_model/language-model-standalone.py --lm_path language_model/pret
```
### Evaluate
Finally, run the `evaluate_model.py` script to load the pretrained baseline RNN, use it for inference on the heldout val or test sets to get phoneme logits, pass them through the language model via redis to get word predictions, and then save the predicted sentences to a .txt file in the format required for competition submission.
Finally, use the `b2txt25` conda environment to run the `evaluate_model.py` script to load the pretrained baseline RNN, use it for inference on the heldout val or test sets to get phoneme logits, pass them through the language model via redis to get word predictions, and then save the predicted sentences to a .txt file in the format required for competition submission. An example output file for the val split can be found at `rnn_baseline_submission_file_valsplit.txt`.
```bash
conda activate b2txt25
python evaluate_model.py --model_path ../data/t15_pretrained_rnn_baseline --data_dir ../data/t15_copyTask_neuralData --eval_type test --gpu_number 1

View File

@@ -262,4 +262,9 @@ if eval_type == 'val':
output_file = os.path.join(model_path, f'baseline_rnn_{eval_type}_predicted_sentences_{time.strftime("%Y%m%d_%H%M%S")}.txt')
with open(output_file, 'w') as f:
for i in range(len(lm_results['pred_sentence'])):
f.write(f"{remove_punctuation(lm_results['pred_sentence'][i])}\n")
if i < len(lm_results['pred_sentence']) - 1:
# write sentence + newline
f.write(f"{remove_punctuation(lm_results['pred_sentence'][i])}\n")
else:
# don't add a newline at the end of the last sentence
f.write(f"{remove_punctuation(lm_results['pred_sentence'][i])}")

View File

@@ -130,7 +130,7 @@ dataset:
- t15.2025.03.30
- t15.2025.04.13
dataset_probability_val: # probability of including a trial in the validation set (0 or 1)
- 0
- 0 # no val or test data from this day
- 1
- 1
- 1
@@ -158,12 +158,12 @@ dataset:
- 1
- 1
- 1
- 0
- 0 # no val or test data from this day
- 1
- 1
- 1
- 0
- 0
- 0 # no val or test data from this day
- 0 # no val or test data from this day
- 1
- 1
- 1