f
This commit is contained in:
@@ -851,23 +851,25 @@ def analyze_dataset_shapes(dataset_tf: BrainToTextDatasetTF, sample_size: int =
|
|||||||
# Utility functions for TPU-optimized data pipeline
|
# Utility functions for TPU-optimized data pipeline
|
||||||
def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
||||||
transform_args: Dict[str, Any],
|
transform_args: Dict[str, Any],
|
||||||
|
max_shapes: Dict[str, int],
|
||||||
training: bool = True,
|
training: bool = True,
|
||||||
cache_path: Optional[str] = None) -> tf.data.Dataset:
|
cache_path: Optional[str] = None) -> tf.data.Dataset:
|
||||||
"""
|
"""
|
||||||
Create input function for TPU training with DYNAMIC padding and data augmentation
|
Create input function for TPU training with PRE-ANALYZED FIXED shapes
|
||||||
|
|
||||||
This function uses dynamic shapes to avoid the "pad to a smaller size" error.
|
This function uses pre-computed maximum shapes to create fixed-size batches,
|
||||||
All variable-length dimensions use tf.TensorShape([None, ...]) to allow
|
ensuring XLA compilation success on TPU hardware.
|
||||||
TensorFlow to automatically determine the appropriate padding size for each batch.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset_tf: BrainToTextDatasetTF instance
|
dataset_tf: BrainToTextDatasetTF instance
|
||||||
transform_args: Data transformation configuration
|
transform_args: Data transformation configuration
|
||||||
|
max_shapes: Pre-computed maximum shapes dictionary with keys:
|
||||||
|
'max_time_steps', 'max_phone_seq_len', 'max_transcription_len', 'n_features'
|
||||||
training: Whether this is for training (applies augmentations)
|
training: Whether this is for training (applies augmentations)
|
||||||
cache_path: Optional path for disk caching to improve I/O performance
|
cache_path: Optional path for disk caching to improve I/O performance
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
tf.data.Dataset ready for TPU training with dynamic shapes
|
tf.data.Dataset ready for TPU training with fixed shapes
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Create individual example dataset with file-grouping I/O optimization
|
# Create individual example dataset with file-grouping I/O optimization
|
||||||
@@ -916,22 +918,29 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
|||||||
num_parallel_calls=tf.data.AUTOTUNE
|
num_parallel_calls=tf.data.AUTOTUNE
|
||||||
)
|
)
|
||||||
|
|
||||||
# ========================= DYNAMIC SHAPES SOLUTION =========================
|
# ========================= FIXED SHAPES SOLUTION =========================
|
||||||
# 使用动态形状避免 "pad to a smaller size" 错误
|
# 使用预分析的固定形状确保 XLA 编译成功
|
||||||
# 这是最简单、最健壮的解决方案
|
print(f"🔧 Using PRE-ANALYZED FIXED shapes for maximum TPU performance:")
|
||||||
print("🔧 Using DYNAMIC shapes for maximum compatibility and robustness.")
|
|
||||||
|
|
||||||
# Calculate number of features based on subset
|
# 从传入的参数中获取形状信息
|
||||||
n_features = len(dataset_tf.feature_subset) if dataset_tf.feature_subset else 512
|
max_time_steps = max_shapes['max_time_steps']
|
||||||
|
max_phone_seq_len = max_shapes['max_phone_seq_len']
|
||||||
|
max_transcription_len = max_shapes['max_transcription_len']
|
||||||
|
n_features = max_shapes['n_features']
|
||||||
|
|
||||||
# Define dynamic padded shapes - all variable dimensions use None
|
print(f" Fixed time steps: {max_time_steps}")
|
||||||
|
print(f" Fixed phone sequence length: {max_phone_seq_len}")
|
||||||
|
print(f" Fixed transcription length: {max_transcription_len}")
|
||||||
|
print(f" Number of features: {n_features}")
|
||||||
|
|
||||||
|
# Define fixed padded shapes - NO None values for XLA compatibility
|
||||||
padded_shapes = {
|
padded_shapes = {
|
||||||
'input_features': tf.TensorShape([None, n_features]), # 时间维度动态
|
'input_features': tf.TensorShape([max_time_steps, n_features]),
|
||||||
'seq_class_ids': tf.TensorShape([None]), # 序列长度动态
|
'seq_class_ids': tf.TensorShape([max_phone_seq_len]),
|
||||||
'n_time_steps': tf.TensorShape([]), # 标量
|
'n_time_steps': tf.TensorShape([]), # 标量
|
||||||
'phone_seq_lens': tf.TensorShape([]), # 标量
|
'phone_seq_lens': tf.TensorShape([]), # 标量
|
||||||
'day_indices': tf.TensorShape([]), # 标量
|
'day_indices': tf.TensorShape([]), # 标量
|
||||||
'transcriptions': tf.TensorShape([None]), # 转录长度动态
|
'transcriptions': tf.TensorShape([max_transcription_len]),
|
||||||
'block_nums': tf.TensorShape([]), # 标量
|
'block_nums': tf.TensorShape([]), # 标量
|
||||||
'trial_nums': tf.TensorShape([]) # 标量
|
'trial_nums': tf.TensorShape([]) # 标量
|
||||||
}
|
}
|
||||||
@@ -948,8 +957,7 @@ def create_input_fn(dataset_tf: BrainToTextDatasetTF,
|
|||||||
'trial_nums': 0
|
'trial_nums': 0
|
||||||
}
|
}
|
||||||
|
|
||||||
# Create batches with dynamic padding - TensorFlow will automatically
|
# Create batches with FIXED padding - XLA compiler will be happy!
|
||||||
# determine the appropriate padding size for each batch
|
|
||||||
dataset = dataset.padded_batch(
|
dataset = dataset.padded_batch(
|
||||||
batch_size=dataset_tf.batch_size,
|
batch_size=dataset_tf.batch_size,
|
||||||
padded_shapes=padded_shapes,
|
padded_shapes=padded_shapes,
|
||||||
|
|||||||
@@ -27,7 +27,8 @@ from dataset_tf import (
|
|||||||
BrainToTextDatasetTF,
|
BrainToTextDatasetTF,
|
||||||
DataAugmentationTF,
|
DataAugmentationTF,
|
||||||
train_test_split_indices,
|
train_test_split_indices,
|
||||||
create_input_fn
|
create_input_fn,
|
||||||
|
analyze_dataset_shapes
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -550,25 +551,25 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
|
|
||||||
# Calculate losses
|
# Calculate losses
|
||||||
if use_full:
|
if use_full:
|
||||||
# Clean CTC loss - use tf.nn.ctc_loss (TPU-compatible)
|
# Clean CTC loss - use tf.nn.ctc_loss with dense labels (fixed shapes)
|
||||||
# tf.nn.ctc_loss expects logits in time-major format [max_time, batch_size, num_classes]
|
# tf.nn.ctc_loss expects logits in time-major format [max_time, batch_size, num_classes]
|
||||||
clean_logits_time_major = tf.transpose(clean_logits, [1, 0, 2])
|
clean_logits_time_major = tf.transpose(clean_logits, [1, 0, 2])
|
||||||
clean_loss = tf.nn.ctc_loss(
|
clean_loss = tf.nn.ctc_loss(
|
||||||
labels=tf.cast(labels, tf.int32),
|
labels=tf.cast(labels, tf.int32), # Use dense labels with fixed shapes
|
||||||
logits=clean_logits_time_major,
|
logits=clean_logits_time_major,
|
||||||
label_length=tf.cast(phone_seq_lens, tf.int32),
|
label_length=tf.cast(phone_seq_lens, tf.int32), # Re-enable label_length
|
||||||
logit_length=tf.cast(adjusted_lens, tf.int32),
|
logit_length=tf.cast(adjusted_lens, tf.int32),
|
||||||
blank_index=0,
|
blank_index=0,
|
||||||
logits_time_major=True
|
logits_time_major=True
|
||||||
)
|
)
|
||||||
clean_loss = tf.reduce_mean(clean_loss)
|
clean_loss = tf.reduce_mean(clean_loss)
|
||||||
|
|
||||||
# Noisy CTC loss - use tf.nn.ctc_loss (TPU-compatible)
|
# Noisy CTC loss - use tf.nn.ctc_loss with dense labels (fixed shapes)
|
||||||
noisy_logits_time_major = tf.transpose(noisy_logits, [1, 0, 2])
|
noisy_logits_time_major = tf.transpose(noisy_logits, [1, 0, 2])
|
||||||
noisy_loss = tf.nn.ctc_loss(
|
noisy_loss = tf.nn.ctc_loss(
|
||||||
labels=tf.cast(labels, tf.int32),
|
labels=tf.cast(labels, tf.int32), # Use dense labels with fixed shapes
|
||||||
logits=noisy_logits_time_major,
|
logits=noisy_logits_time_major,
|
||||||
label_length=tf.cast(phone_seq_lens, tf.int32),
|
label_length=tf.cast(phone_seq_lens, tf.int32), # Re-enable label_length
|
||||||
logit_length=tf.cast(adjusted_lens, tf.int32),
|
logit_length=tf.cast(adjusted_lens, tf.int32),
|
||||||
blank_index=0,
|
blank_index=0,
|
||||||
logits_time_major=True
|
logits_time_major=True
|
||||||
@@ -582,12 +583,12 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
|
|
||||||
loss = clean_loss + self.adv_noisy_loss_weight * noisy_loss + self.adv_noise_l2_weight * noise_l2
|
loss = clean_loss + self.adv_noisy_loss_weight * noisy_loss + self.adv_noise_l2_weight * noise_l2
|
||||||
else:
|
else:
|
||||||
# Standard CTC loss - use tf.nn.ctc_loss (TPU-compatible)
|
# Standard CTC loss - use tf.nn.ctc_loss with dense labels (fixed shapes)
|
||||||
logits_time_major = tf.transpose(clean_logits, [1, 0, 2])
|
logits_time_major = tf.transpose(clean_logits, [1, 0, 2])
|
||||||
loss = tf.nn.ctc_loss(
|
loss = tf.nn.ctc_loss(
|
||||||
labels=tf.cast(labels, tf.int32),
|
labels=tf.cast(labels, tf.int32), # Use dense labels with fixed shapes
|
||||||
logits=logits_time_major,
|
logits=logits_time_major,
|
||||||
label_length=tf.cast(phone_seq_lens, tf.int32),
|
label_length=tf.cast(phone_seq_lens, tf.int32), # Re-enable label_length
|
||||||
logit_length=tf.cast(adjusted_lens, tf.int32),
|
logit_length=tf.cast(adjusted_lens, tf.int32),
|
||||||
blank_index=0,
|
blank_index=0,
|
||||||
logits_time_major=True
|
logits_time_major=True
|
||||||
@@ -652,13 +653,13 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
# Forward pass (inference mode only)
|
# Forward pass (inference mode only)
|
||||||
logits = self.model(features, day_indices, None, False, 'inference', training=False)
|
logits = self.model(features, day_indices, None, False, 'inference', training=False)
|
||||||
|
|
||||||
# Calculate loss - use tf.nn.ctc_loss (TPU-compatible)
|
# Calculate loss - use tf.nn.ctc_loss with dense labels (fixed shapes)
|
||||||
# tf.nn.ctc_loss expects logits in time-major format [max_time, batch_size, num_classes]
|
# tf.nn.ctc_loss expects logits in time-major format [max_time, batch_size, num_classes]
|
||||||
logits_time_major = tf.transpose(logits, [1, 0, 2])
|
logits_time_major = tf.transpose(logits, [1, 0, 2])
|
||||||
loss = tf.nn.ctc_loss(
|
loss = tf.nn.ctc_loss(
|
||||||
labels=tf.cast(labels, tf.int32),
|
labels=tf.cast(labels, tf.int32), # Use dense labels with fixed shapes
|
||||||
logits=logits_time_major,
|
logits=logits_time_major,
|
||||||
label_length=tf.cast(phone_seq_lens, tf.int32),
|
label_length=tf.cast(phone_seq_lens, tf.int32), # Re-enable label_length
|
||||||
logit_length=tf.cast(adjusted_lens, tf.int32),
|
logit_length=tf.cast(adjusted_lens, tf.int32),
|
||||||
blank_index=0,
|
blank_index=0,
|
||||||
logits_time_major=True
|
logits_time_major=True
|
||||||
@@ -679,28 +680,60 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
initial_tpu_status = self._get_detailed_tpu_status()
|
initial_tpu_status = self._get_detailed_tpu_status()
|
||||||
self.logger.info(f"Initial TPU Status: {initial_tpu_status}")
|
self.logger.info(f"Initial TPU Status: {initial_tpu_status}")
|
||||||
|
|
||||||
# Create datasets using modern distribution API
|
# ========================= DATASET SHAPE ANALYSIS =========================
|
||||||
def create_dist_dataset_fn(input_dataset_tf, training):
|
# Perform one-time full dataset analysis for fixed shapes (TPU requirement)
|
||||||
"""Create distributed dataset function for modern TPU strategy"""
|
self.logger.info("🚀 Performing one-time full dataset analysis for fixed shapes...")
|
||||||
|
|
||||||
|
# Analyze training dataset (all data for accurate max shapes)
|
||||||
|
train_analysis_start = time.time()
|
||||||
|
train_max_shapes = analyze_dataset_shapes(self.train_dataset_tf, sample_size=-1)
|
||||||
|
train_analysis_time = time.time() - train_analysis_start
|
||||||
|
self.logger.info(f"✅ Training dataset analysis completed in {train_analysis_time:.2f}s")
|
||||||
|
|
||||||
|
# Analyze validation dataset (all data for accurate max shapes)
|
||||||
|
val_analysis_start = time.time()
|
||||||
|
val_max_shapes = analyze_dataset_shapes(self.val_dataset_tf, sample_size=-1)
|
||||||
|
val_analysis_time = time.time() - val_analysis_start
|
||||||
|
self.logger.info(f"✅ Validation dataset analysis completed in {val_analysis_time:.2f}s")
|
||||||
|
|
||||||
|
# Use maximum shapes across both datasets for consistent padding
|
||||||
|
final_max_shapes = {
|
||||||
|
'max_time_steps': max(train_max_shapes['max_time_steps'], val_max_shapes['max_time_steps']),
|
||||||
|
'max_phone_seq_len': max(train_max_shapes['max_phone_seq_len'], val_max_shapes['max_phone_seq_len']),
|
||||||
|
'max_transcription_len': max(train_max_shapes['max_transcription_len'], val_max_shapes['max_transcription_len']),
|
||||||
|
'n_features': train_max_shapes['n_features']
|
||||||
|
}
|
||||||
|
|
||||||
|
self.logger.info(f"📊 Final fixed shapes for TPU training:")
|
||||||
|
self.logger.info(f" Time steps: {final_max_shapes['max_time_steps']}")
|
||||||
|
self.logger.info(f" Phone sequence length: {final_max_shapes['max_phone_seq_len']}")
|
||||||
|
self.logger.info(f" Transcription length: {final_max_shapes['max_transcription_len']}")
|
||||||
|
self.logger.info(f" Features: {final_max_shapes['n_features']}")
|
||||||
|
# =====================================================================
|
||||||
|
|
||||||
|
# Create datasets using modern distribution API with fixed shapes
|
||||||
|
def create_dist_dataset_fn(input_dataset_tf, training, max_shapes):
|
||||||
|
"""Create distributed dataset function for modern TPU strategy with fixed shapes"""
|
||||||
def dataset_fn(input_context):
|
def dataset_fn(input_context):
|
||||||
# create_input_fn returns a complete, batched tf.data.Dataset
|
# create_input_fn now requires max_shapes parameter for fixed shapes
|
||||||
return create_input_fn(
|
return create_input_fn(
|
||||||
input_dataset_tf,
|
input_dataset_tf,
|
||||||
self.args['dataset']['data_transforms'],
|
self.args['dataset']['data_transforms'],
|
||||||
|
max_shapes=max_shapes, # Pass pre-analyzed shapes
|
||||||
training=training
|
training=training
|
||||||
)
|
)
|
||||||
return self.strategy.distribute_datasets_from_function(dataset_fn)
|
return self.strategy.distribute_datasets_from_function(dataset_fn)
|
||||||
|
|
||||||
# Distribute datasets using modern API
|
# Distribute datasets using modern API with fixed shapes
|
||||||
self.logger.info("🔄 Distributing training dataset across TPU cores...")
|
self.logger.info("🔄 Distributing training dataset across TPU cores...")
|
||||||
dist_start_time = time.time()
|
dist_start_time = time.time()
|
||||||
train_dist_dataset = create_dist_dataset_fn(self.train_dataset_tf, training=True)
|
train_dist_dataset = create_dist_dataset_fn(self.train_dataset_tf, training=True, max_shapes=final_max_shapes)
|
||||||
train_dist_time = time.time() - dist_start_time
|
train_dist_time = time.time() - dist_start_time
|
||||||
self.logger.info(f"✅ Training dataset distributed in {train_dist_time:.2f}s")
|
self.logger.info(f"✅ Training dataset distributed in {train_dist_time:.2f}s")
|
||||||
|
|
||||||
self.logger.info("🔄 Distributing validation dataset across TPU cores...")
|
self.logger.info("🔄 Distributing validation dataset across TPU cores...")
|
||||||
val_start_time = time.time()
|
val_start_time = time.time()
|
||||||
val_dist_dataset = create_dist_dataset_fn(self.val_dataset_tf, training=False)
|
val_dist_dataset = create_dist_dataset_fn(self.val_dataset_tf, training=False, max_shapes=final_max_shapes)
|
||||||
val_dist_time = time.time() - val_start_time
|
val_dist_time = time.time() - val_start_time
|
||||||
self.logger.info(f"✅ Validation dataset distributed in {val_dist_time:.2f}s")
|
self.logger.info(f"✅ Validation dataset distributed in {val_dist_time:.2f}s")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user