additional documentation

This commit is contained in:
nckcard
2025-07-02 14:28:34 -07:00
parent b2d25bd606
commit d87462df51
6 changed files with 48 additions and 32 deletions

View File

@@ -23,18 +23,26 @@ The `language-model-standalone.py` script included here is made to work with the
3. Connect to the `localhost` redis server (or a different server, specified by the `--redis_ip` and `--redis_port` args)
4. Wait to receive phoneme logits via redis, and then make word predictions and pass them back via redis.
### `language-model-standalone.py` input args
See the bottom of the `language-model-standalone.py` script for a full list of input args.
### run a 1gram model
To run the 1gram language model from the root directory of this repository:
```bash
conda activate b2txt_lm
python language_model/language-model-standalone.py --lm_path language_model/pretrained_language_models/openwebtext_1gram_lm_sil --do_opt --nbest 100 --acoustic_scale 0.325 --blank_penalty 90 --alpha 0.55 --redis_ip localhost --gpu_number 0
```
### run a 3gram model
To run the 3gram language model from the root directory of this repository (requires ~60GB RAM):
```bash
conda activate b2txt_lm
python language_model/language-model-standalone.py --lm_path language_model/pretrained_language_models/openwebtext_3gram_lm_sil --do_opt --nbest 100 --acoustic_scale 0.325 --blank_penalty 90 --alpha 0.55 --redis_ip localhost --gpu_number 0
```
### run a 5gram model
To run the 5gram language model from the root directory of this repository (requires ~300GB of RAM):
```bash
conda activate b2txt_lm

View File

@@ -11,21 +11,10 @@ import lm_decoder
from functools import lru_cache
from transformers import AutoModelForCausalLM, AutoTokenizer
'''
Command string for normal chang lm:
~/miniconda3/envs/b2txt25/bin/python ~/brand/brand-modules/npl-davis/nodes/brainToText_closedLoop/language-model-standalone/language-model-standalone.py --lm_path ~/brand/LanguageModels/chang_lm_sil --acoustic_scale 0.8 --blank_penalty 2
Command string for normal open webtext lm:
~/miniconda3/envs/b2txt25/bin/python ~/brand/brand-modules/npl-davis/nodes/brainToText_closedLoop/language-model-standalone/language-model-standalone.py --lm_path ~/brand/LanguageModels/openwebtext_3gram_lm_sil --acoustic_scale 0.3 --blank_penalty 7 --nbest 100
Command string for giant language model:
~/miniconda3/envs/b2txt25/bin/python ~/brand/brand-modules/npl-davis/nodes/brainToText_closedLoop/language-model-standalone/language-model-standalone.py --lm_path ~/brand/LanguageModels/openwebtext_5gram_lm_sil_t15_names --do_opt 1 --nbest 100 --rescore 1 --acoustic_scale 0.325 --blank_penalty 90 --alpha 0.55
'''
# set up logging
logging.basicConfig(format='%(asctime)s %(levelname)s: %(message)s',level=logging.INFO)
# function for initializing the ngram decoder
def build_lm_decoder(
model_path,
max_active=7000,
@@ -73,6 +62,7 @@ def build_lm_decoder(
return decoder
# function for updating the ngram decoder parameters
def update_ngram_params(
ngramDecoder,
max_active=200,
@@ -98,6 +88,7 @@ def update_ngram_params(
ngramDecoder.SetOpt(decode_opts)
# function for initializing the OPT model and tokenizer
def build_opt(
model_name='facebook/opt-6.7b',
cache_dir=None,
@@ -131,6 +122,8 @@ def build_opt(
return model, tokenizer
# function for rescoring hypotheses with the GPT-2 model
@torch.inference_mode()
def rescore_with_gpt2(
model,
@@ -168,6 +161,7 @@ def rescore_with_gpt2(
return scores
# function for decoding with the GPT-2 model
def gpt2_lm_decode(
model,
tokenizer,
@@ -275,6 +269,7 @@ def get_current_redis_time_ms(redis_conn):
return int(t[0]*1000 + t[1]/1000)
# function to get string differences between two sentences
def get_string_differences(cue, decoder_output):
decoder_output_words = decoder_output.split()
cue_words = cue.split()
@@ -328,6 +323,7 @@ def remove_punctuation(sentence):
return sentence
# function to augment the nbest list by swapping words around, artificially increasing the number of candidates
def augment_nbest(nbest, top_candidates_to_augment=20, acoustic_scale=0.3, score_penalty_percent=0.01):
sentences = []
@@ -437,7 +433,6 @@ def main(args):
opt_cache_dir = args.opt_cache_dir
alpha = args.alpha
rescore = args.rescore
specific_word_bias = args.specific_word_bias
redis_ip = args.redis_ip
redis_port = args.redis_port
@@ -469,7 +464,6 @@ def main(args):
'rescore': rescore,
'top_candidates_to_augment': top_candidates_to_augment,
'score_penalty_percent': score_penalty_percent,
'specific_word_bias': specific_word_bias,
}
# pick GPU
@@ -681,7 +675,6 @@ def main(args):
rescore = int(entry_data.get(b'rescore', rescore))
top_candidates_to_augment = int(entry_data.get(b'top_candidates_to_augment', top_candidates_to_augment))
score_penalty_percent = float(entry_data.get(b'score_penalty_percent', score_penalty_percent))
specific_word_bias = float(entry_data.get(b'specific_word_bias', specific_word_bias))
# make sure that the update remote lm args are put into redis nicely
lm_args = {
@@ -701,7 +694,6 @@ def main(args):
'rescore': rescore,
'top_candidates_to_augment': top_candidates_to_augment,
'score_penalty_percent': score_penalty_percent,
'specific_word_bias': specific_word_bias,
}
r.xadd('remote_lm_args', lm_args)
@@ -732,8 +724,7 @@ def main(args):
f'\n\tdo_opt = {do_opt}' +
f'\n\trescore = {rescore}' +
f'\n\ttop_candidates_to_augment = {top_candidates_to_augment}' +
f'\n\tscore_penalty_percent = {score_penalty_percent}' +
f'\n\tspecific_word_bias = {specific_word_bias}'
f'\n\tscore_penalty_percent = {score_penalty_percent}'
)
r.xadd('remote_lm_done_updating_params', {'done': 1})
@@ -805,16 +796,15 @@ if __name__ == "__main__":
parser.add_argument('--ctc_blank_skip_threshold', type=float, default=1., help='ctc_blank_skip_threshold param for LM')
parser.add_argument('--length_penalty', type=float, default=0.0, help='length_penalty param for LM')
parser.add_argument('--acoustic_scale', type=float, default=0.3, help='Acoustic scale for LM')
parser.add_argument('--nbest', type=int, default=1, help='# of candidate sentences for LM decoding')
parser.add_argument('--nbest', type=int, default=100, help='# of candidate sentences for LM decoding')
parser.add_argument('--top_candidates_to_augment', type=int, default=20, help='# of top candidates to augment')
parser.add_argument('--score_penalty_percent', type=float, default=0.01, help='Score penalty percent for augmented candidates')
parser.add_argument('--blank_penalty', type=float, default=7.0, help='Blank penalty for LM')
parser.add_argument('--blank_penalty', type=float, default=9.0, help='Blank penalty for LM')
parser.add_argument('--rescore', action='store_true', help='Use an unpruned ngram model for rescoring?')
parser.add_argument('--do_opt', action='store_true', help='Use the opt model for rescoring?')
parser.add_argument('--opt_cache_dir', type=str, default=None, help='path to opt cache')
parser.add_argument('--alpha', type=float, default=0.6, help='alpha value [0-1]: Higher = more weight on OPT rescore. Lower = more weight on ngram rescore')
parser.add_argument('--specific_word_bias', type=float, default=0.10, help='percentage to bias the LM score for sentences with specific words')
parser.add_argument('--alpha', type=float, default=0.5, help='alpha value [0-1]: Higher = more weight on OPT rescore. Lower = more weight on ngram rescore')
parser.add_argument('--redis_ip', type=str, default='192.168.150.2', help='IP of the redis stream (string)')
parser.add_argument('--redis_port', type=int, default=6379, help='Port of the redis stream (int)')