Copy Task figure and environment setup
This commit is contained in:
156
nejm_b2txt_utils/general_utils.py
Normal file
156
nejm_b2txt_utils/general_utils.py
Normal file
@@ -0,0 +1,156 @@
|
||||
import numpy as np
|
||||
import re
|
||||
from g2p_en import G2p
|
||||
|
||||
|
||||
|
||||
LOGIT_PHONE_DEF = [
|
||||
'BLANK', 'SIL', # blank and silence
|
||||
'AA', 'AE', 'AH', 'AO', 'AW',
|
||||
'AY', 'B', 'CH', 'D', 'DH',
|
||||
'EH', 'ER', 'EY', 'F', 'G',
|
||||
'HH', 'IH', 'IY', 'JH', 'K',
|
||||
'L', 'M', 'N', 'NG', 'OW',
|
||||
'OY', 'P', 'R', 'S', 'SH',
|
||||
'T', 'TH', 'UH', 'UW', 'V',
|
||||
'W', 'Y', 'Z', 'ZH'
|
||||
]
|
||||
SIL_DEF = ['SIL']
|
||||
|
||||
|
||||
# remove puntuation from text
|
||||
def remove_punctuation(sentence):
|
||||
# Remove punctuation
|
||||
sentence = re.sub(r'[^a-zA-Z\- \']', '', sentence)
|
||||
sentence = sentence.replace('--', '').lower()
|
||||
sentence = sentence.replace(" '", "'").lower()
|
||||
|
||||
sentence = sentence.strip()
|
||||
sentence = ' '.join(sentence.split())
|
||||
|
||||
return sentence
|
||||
|
||||
|
||||
# Convert RNN logits to argmax phonemes
|
||||
def logits_to_phonemes(logits):
|
||||
seq = np.argmax(logits, axis=1)
|
||||
seq2 = np.array([seq[0]] + [seq[i] for i in range(1, len(seq)) if seq[i] != seq[i-1]])
|
||||
|
||||
phones = []
|
||||
for i in range(len(seq2)):
|
||||
phones.append(LOGIT_PHONE_DEF[seq2[i]])
|
||||
|
||||
# Remove blank and repeated phonemes
|
||||
phones = [p for p in phones if p!='BLANK']
|
||||
phones = [phones[0]] + [phones[i] for i in range(1, len(phones)) if phones[i] != phones[i-1]]
|
||||
|
||||
return phones
|
||||
|
||||
|
||||
# Convert text to phonemes
|
||||
def sentence_to_phonemes(thisTranscription, g2p_instance=None):
|
||||
if not g2p_instance:
|
||||
g2p_instance = G2p()
|
||||
|
||||
# Remove punctuation
|
||||
thisTranscription = remove_punctuation(thisTranscription)
|
||||
|
||||
# Convert to phonemes
|
||||
phonemes = []
|
||||
if len(thisTranscription) == 0:
|
||||
phonemes = SIL_DEF
|
||||
else:
|
||||
for p in g2p_instance(thisTranscription):
|
||||
if p==' ':
|
||||
phonemes.append('SIL')
|
||||
|
||||
p = re.sub(r'[0-9]', '', p) # Remove stress
|
||||
if re.match(r'[A-Z]+', p): # Only keep phonemes
|
||||
phonemes.append(p)
|
||||
|
||||
#add one SIL symbol at the end so there's one at the end of each word
|
||||
phonemes.append('SIL')
|
||||
|
||||
return phonemes, thisTranscription
|
||||
|
||||
|
||||
# Calculate WER or PER
|
||||
def calculate_error_rate(r, h):
|
||||
"""
|
||||
Calculation of WER or PER with Levenshtein distance.
|
||||
Works only for iterables up to 254 elements (uint8).
|
||||
O(nm) time ans space complexity.
|
||||
----------
|
||||
Parameters:
|
||||
r : list of true words or phonemes
|
||||
h : list of predicted words or phonemes
|
||||
----------
|
||||
Returns:
|
||||
Word error rate (WER) or phoneme error rate (PER) [int]
|
||||
----------
|
||||
Examples:
|
||||
>>> calculate_wer("who is there".split(), "is there".split())
|
||||
1
|
||||
>>> calculate_wer("who is there".split(), "".split())
|
||||
3
|
||||
>>> calculate_wer("".split(), "who is there".split())
|
||||
3
|
||||
"""
|
||||
# initialization
|
||||
d = np.zeros((len(r)+1)*(len(h)+1), dtype=np.uint8)
|
||||
d = d.reshape((len(r)+1, len(h)+1))
|
||||
for i in range(len(r)+1):
|
||||
for j in range(len(h)+1):
|
||||
if i == 0:
|
||||
d[0][j] = j
|
||||
elif j == 0:
|
||||
d[i][0] = i
|
||||
|
||||
# computation
|
||||
for i in range(1, len(r)+1):
|
||||
for j in range(1, len(h)+1):
|
||||
if r[i-1] == h[j-1]:
|
||||
d[i][j] = d[i-1][j-1]
|
||||
else:
|
||||
substitution = d[i-1][j-1] + 1
|
||||
insertion = d[i][j-1] + 1
|
||||
deletion = d[i-1][j] + 1
|
||||
d[i][j] = min(substitution, insertion, deletion)
|
||||
|
||||
return d[len(r)][len(h)]
|
||||
|
||||
|
||||
# calculate aggregate WER or PER
|
||||
def calculate_aggregate_error_rate(r, h):
|
||||
|
||||
# list setup
|
||||
err_count = []
|
||||
item_count = []
|
||||
error_rate_ind = []
|
||||
|
||||
# calculate individual error rates
|
||||
for x in range(len(h)):
|
||||
r_x = r[x]
|
||||
h_x = h[x]
|
||||
|
||||
n_err = calculate_error_rate(r_x, h_x)
|
||||
|
||||
item_count.append(len(r_x))
|
||||
err_count.append(n_err)
|
||||
error_rate_ind.append(n_err / len(r_x))
|
||||
|
||||
# Calculate aggregate error rate
|
||||
error_rate_agg = np.sum(err_count) / np.sum(item_count)
|
||||
|
||||
# calculate 95% CI
|
||||
item_count = np.array(item_count)
|
||||
err_count = np.array(err_count)
|
||||
nResamples = 10000
|
||||
resampled_error_rate = np.zeros([nResamples,])
|
||||
for n in range(nResamples):
|
||||
resampleIdx = np.random.randint(0, item_count.shape[0], [item_count.shape[0]])
|
||||
resampled_error_rate[n] = np.sum(err_count[resampleIdx]) / np.sum(item_count[resampleIdx])
|
||||
error_rate_agg_CI = np.percentile(resampled_error_rate, [2.5, 97.5])
|
||||
|
||||
# return everything as a tuple
|
||||
return (error_rate_agg, error_rate_agg_CI[0], error_rate_agg_CI[1], error_rate_ind)
|
Reference in New Issue
Block a user