f
This commit is contained in:
@@ -647,43 +647,38 @@ class BrainToTextDecoderTrainerTF:
|
||||
loss = self.ctc_loss(loss_input, logits)
|
||||
loss = tf.reduce_mean(loss)
|
||||
|
||||
# Calculate PER (Phoneme Error Rate)
|
||||
# Calculate simplified PER approximation (TPU-compatible)
|
||||
# For TPU training, we use a simplified metric that avoids complex loops
|
||||
# This gives an approximation of PER but is much faster and TPU-compatible
|
||||
|
||||
# Greedy decoding
|
||||
predicted_ids = tf.argmax(logits, axis=-1)
|
||||
|
||||
# Remove blanks and consecutive duplicates
|
||||
batch_edit_distance = 0
|
||||
for i in range(tf.shape(logits)[0]):
|
||||
pred_seq = predicted_ids[i, :adjusted_lens[i]]
|
||||
# Remove consecutive duplicates
|
||||
pred_seq = tf.py_function(
|
||||
func=lambda x: tf.constant([x[0]] + [x[j] for j in range(1, len(x)) if x[j] != x[j-1]]),
|
||||
inp=[pred_seq],
|
||||
Tout=tf.int64
|
||||
)
|
||||
# Remove blanks (assuming blank_index=0)
|
||||
pred_seq = tf.boolean_mask(pred_seq, pred_seq != 0)
|
||||
# Simple approximation: count exact matches vs mismatches
|
||||
# This is less accurate than true edit distance but TPU-compatible
|
||||
batch_size = tf.shape(logits)[0]
|
||||
|
||||
# For each sample, compare predicted vs true sequences
|
||||
total_mismatches = tf.constant(0, dtype=tf.int32)
|
||||
|
||||
for i in tf.range(batch_size):
|
||||
# Get sequences for this sample
|
||||
pred_seq = predicted_ids[i, :adjusted_lens[i]]
|
||||
true_seq = labels[i, :phone_seq_lens[i]]
|
||||
|
||||
# Calculate edit distance
|
||||
edit_dist = tf.edit_distance(
|
||||
tf.SparseTensor(
|
||||
indices=tf.expand_dims(tf.range(tf.size(pred_seq)), 1),
|
||||
values=tf.cast(pred_seq, tf.int64),
|
||||
dense_shape=[tf.size(pred_seq)]
|
||||
),
|
||||
tf.SparseTensor(
|
||||
indices=tf.expand_dims(tf.range(tf.size(true_seq)), 1),
|
||||
values=tf.cast(true_seq, tf.int64),
|
||||
dense_shape=[tf.size(true_seq)]
|
||||
),
|
||||
normalize=False
|
||||
)
|
||||
# Pad to same length for comparison
|
||||
max_len = tf.maximum(tf.shape(pred_seq)[0], tf.shape(true_seq)[0])
|
||||
pred_padded = tf.pad(pred_seq, [[0, max_len - tf.shape(pred_seq)[0]]], constant_values=0)
|
||||
true_padded = tf.pad(true_seq, [[0, max_len - tf.shape(true_seq)[0]]], constant_values=0)
|
||||
|
||||
batch_edit_distance += edit_dist
|
||||
# Count mismatches
|
||||
mismatches = tf.reduce_sum(tf.cast(tf.not_equal(pred_padded, true_padded), tf.int32))
|
||||
total_mismatches += mismatches
|
||||
|
||||
return loss, batch_edit_distance, tf.reduce_sum(phone_seq_lens)
|
||||
# Approximate edit distance as number of mismatches
|
||||
batch_edit_distance = tf.cast(total_mismatches, tf.float32)
|
||||
|
||||
return loss, batch_edit_distance, tf.cast(tf.reduce_sum(phone_seq_lens), tf.float32)
|
||||
|
||||
def train(self) -> Dict[str, Any]:
|
||||
"""Main training loop"""
|
||||
|
Reference in New Issue
Block a user