Source code for skself.partial_annotations.lazy_model

import tensorflow as tf


[docs] def pop_channel(tensor, channel_to_remove): # Get the number of channels in the tensor num_channels = tf.shape(tensor)[-1] # Convert the channel index to a non-negative index if channel_to_remove < 0: channel_to_remove += num_channels # Split the tensor into three parts: before, channel, and after before_channel = tensor[..., :channel_to_remove] removed_channel = tensor[..., channel_to_remove:channel_to_remove+1] after_channel = tensor[..., channel_to_remove + 1:] # Concatenate the two parts back together along the channel axis result_tensor = tf.concat([before_channel, after_channel], axis=-1) return result_tensor, removed_channel
[docs] class LazyLossWrapper(tf.keras.losses.Loss): def __init__(self, base_loss, mask_index=-1): super(LazyLossWrapper, self).__init__()
[docs] self.base_loss = base_loss
[docs] self.mask_index = mask_index
[docs] def call(self, y_true, y_pred): # Split the mask from y_true based on the user-specified index y_true, mask = pop_channel(y_true, self.mask_index) # Apply the mask to y_pred masked_y_pred = y_pred * (tf.reduce_sum(y_true, axis=-1, keepdims=True)) # (1- mask) y_true = y_true # * (1 - mask) + mask # Calculate the loss using the user-provided loss function loss = self.base_loss(y_true, masked_y_pred) return loss
[docs] class LazyMetricWrapper(tf.keras.metrics.MeanMetricWrapper): def __init__(self, metric_fn, mask_index=-1, name=None, **kwargs): name = name or (metric_fn.name) if hasattr(metric_fn, "name") else metric_fn.__name__ super(LazyMetricWrapper, self).__init__(metric_fn, name=name, **kwargs)
[docs] self.mask_index = mask_index
[docs] def update_state(self, y_true, y_pred, sample_weight=None): # Extract the mask from y_true based on the user-specified index y_true, mask = pop_channel(y_true, self.mask_index) # Apply the mask to y_pred masked_y_pred = y_pred * (tf.reduce_sum(y_true, axis=-1, keepdims=True)) # (1 - mask) #+ mask y_true = y_true # * (1 - mask) + mask # Call the metric_fn to update the metric value super(LazyMetricWrapper, self).update_state(y_true, masked_y_pred)
# Define the custom U-Net wrapper class
[docs] class LazySegmentationModel(tf.keras.Model): def __init__(self, base_model, ignore_channel_index=-1, **kwargs): super(LazySegmentationModel, self).__init__(**kwargs)
[docs] self.base_unet = base_model
[docs] self.mask_index = ignore_channel_index
[docs] def compile( self, optimizer="rmsprop", loss=None, loss_weights=None, metrics=None, weighted_metrics=None, run_eagerly=False, steps_per_execution=1, jit_compile="auto", **kwargs ): if isinstance(loss, str): # Convert the loss string to a loss function loss = tf.keras.losses.get(loss) if metrics is not None: if isinstance(metrics, str): # Convert the metrics string to a list of metric functions metrics = [tf.keras.metrics.get(metrics)] elif isinstance(metrics, list): # Convert each metric in the list to a metric function metrics = [ tf.keras.metrics.get(metric) if isinstance(metric, str) else metric for metric in metrics ] if loss is not None: # Wrap the loss function with the LazyLossWrapper using the specified mask_index loss = LazyLossWrapper(loss, mask_index=self.mask_index) if metrics is not None: # Wrap each metric with the LazyMetricWrapper using the specified mask_index metrics = [LazyMetricWrapper(metric, mask_index=self.mask_index) for metric in metrics] # Call the compile method of the base_unet with the wrapped loss and metrics self.base_unet.compile( optimizer=optimizer, loss=loss, metrics=metrics, loss_weights=loss_weights, weighted_metrics=weighted_metrics, run_eagerly=run_eagerly, steps_per_execution=steps_per_execution, jit_compile=jit_compile, **kwargs )
[docs] def fit(self, *args, **kwargs): # Call the call method of the base_unet with the input tensor return self.base_unet.fit(*args, **kwargs)
[docs] def call(self, *args, **kwargs): # Call the call method of the base_unet with the input tensor return self.base_unet.call(*args, **kwargs)
[docs] def evaluate(self, *args, **kwargs): # Call the call method of the base_unet with the input tensor return self.base_unet.evaluate(*args, **kwargs)