This commit is contained in:
Zchen
2025-10-15 20:45:25 +08:00
parent 3b242b908d
commit e8f0308fef
5 changed files with 409 additions and 19 deletions

View File

@@ -336,30 +336,47 @@ class DataAugmentationTF:
gauss_kernel = gauss_kernel[valid_idx].flatten()
gauss_kernel = gauss_kernel / np.sum(gauss_kernel)
# Convert to TensorFlow tensor
# Convert to TensorFlow tensor and reshape for conv1d
gauss_kernel = tf.constant(gauss_kernel, dtype=tf.float32)
gauss_kernel = tf.reshape(gauss_kernel, [1, 1, -1]) # [1, 1, kernel_size]
kernel_size = tf.shape(gauss_kernel)[0]
gauss_kernel = tf.reshape(gauss_kernel, [kernel_size, 1, 1]) # [kernel_size, in_channels, out_channels]
# Prepare for convolution
# Get tensor dimensions
batch_size = tf.shape(inputs)[0]
time_steps = tf.shape(inputs)[1]
num_features = tf.shape(inputs)[2]
# Reshape for convolution: [batch_size * features, 1, time_steps]
inputs_reshaped = tf.transpose(inputs, [0, 2, 1]) # [batch_size, features, time_steps]
inputs_reshaped = tf.reshape(inputs_reshaped, [-1, 1, time_steps])
# Apply convolution to each feature channel separately
smoothed_features = []
# Apply convolution
smoothed = tf.nn.conv1d(
inputs_reshaped,
gauss_kernel,
stride=1,
padding='SAME'
)
# Convert num_features to Python int for loop
num_features_py = inputs.shape[-1] if inputs.shape[-1] is not None else tf.shape(inputs)[-1]
# Reshape back to original format
smoothed = tf.reshape(smoothed, [batch_size, num_features, time_steps])
smoothed = tf.transpose(smoothed, [0, 2, 1]) # [batch_size, time_steps, features]
if isinstance(num_features_py, tf.Tensor):
# If dynamic, use tf.map_fn for dynamic number of features
def smooth_single_feature(i):
# Extract single feature channel: [batch_size, time_steps, 1]
feature_channel = tf.expand_dims(inputs[:, :, i], axis=-1)
# Apply 1D convolution
return tf.nn.conv1d(feature_channel, gauss_kernel, stride=1, padding='SAME')
# Use tf.map_fn for dynamic features
indices = tf.range(num_features)
smoothed_features_tensor = tf.map_fn(smooth_single_feature, indices, dtype=tf.float32)
# Transpose to get [batch_size, time_steps, features]
smoothed = tf.transpose(smoothed_features_tensor, [1, 2, 0, 3])
smoothed = tf.squeeze(smoothed, axis=-1)
else:
# Static number of features - use loop
for i in range(num_features_py):
# Extract single feature channel: [batch_size, time_steps, 1]
feature_channel = tf.expand_dims(inputs[:, :, i], axis=-1)
# Apply 1D convolution
smoothed_channel = tf.nn.conv1d(feature_channel, gauss_kernel, stride=1, padding='SAME')
smoothed_features.append(smoothed_channel)
# Concatenate all smoothed features
smoothed = tf.concat(smoothed_features, axis=-1) # [batch_size, time_steps, features]
return smoothed