import tensorflow as tf
[docs]
def pop_channel(tensor, channel_to_remove):
# Get the number of channels in the tensor
num_channels = tf.shape(tensor)[-1]
# Convert the channel index to a non-negative index
if channel_to_remove < 0:
channel_to_remove += num_channels
# Split the tensor into three parts: before, channel, and after
before_channel = tensor[..., :channel_to_remove]
removed_channel = tensor[..., channel_to_remove:channel_to_remove+1]
after_channel = tensor[..., channel_to_remove + 1:]
# Concatenate the two parts back together along the channel axis
result_tensor = tf.concat([before_channel, after_channel], axis=-1)
return result_tensor, removed_channel
[docs]
class LazyLossWrapper(tf.keras.losses.Loss):
def __init__(self, base_loss, mask_index=-1):
super(LazyLossWrapper, self).__init__()
[docs]
self.base_loss = base_loss
[docs]
self.mask_index = mask_index
[docs]
def call(self, y_true, y_pred):
# Split the mask from y_true based on the user-specified index
y_true, mask = pop_channel(y_true, self.mask_index)
# Apply the mask to y_pred
masked_y_pred = y_pred * (tf.reduce_sum(y_true, axis=-1, keepdims=True)) # (1- mask)
y_true = y_true # * (1 - mask) + mask
# Calculate the loss using the user-provided loss function
loss = self.base_loss(y_true, masked_y_pred)
return loss
[docs]
class LazyMetricWrapper(tf.keras.metrics.MeanMetricWrapper):
def __init__(self, metric_fn, mask_index=-1, name=None, **kwargs):
name = name or (metric_fn.name) if hasattr(metric_fn, "name") else metric_fn.__name__
super(LazyMetricWrapper, self).__init__(metric_fn, name=name, **kwargs)
[docs]
self.mask_index = mask_index
[docs]
def update_state(self, y_true, y_pred, sample_weight=None):
# Extract the mask from y_true based on the user-specified index
y_true, mask = pop_channel(y_true, self.mask_index)
# Apply the mask to y_pred
masked_y_pred = y_pred * (tf.reduce_sum(y_true, axis=-1, keepdims=True)) # (1 - mask) #+ mask
y_true = y_true # * (1 - mask) + mask
# Call the metric_fn to update the metric value
super(LazyMetricWrapper, self).update_state(y_true, masked_y_pred)
# Define the custom U-Net wrapper class
[docs]
class LazySegmentationModel(tf.keras.Model):
def __init__(self, base_model, ignore_channel_index=-1, **kwargs):
super(LazySegmentationModel, self).__init__(**kwargs)
[docs]
self.base_unet = base_model
[docs]
self.mask_index = ignore_channel_index
[docs]
def compile(
self,
optimizer="rmsprop",
loss=None,
loss_weights=None,
metrics=None,
weighted_metrics=None,
run_eagerly=False,
steps_per_execution=1,
jit_compile="auto",
**kwargs
):
if isinstance(loss, str):
# Convert the loss string to a loss function
loss = tf.keras.losses.get(loss)
if metrics is not None:
if isinstance(metrics, str):
# Convert the metrics string to a list of metric functions
metrics = [tf.keras.metrics.get(metrics)]
elif isinstance(metrics, list):
# Convert each metric in the list to a metric function
metrics = [
tf.keras.metrics.get(metric) if isinstance(metric, str) else metric
for metric in metrics
]
if loss is not None:
# Wrap the loss function with the LazyLossWrapper using the specified mask_index
loss = LazyLossWrapper(loss, mask_index=self.mask_index)
if metrics is not None:
# Wrap each metric with the LazyMetricWrapper using the specified mask_index
metrics = [LazyMetricWrapper(metric, mask_index=self.mask_index) for metric in metrics]
# Call the compile method of the base_unet with the wrapped loss and metrics
self.base_unet.compile(
optimizer=optimizer,
loss=loss,
metrics=metrics,
loss_weights=loss_weights,
weighted_metrics=weighted_metrics,
run_eagerly=run_eagerly,
steps_per_execution=steps_per_execution,
jit_compile=jit_compile,
**kwargs
)
[docs]
def fit(self, *args, **kwargs):
# Call the call method of the base_unet with the input tensor
return self.base_unet.fit(*args, **kwargs)
[docs]
def call(self, *args, **kwargs):
# Call the call method of the base_unet with the input tensor
return self.base_unet.call(*args, **kwargs)
[docs]
def evaluate(self, *args, **kwargs):
# Call the call method of the base_unet with the input tensor
return self.base_unet.evaluate(*args, **kwargs)