This commit is contained in:
Zchen
2025-10-14 23:22:59 +08:00
parent 989ba67618
commit cd52ba51ba
2 changed files with 26 additions and 5 deletions

View File

@@ -166,8 +166,8 @@ x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype) # bmm + dtyp
```
#### 5. Mixed Precision Dtype Consistency
**Problem**: Mixed precision training causes dtype mismatches in bmm operations and adversarial residual connections
**Solution**: Ensure all operands match input tensor dtype
**Problem**: Mixed precision training causes dtype mismatches in bmm operations, adversarial residual connections, and patch processing operations
**Solution**: Ensure all operands match input tensor dtype and preserve dtype through all operations
```python
# Error: f32[32,7168] vs bf16[32,7168] in mixed precision training
@@ -176,6 +176,18 @@ x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)
# Fix 2: Ensure dtype consistency in adversarial training residual connections
denoised_input = x_processed - noise_output.to(x_processed.dtype)
# Fix 3: Preserve dtype through patch processing operations
if self.patch_size > 0:
original_dtype = x.dtype # Preserve original dtype for XLA/TPU compatibility
x = x.unsqueeze(1)
x = x.permute(0, 3, 1, 2)
x_unfold = x.unfold(3, self.patch_size, self.patch_stride)
x_unfold = x_unfold.squeeze(2)
x_unfold = x_unfold.permute(0, 2, 3, 1)
x = x_unfold.reshape(batch_size, x_unfold.size(1), -1)
# Ensure dtype consistency after patch processing operations
x = x.to(original_dtype)
```
#### 3. Hidden State Initialization