Skip to content

runn.keras.layers.gather

Gather(indices, axis=1, **kwargs) #

Bases: Layer

Custom layer to gather the elements of a tensor based on an index tensor.

PARAMETER DESCRIPTION
indices

A tensor with the indices of the elements to gather.

TYPE: Tensor

axis

The axis along which to gather the elements. Default: 1.

TYPE: int DEFAULT: 1

Source code in runn/keras/layers/gather.py
def __init__(self, indices: tf.Tensor, axis: int = 1, **kwargs):
    super(Gather, self).__init__(**kwargs)
    self.indices = indices
    self.axis = axis

axis = axis instance-attribute #

indices = indices instance-attribute #

call(inputs) #

Source code in runn/keras/layers/gather.py
def call(self, inputs: tf.Tensor) -> tf.Tensor:
    # Perform the gather operation
    return tf.gather(inputs, indices=self.indices, axis=self.axis)

get_config() #

Source code in runn/keras/layers/gather.py
def get_config(self) -> dict:
    config = super(Gather, self).get_config()
    config.update({"indices": self.indices, "axis": self.axis})
    return config