fix twice gradient cut
This commit is contained in:
36
model_training_nnn_tpu/ISSUES.md
Normal file
36
model_training_nnn_tpu/ISSUES.md
Normal file
@@ -0,0 +1,36 @@
|
||||
# ISSUES
|
||||
|
||||
## 双重梯度裁剪
|
||||
优化器级别:global_clipnorm=self.args.get('grad_norm_clip_value', 0.0)(第283行)
|
||||
手动级别:tf.clip_by_global_norm(第447-449行)
|
||||
这导致梯度被裁剪两次,并且在TPU的分布式训练中可能引发内部状态冲突。
|
||||
|
||||
修复总结
|
||||
问题根源:双重梯度裁剪导致AdamW内部状态冲突 修复内容:
|
||||
移除了优化器级别的梯度裁剪:删除了 global_clipnorm 参数
|
||||
保留手动梯度裁剪:在 _train_step 中继续使用 tf.clip_by_global_norm
|
||||
为什么会出错:
|
||||
```python
|
||||
# 之前:双重裁剪
|
||||
optimizer = tf.keras.optimizers.AdamW(
|
||||
global_clipnorm=clip_value # 第一次裁剪
|
||||
)
|
||||
```
|
||||
```python
|
||||
# 在 _train_step 中:
|
||||
tf.clip_by_global_norm(gradients, clip_value) # 第二次裁剪
|
||||
optimizer.apply_gradients(...) # 内部再次处理,导致冲突
|
||||
现在的修复:
|
||||
```
|
||||
```python
|
||||
# 现在:只有一次裁剪
|
||||
optimizer = tf.keras.optimizers.AdamW(
|
||||
# 没有 global_clipnorm
|
||||
)
|
||||
|
||||
```
|
||||
```python
|
||||
# 在 _train_step 中:
|
||||
tf.clip_by_global_norm(gradients, clip_value) # 唯一的裁剪
|
||||
optimizer.apply_gradients(...) # 正常工作
|
||||
```
|
@@ -88,9 +88,7 @@ class BrainToTextDecoderTrainerTF:
|
||||
|
||||
# Build model within strategy scope
|
||||
with self.strategy.scope():
|
||||
print("🔨 Building model within TPU strategy scope...")
|
||||
self.model = self._build_model()
|
||||
print("⚙️ Creating optimizer...")
|
||||
self.optimizer = self._create_optimizer()
|
||||
print("🔧 Pre-building optimizer state for TPU...")
|
||||
# For TPU, we must ensure optimizer is completely ready before training
|
||||
@@ -125,13 +123,8 @@ class BrainToTextDecoderTrainerTF:
|
||||
print(f"Full traceback: {traceback.format_exc()}")
|
||||
raise RuntimeError(f"Optimizer pre-build failed: {e}") from e
|
||||
|
||||
print("📅 Setting up learning rate scheduler...")
|
||||
self.lr_scheduler = self._create_lr_scheduler()
|
||||
print("✅ LR scheduler ready")
|
||||
|
||||
print("🎯 Initializing CTC loss...")
|
||||
self.ctc_loss = CTCLoss(blank_index=0, reduction='none')
|
||||
print("✅ CTC loss initialized")
|
||||
|
||||
# Log model information
|
||||
self._log_model_info()
|
||||
@@ -452,12 +445,9 @@ class BrainToTextDecoderTrainerTF:
|
||||
beta_1=self.args['beta0'],
|
||||
beta_2=self.args['beta1'],
|
||||
epsilon=self.args['epsilon'],
|
||||
weight_decay=0.0, # Disabled for TPU compatibility
|
||||
# TPU-specific settings
|
||||
global_clipnorm=self.args.get('grad_norm_clip_value', 0.0) if self.args.get('grad_norm_clip_value', 0.0) > 0 else None
|
||||
weight_decay=0.0 # Disabled for TPU compatibility
|
||||
# REMOVE global_clipnorm to avoid double clipping with manual tf.clip_by_global_norm
|
||||
)
|
||||
print(f"⚠️ Weight decay disabled for TPU compatibility (was {self.args['weight_decay']})")
|
||||
print("💡 Consider implementing manual L2 regularization if needed")
|
||||
else:
|
||||
print("Using standard optimizer configuration")
|
||||
optimizer = tf.keras.optimizers.AdamW(
|
||||
|
Reference in New Issue
Block a user