tpu
This commit is contained in:
16
CLAUDE.md
16
CLAUDE.md
@@ -166,8 +166,8 @@ x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype) # bmm + dtyp
|
|||||||
```
|
```
|
||||||
|
|
||||||
#### 5. Mixed Precision Dtype Consistency
|
#### 5. Mixed Precision Dtype Consistency
|
||||||
**Problem**: Mixed precision training causes dtype mismatches in bmm operations and adversarial residual connections
|
**Problem**: Mixed precision training causes dtype mismatches in bmm operations, adversarial residual connections, and patch processing operations
|
||||||
**Solution**: Ensure all operands match input tensor dtype
|
**Solution**: Ensure all operands match input tensor dtype and preserve dtype through all operations
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# Error: f32[32,7168] vs bf16[32,7168] in mixed precision training
|
# Error: f32[32,7168] vs bf16[32,7168] in mixed precision training
|
||||||
@@ -176,6 +176,18 @@ x = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)
|
|||||||
|
|
||||||
# Fix 2: Ensure dtype consistency in adversarial training residual connections
|
# Fix 2: Ensure dtype consistency in adversarial training residual connections
|
||||||
denoised_input = x_processed - noise_output.to(x_processed.dtype)
|
denoised_input = x_processed - noise_output.to(x_processed.dtype)
|
||||||
|
|
||||||
|
# Fix 3: Preserve dtype through patch processing operations
|
||||||
|
if self.patch_size > 0:
|
||||||
|
original_dtype = x.dtype # Preserve original dtype for XLA/TPU compatibility
|
||||||
|
x = x.unsqueeze(1)
|
||||||
|
x = x.permute(0, 3, 1, 2)
|
||||||
|
x_unfold = x.unfold(3, self.patch_size, self.patch_stride)
|
||||||
|
x_unfold = x_unfold.squeeze(2)
|
||||||
|
x_unfold = x_unfold.permute(0, 2, 3, 1)
|
||||||
|
x = x_unfold.reshape(batch_size, x_unfold.size(1), -1)
|
||||||
|
# Ensure dtype consistency after patch processing operations
|
||||||
|
x = x.to(original_dtype)
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 3. Hidden State Initialization
|
#### 3. Hidden State Initialization
|
||||||
|
@@ -94,14 +94,17 @@ class NoiseModel(nn.Module):
|
|||||||
if self.input_dropout > 0:
|
if self.input_dropout > 0:
|
||||||
x = self.day_layer_dropout(x)
|
x = self.day_layer_dropout(x)
|
||||||
|
|
||||||
# Apply patch processing if enabled (keep conditional for now, optimize later)
|
# Apply patch processing if enabled with dtype preservation for mixed precision training
|
||||||
if self.patch_size > 0:
|
if self.patch_size > 0:
|
||||||
|
original_dtype = x.dtype # Preserve original dtype for XLA/TPU compatibility
|
||||||
x = x.unsqueeze(1)
|
x = x.unsqueeze(1)
|
||||||
x = x.permute(0, 3, 1, 2)
|
x = x.permute(0, 3, 1, 2)
|
||||||
x_unfold = x.unfold(3, self.patch_size, self.patch_stride)
|
x_unfold = x.unfold(3, self.patch_size, self.patch_stride)
|
||||||
x_unfold = x_unfold.squeeze(2)
|
x_unfold = x_unfold.squeeze(2)
|
||||||
x_unfold = x_unfold.permute(0, 2, 3, 1)
|
x_unfold = x_unfold.permute(0, 2, 3, 1)
|
||||||
x = x_unfold.reshape(batch_size, x_unfold.size(1), -1)
|
x = x_unfold.reshape(batch_size, x_unfold.size(1), -1)
|
||||||
|
# Ensure dtype consistency after patch processing operations
|
||||||
|
x = x.to(original_dtype)
|
||||||
|
|
||||||
# XLA-friendly hidden state initialization - avoid dynamic allocation
|
# XLA-friendly hidden state initialization - avoid dynamic allocation
|
||||||
if states is None:
|
if states is None:
|
||||||
@@ -193,14 +196,17 @@ class CleanSpeechModel(nn.Module):
|
|||||||
if self.input_dropout > 0:
|
if self.input_dropout > 0:
|
||||||
x = self.day_layer_dropout(x)
|
x = self.day_layer_dropout(x)
|
||||||
|
|
||||||
# Apply patch processing if enabled
|
# Apply patch processing if enabled with dtype preservation for mixed precision training
|
||||||
if self.patch_size > 0:
|
if self.patch_size > 0:
|
||||||
|
original_dtype = x.dtype # Preserve original dtype for XLA/TPU compatibility
|
||||||
x = x.unsqueeze(1)
|
x = x.unsqueeze(1)
|
||||||
x = x.permute(0, 3, 1, 2)
|
x = x.permute(0, 3, 1, 2)
|
||||||
x_unfold = x.unfold(3, self.patch_size, self.patch_stride)
|
x_unfold = x.unfold(3, self.patch_size, self.patch_stride)
|
||||||
x_unfold = x_unfold.squeeze(2)
|
x_unfold = x_unfold.squeeze(2)
|
||||||
x_unfold = x_unfold.permute(0, 2, 3, 1)
|
x_unfold = x_unfold.permute(0, 2, 3, 1)
|
||||||
x = x_unfold.reshape(batch_size, x_unfold.size(1), -1)
|
x = x_unfold.reshape(batch_size, x_unfold.size(1), -1)
|
||||||
|
# Ensure dtype consistency after patch processing operations
|
||||||
|
x = x.to(original_dtype)
|
||||||
|
|
||||||
# XLA-friendly hidden state initialization
|
# XLA-friendly hidden state initialization
|
||||||
if states is None:
|
if states is None:
|
||||||
@@ -383,14 +389,17 @@ class TripleGRUDecoder(nn.Module):
|
|||||||
x_processed = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)
|
x_processed = torch.bmm(x, day_weights.to(x.dtype)) + day_biases.to(x.dtype)
|
||||||
x_processed = self.clean_speech_model.day_layer_activation(x_processed)
|
x_processed = self.clean_speech_model.day_layer_activation(x_processed)
|
||||||
|
|
||||||
# Apply patch processing if enabled
|
# Apply patch processing if enabled with dtype preservation for mixed precision training
|
||||||
if self.patch_size > 0:
|
if self.patch_size > 0:
|
||||||
|
original_dtype = x_processed.dtype # Preserve original dtype for XLA/TPU compatibility
|
||||||
x_processed = x_processed.unsqueeze(1)
|
x_processed = x_processed.unsqueeze(1)
|
||||||
x_processed = x_processed.permute(0, 3, 1, 2)
|
x_processed = x_processed.permute(0, 3, 1, 2)
|
||||||
x_unfold = x_processed.unfold(3, self.patch_size, self.patch_stride)
|
x_unfold = x_processed.unfold(3, self.patch_size, self.patch_stride)
|
||||||
x_unfold = x_unfold.squeeze(2)
|
x_unfold = x_unfold.squeeze(2)
|
||||||
x_unfold = x_unfold.permute(0, 2, 3, 1)
|
x_unfold = x_unfold.permute(0, 2, 3, 1)
|
||||||
x_processed = x_unfold.reshape(batch_size, x_unfold.size(1), -1)
|
x_processed = x_unfold.reshape(batch_size, x_unfold.size(1), -1)
|
||||||
|
# Ensure dtype consistency after patch processing operations
|
||||||
|
x_processed = x_processed.to(original_dtype)
|
||||||
|
|
||||||
return x_processed
|
return x_processed
|
||||||
|
|
||||||
|
Reference in New Issue
Block a user