tpu
This commit is contained in:
@@ -336,30 +336,47 @@ class DataAugmentationTF:
|
||||
gauss_kernel = gauss_kernel[valid_idx].flatten()
|
||||
gauss_kernel = gauss_kernel / np.sum(gauss_kernel)
|
||||
|
||||
# Convert to TensorFlow tensor
|
||||
# Convert to TensorFlow tensor and reshape for conv1d
|
||||
gauss_kernel = tf.constant(gauss_kernel, dtype=tf.float32)
|
||||
gauss_kernel = tf.reshape(gauss_kernel, [1, 1, -1]) # [1, 1, kernel_size]
|
||||
kernel_size = tf.shape(gauss_kernel)[0]
|
||||
gauss_kernel = tf.reshape(gauss_kernel, [kernel_size, 1, 1]) # [kernel_size, in_channels, out_channels]
|
||||
|
||||
# Prepare for convolution
|
||||
# Get tensor dimensions
|
||||
batch_size = tf.shape(inputs)[0]
|
||||
time_steps = tf.shape(inputs)[1]
|
||||
num_features = tf.shape(inputs)[2]
|
||||
|
||||
# Reshape for convolution: [batch_size * features, 1, time_steps]
|
||||
inputs_reshaped = tf.transpose(inputs, [0, 2, 1]) # [batch_size, features, time_steps]
|
||||
inputs_reshaped = tf.reshape(inputs_reshaped, [-1, 1, time_steps])
|
||||
# Apply convolution to each feature channel separately
|
||||
smoothed_features = []
|
||||
|
||||
# Apply convolution
|
||||
smoothed = tf.nn.conv1d(
|
||||
inputs_reshaped,
|
||||
gauss_kernel,
|
||||
stride=1,
|
||||
padding='SAME'
|
||||
)
|
||||
# Convert num_features to Python int for loop
|
||||
num_features_py = inputs.shape[-1] if inputs.shape[-1] is not None else tf.shape(inputs)[-1]
|
||||
|
||||
# Reshape back to original format
|
||||
smoothed = tf.reshape(smoothed, [batch_size, num_features, time_steps])
|
||||
smoothed = tf.transpose(smoothed, [0, 2, 1]) # [batch_size, time_steps, features]
|
||||
if isinstance(num_features_py, tf.Tensor):
|
||||
# If dynamic, use tf.map_fn for dynamic number of features
|
||||
def smooth_single_feature(i):
|
||||
# Extract single feature channel: [batch_size, time_steps, 1]
|
||||
feature_channel = tf.expand_dims(inputs[:, :, i], axis=-1)
|
||||
# Apply 1D convolution
|
||||
return tf.nn.conv1d(feature_channel, gauss_kernel, stride=1, padding='SAME')
|
||||
|
||||
# Use tf.map_fn for dynamic features
|
||||
indices = tf.range(num_features)
|
||||
smoothed_features_tensor = tf.map_fn(smooth_single_feature, indices, dtype=tf.float32)
|
||||
# Transpose to get [batch_size, time_steps, features]
|
||||
smoothed = tf.transpose(smoothed_features_tensor, [1, 2, 0, 3])
|
||||
smoothed = tf.squeeze(smoothed, axis=-1)
|
||||
else:
|
||||
# Static number of features - use loop
|
||||
for i in range(num_features_py):
|
||||
# Extract single feature channel: [batch_size, time_steps, 1]
|
||||
feature_channel = tf.expand_dims(inputs[:, :, i], axis=-1)
|
||||
# Apply 1D convolution
|
||||
smoothed_channel = tf.nn.conv1d(feature_channel, gauss_kernel, stride=1, padding='SAME')
|
||||
smoothed_features.append(smoothed_channel)
|
||||
|
||||
# Concatenate all smoothed features
|
||||
smoothed = tf.concat(smoothed_features, axis=-1) # [batch_size, time_steps, features]
|
||||
|
||||
return smoothed
|
||||
|
||||
|
Reference in New Issue
Block a user