简单修复
This commit is contained in:
@@ -94,7 +94,11 @@ def gauss_smooth(inputs: tf.Tensor, smooth_kernel_std: float = 2.0, smooth_kerne
|
|||||||
return tf.nn.conv1d(feature_channel, gauss_kernel, stride=1, padding='SAME')
|
return tf.nn.conv1d(feature_channel, gauss_kernel, stride=1, padding='SAME')
|
||||||
|
|
||||||
indices = tf.range(tf.shape(inputs)[-1])
|
indices = tf.range(tf.shape(inputs)[-1])
|
||||||
smoothed_features_tensor = tf.map_fn(smooth_single_feature, indices, dtype=tf.float32)
|
smoothed_features_tensor = tf.map_fn(
|
||||||
|
smooth_single_feature,
|
||||||
|
indices,
|
||||||
|
fn_output_signature=tf.TensorSpec(shape=[None, None, 1], dtype=tf.float32)
|
||||||
|
)
|
||||||
smoothed = tf.transpose(smoothed_features_tensor, [1, 2, 0, 3])
|
smoothed = tf.transpose(smoothed_features_tensor, [1, 2, 0, 3])
|
||||||
smoothed = tf.squeeze(smoothed, axis=-1)
|
smoothed = tf.squeeze(smoothed, axis=-1)
|
||||||
else:
|
else:
|
||||||
@@ -109,7 +113,25 @@ def gauss_smooth(inputs: tf.Tensor, smooth_kernel_std: float = 2.0, smooth_kerne
|
|||||||
return smoothed
|
return smoothed
|
||||||
```
|
```
|
||||||
|
|
||||||
## 4. ✅ Test Script Fix (`test_tensorflow_implementation.py`)
|
## 4. ✅ TensorFlow Deprecation Warning Fix (`dataset_tf.py`)
|
||||||
|
|
||||||
|
**Problem**: `calling map_fn_v2 with dtype is deprecated and will be removed in a future version`
|
||||||
|
|
||||||
|
**Solution**: Replaced deprecated `dtype` parameter with `fn_output_signature` in `tf.map_fn`:
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Before (deprecated):
|
||||||
|
smoothed_features_tensor = tf.map_fn(smooth_single_feature, indices, dtype=tf.float32)
|
||||||
|
|
||||||
|
# After (modern API):
|
||||||
|
smoothed_features_tensor = tf.map_fn(
|
||||||
|
smooth_single_feature,
|
||||||
|
indices,
|
||||||
|
fn_output_signature=tf.TensorSpec(shape=[None, None, 1], dtype=tf.float32)
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## 5. ✅ Test Script Fix (`test_tensorflow_implementation.py`)
|
||||||
|
|
||||||
**Problem**: `cannot access local variable 'expected_features' where it is not associated with a value`
|
**Problem**: `cannot access local variable 'expected_features' where it is not associated with a value`
|
||||||
|
|
||||||
|
@@ -362,7 +362,11 @@ class DataAugmentationTF:
|
|||||||
|
|
||||||
# Use tf.map_fn for dynamic features
|
# Use tf.map_fn for dynamic features
|
||||||
indices = tf.range(num_features)
|
indices = tf.range(num_features)
|
||||||
smoothed_features_tensor = tf.map_fn(smooth_single_feature, indices, dtype=tf.float32)
|
smoothed_features_tensor = tf.map_fn(
|
||||||
|
smooth_single_feature,
|
||||||
|
indices,
|
||||||
|
fn_output_signature=tf.TensorSpec(shape=[None, None, 1], dtype=tf.float32)
|
||||||
|
)
|
||||||
# Transpose to get [batch_size, time_steps, features]
|
# Transpose to get [batch_size, time_steps, features]
|
||||||
smoothed = tf.transpose(smoothed_features_tensor, [1, 2, 0, 3])
|
smoothed = tf.transpose(smoothed_features_tensor, [1, 2, 0, 3])
|
||||||
smoothed = tf.squeeze(smoothed, axis=-1)
|
smoothed = tf.squeeze(smoothed, axis=-1)
|
||||||
|
@@ -135,10 +135,15 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
print(f"💻 Available CPU cores: {available_cores}")
|
print(f"💻 Available CPU cores: {available_cores}")
|
||||||
|
|
||||||
# Optimize for data pipeline parallelism
|
# Optimize for data pipeline parallelism
|
||||||
# Use ~1/4 of cores for inter-op (between operations)
|
# For 224 cores, use more threads for better data loading performance
|
||||||
# Use ~1/8 of cores for intra-op (within operations)
|
if available_cores >= 200: # High core count system
|
||||||
inter_op_threads = min(32, available_cores // 4)
|
inter_op_threads = min(64, available_cores // 3) # More aggressive for 224 cores
|
||||||
intra_op_threads = min(16, available_cores // 8)
|
intra_op_threads = min(32, available_cores // 6)
|
||||||
|
else:
|
||||||
|
# Use ~1/4 of cores for inter-op (between operations)
|
||||||
|
# Use ~1/8 of cores for intra-op (within operations)
|
||||||
|
inter_op_threads = min(32, available_cores // 4)
|
||||||
|
intra_op_threads = min(16, available_cores // 8)
|
||||||
|
|
||||||
tf.config.threading.set_inter_op_parallelism_threads(inter_op_threads)
|
tf.config.threading.set_inter_op_parallelism_threads(inter_op_threads)
|
||||||
tf.config.threading.set_intra_op_parallelism_threads(intra_op_threads)
|
tf.config.threading.set_intra_op_parallelism_threads(intra_op_threads)
|
||||||
@@ -148,6 +153,63 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
print(f" Intra-op parallelism: {intra_op_threads} threads")
|
print(f" Intra-op parallelism: {intra_op_threads} threads")
|
||||||
print(f" This will accelerate data loading and preprocessing while TPU handles training")
|
print(f" This will accelerate data loading and preprocessing while TPU handles training")
|
||||||
|
|
||||||
|
def _get_tpu_status(self) -> str:
|
||||||
|
"""Get current TPU status and utilization info"""
|
||||||
|
try:
|
||||||
|
# Get TPU devices
|
||||||
|
tpu_devices = tf.config.list_logical_devices('TPU')
|
||||||
|
|
||||||
|
if not tpu_devices:
|
||||||
|
return "TPU: No devices"
|
||||||
|
|
||||||
|
# Get strategy info
|
||||||
|
num_replicas = self.strategy.num_replicas_in_sync if hasattr(self.strategy, 'num_replicas_in_sync') else 1
|
||||||
|
|
||||||
|
# Get memory usage (simplified)
|
||||||
|
import psutil
|
||||||
|
memory = psutil.virtual_memory()
|
||||||
|
|
||||||
|
return (f"TPU: {len(tpu_devices)}dev {num_replicas}cores "
|
||||||
|
f"RAM: {memory.percent:.1f}%")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return f"TPU: status_error({str(e)[:20]})"
|
||||||
|
|
||||||
|
def _get_detailed_tpu_status(self) -> str:
|
||||||
|
"""Get detailed TPU status for training start"""
|
||||||
|
try:
|
||||||
|
# Get TPU devices
|
||||||
|
tpu_devices = tf.config.list_logical_devices('TPU')
|
||||||
|
|
||||||
|
if not tpu_devices:
|
||||||
|
return "❌ No TPU devices detected"
|
||||||
|
|
||||||
|
# Get strategy info
|
||||||
|
num_replicas = self.strategy.num_replicas_in_sync if hasattr(self.strategy, 'num_replicas_in_sync') else 1
|
||||||
|
strategy_type = type(self.strategy).__name__
|
||||||
|
|
||||||
|
# Get memory info
|
||||||
|
import psutil
|
||||||
|
memory = psutil.virtual_memory()
|
||||||
|
|
||||||
|
# Simple TPU test
|
||||||
|
try:
|
||||||
|
with tf.device('/TPU:0'):
|
||||||
|
test_result = tf.constant([1.0, 2.0])
|
||||||
|
_ = tf.reduce_sum(test_result)
|
||||||
|
tpu_test = "✅ responsive"
|
||||||
|
except Exception as e:
|
||||||
|
tpu_test = f"❌ test_failed({str(e)[:15]})"
|
||||||
|
|
||||||
|
return (f"TPU Devices: {len(tpu_devices)} | "
|
||||||
|
f"Strategy: {strategy_type} | "
|
||||||
|
f"Cores: {num_replicas} | "
|
||||||
|
f"RAM: {memory.percent:.1f}% ({memory.used//1024//1024//1024}GB/{memory.total//1024//1024//1024}GB) | "
|
||||||
|
f"Test: {tpu_test}")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
return f"❌ TPU status check failed: {str(e)[:50]}"
|
||||||
|
|
||||||
def _initialize_datasets(self):
|
def _initialize_datasets(self):
|
||||||
"""Initialize training and validation datasets"""
|
"""Initialize training and validation datasets"""
|
||||||
# Create file paths
|
# Create file paths
|
||||||
@@ -448,6 +510,10 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
"""Main training loop"""
|
"""Main training loop"""
|
||||||
self.logger.info("Starting training loop...")
|
self.logger.info("Starting training loop...")
|
||||||
|
|
||||||
|
# Log initial TPU status
|
||||||
|
initial_tpu_status = self._get_detailed_tpu_status()
|
||||||
|
self.logger.info(f"Initial TPU Status: {initial_tpu_status}")
|
||||||
|
|
||||||
# Create distributed datasets
|
# Create distributed datasets
|
||||||
train_dataset = create_input_fn(
|
train_dataset = create_input_fn(
|
||||||
self.train_dataset_tf,
|
self.train_dataset_tf,
|
||||||
@@ -493,12 +559,14 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
train_step_duration = time.time() - start_time
|
train_step_duration = time.time() - start_time
|
||||||
train_losses.append(float(loss.numpy()))
|
train_losses.append(float(loss.numpy()))
|
||||||
|
|
||||||
# Log training progress
|
# Log training progress with TPU status
|
||||||
if step % self.args['batches_per_train_log'] == 0:
|
if step % self.args['batches_per_train_log'] == 0:
|
||||||
|
tpu_status = self._get_tpu_status()
|
||||||
self.logger.info(f'Train batch {step}: '
|
self.logger.info(f'Train batch {step}: '
|
||||||
f'loss: {float(loss.numpy()):.2f} '
|
f'loss: {float(loss.numpy()):.2f} '
|
||||||
f'grad norm: {float(grad_norm.numpy()):.2f} '
|
f'grad norm: {float(grad_norm.numpy()):.2f} '
|
||||||
f'time: {train_step_duration:.3f}')
|
f'time: {train_step_duration:.3f}s '
|
||||||
|
f'| {tpu_status}')
|
||||||
|
|
||||||
# Validation step
|
# Validation step
|
||||||
if step % self.args['batches_per_val_step'] == 0 or step == (self.args['num_training_batches'] - 1):
|
if step % self.args['batches_per_val_step'] == 0 or step == (self.args['num_training_batches'] - 1):
|
||||||
@@ -508,10 +576,12 @@ class BrainToTextDecoderTrainerTF:
|
|||||||
val_metrics = self._validate(val_dist_dataset)
|
val_metrics = self._validate(val_dist_dataset)
|
||||||
val_step_duration = time.time() - val_start_time
|
val_step_duration = time.time() - val_start_time
|
||||||
|
|
||||||
|
tpu_status = self._get_tpu_status()
|
||||||
self.logger.info(f'Val batch {step}: '
|
self.logger.info(f'Val batch {step}: '
|
||||||
f'PER (avg): {val_metrics["avg_per"]:.4f} '
|
f'PER (avg): {val_metrics["avg_per"]:.4f} '
|
||||||
f'CTC Loss (avg): {val_metrics["avg_loss"]:.4f} '
|
f'CTC Loss (avg): {val_metrics["avg_loss"]:.4f} '
|
||||||
f'time: {val_step_duration:.3f}')
|
f'time: {val_step_duration:.3f}s '
|
||||||
|
f'| {tpu_status}')
|
||||||
|
|
||||||
val_pers.append(val_metrics['avg_per'])
|
val_pers.append(val_metrics['avg_per'])
|
||||||
val_losses.append(val_metrics['avg_loss'])
|
val_losses.append(val_metrics['avg_loss'])
|
||||||
|
Reference in New Issue
Block a user