models

Implements adaptive computation time for RNNs from Graves 2016.

This module offers 3 interfaces:
  • Adaptive*, which mimics the torch.nn.RNN interface.

  • Adaptive*Cell, which mimics the torch.nn.RNNCell interface.

  • AdaptiveCellWrapper, which wraps any RNN cell to add adaptive computation time.

LSTMs and bidirectional networks are not currently implemented.

torch.nn.RNN-style interface

class models.AdaptiveRNN(input_size: int, hidden_size: int, num_layers: int, time_penalty: float, bias: bool = True, nonlinearity: str = 'tanh', initial_halting_bias: float = - 1.0, ponder_epsilon: float = 0.01, time_limit: int = 100, batch_first: bool = False, dropout: float = 0.0)

An adaptive-time variant of torch.nn.RNN.

Parameters
  • input_size – The number of expected features in the input.

  • hidden_size – The number of features in the hidden state.

  • num_layers – How many layers to use.

  • time_penalty – How heavily to penalize the model for thinking too long. Tau in Graves 2016.

  • bias – Whether to use a learnable bias for the input-hidden and hidden-hidden functions.

  • nonlinearity – The nonlinearity to use. Can be either “tanh” or “relu”.

  • initial_halting_bias – Value to initialize the halting unit’s bias to. Recommended to set this to a negative number to prevent long ponder sequences early in training.

  • ponder_epsilon – When the halting values sum to more than 1 - ponder_epsilon, stop computation. Used to enable halting on the first step.

  • time_limit – Hard limit for how many substeps any computation can take. Intended to prevent overly-long computation early-on. M in Graves 2016.

  • batch_first – If True, expects the first dimension of each sequence to be the batch axis and the second to be the sequence axis.

  • dropout – Amount of dropout to apply to the output of each layer except the last.

forward(inputs: torch.Tensor, hidden: Optional[torch.Tensor] = None) → Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
Parameters
  • inputs

    Input to the network.

    Shape: (seq_len, batch, input_size), or (batch, seq_len, input_size) if batch_first is True.

  • hidden

    Initial hidden state for each element in the batch, as a tensor of

    Shape: (num_layers * num_directions, batch, hidden_size)

Returns

  • output (torch.Tensor) – The output features from the last layer of the RNN for each timestep.

    Shape: (seq_len, batch, hidden_size), or (batch, seq_len, hidden_size) if batch_first is True.

  • hidden (torch.Tensor) – The hidden state for the final step.

    Shape: (num_layers, batch, hidden_size)

  • ponder_cost (torch.Tensor) – The total ponder cost for this sequence.

    Shape: ()

class models.AdaptiveGRU(input_size: int, hidden_size: int, num_layers: int, time_penalty: float, bias: bool = True, initial_halting_bias: float = - 1.0, ponder_epsilon: float = 0.01, time_limit: int = 100, batch_first: bool = False, dropout: float = 0.0)

An adaptive-time variant of torch.nn.GRU.

Parameters
  • input_size – The number of expected features in the input.

  • hidden_size – The number of features in the hidden state.

  • num_layers – How many layers to use.

  • time_penalty – How heavily to penalize the model for thinking too long. Tau in Graves 2016.

  • bias – Whether to use a learnable bias for the input-hidden and hidden-hidden functions.

  • initial_halting_bias – Value to initialize the halting unit’s bias to. Recommended to set this to a negative number to prevent long ponder sequences early in training.

  • ponder_epsilon – When the halting values sum to more than 1 - ponder_epsilon, stop computation. Used to enable halting on the first step.

  • time_limit – Hard limit for how many substeps any computation can take. Intended to prevent overly-long computation early-on. M in Graves 2016.

  • batch_first – If True, expects the first dimension of each sequence to be the batch axis and the second to be the sequence axis.

  • dropout – Amount of dropout to apply to the output of each layer except the last.

forward(inputs: torch.Tensor, hidden: Optional[torch.Tensor] = None) → Tuple[torch.Tensor, torch.Tensor, torch.Tensor]
Parameters
  • inputs

    Input to the network.

    Shape: (seq_len, batch, input_size), or (batch, seq_len, input_size) if batch_first is True.

  • hidden

    Initial hidden state for each element in the batch, as a tensor of

    Shape: (num_layers * num_directions, batch, hidden_size)

Returns

  • output (torch.Tensor) – The output features from the last layer of the RNN for each timestep.

    Shape: (seq_len, batch, hidden_size), or (batch, seq_len, hidden_size) if batch_first is True.

  • hidden (torch.Tensor) – The hidden state for the final step.

    Shape: (num_layers, batch, hidden_size)

  • ponder_cost (torch.Tensor) – The total ponder cost for this sequence.

    Shape: ()

torch.nn.RNNCell-style interface

class models.AdaptiveRNNCell(input_size: int, hidden_size: int, time_penalty: float, bias: bool = True, nonlinearity: str = 'tanh', initial_halting_bias: float = - 1.0, ponder_epsilon: float = 0.01, time_limit: int = 100)

An adaptive-time variant of torch.nn.RNNCell.

Parameters
  • input_size – The number of expected features in the input.

  • hidden_size – The number of features in the hidden state.

  • time_penalty – How heavily to penalize the model for thinking too long. Tau in Graves 2016.

  • bias – Whether to use a learnable bias for the input-hidden and hidden-hidden functions.

  • nonlinearity – The nonlinearity to use. Can be either “tanh” or “relu”.

  • initial_halting_bias – Value to initialize the halting unit’s bias to. Recommended to set this to a negative number to prevent long ponder sequences early in training.

  • ponder_epsilon – When the halting values sum to more than 1 - ponder_epsilon, stop computation. Used to enable halting on the first step.

  • time_limit – Hard limit for how many substeps any computation can take. Intended to prevent overly-long computation early-on. M in Graves 2016.

forward(inputs: torch.Tensor, hidden: Optional[torch.Tensor] = None) → Tuple[torch.Tensor, torch.Tensor, torch.Tensor]

Execute one timestep of the RNN, which may correspond to several internal steps.

Parameters
  • inputs

    Tensor containing input features.

    Shape: (batch, input_size)

  • hidden

    Initial hidden value for the wrapped cell. If not provided, relies on the wrapped cell to provide its own initial value.

    Shape: (batch, hidden_size)

Returns

  • next_hiddens (torch.Tensor) – The hidden state for this timestep.

    Shape: (batch, hidden_size)

  • ponder_cost (torch.Tensor) – The ponder cost for this timestep.

    Shape: () (scalar)

  • ponder_steps (torch.Tensor) – The number of ponder steps each element in the batch took.

    Shape: (batch)

class models.AdaptiveGRUCell(input_size: int, hidden_size: int, time_penalty: float, bias: bool = True, initial_halting_bias: float = - 1.0, ponder_epsilon: float = 0.01, time_limit: int = 100)

An adaptive-time variant of torch.nn.GRUCell.

Parameters
  • input_size – The number of expected features in the input.

  • hidden_size – The number of features in the hidden state.

  • time_penalty – How heavily to penalize the model for thinking too long. Tau in Graves 2016.

  • bias – Whether to use a learnable bias for the input-hidden and hidden-hidden functions.

  • initial_halting_bias – Value to initialize the halting unit’s bias to. Recommended to set this to a negative number to prevent long ponder sequences early in training.

  • ponder_epsilon – When the halting values sum to more than 1 - ponder_epsilon, stop computation. Used to enable halting on the first step.

  • time_limit – Hard limit for how many substeps any computation can take. Intended to prevent overly-long computation early-on. M in Graves 2016.

forward(inputs: torch.Tensor, hidden: Optional[torch.Tensor] = None) → Tuple[torch.Tensor, torch.Tensor, torch.Tensor]

Execute one timestep of the RNN, which may correspond to several internal steps.

Parameters
  • inputs

    Tensor containing input features.

    Shape: (batch, input_size)

  • hidden

    Initial hidden value for the wrapped cell. If not provided, relies on the wrapped cell to provide its own initial value.

    Shape: (batch, hidden_size)

Returns

  • next_hiddens (torch.Tensor) – The hidden state for this timestep.

    Shape: (batch, hidden_size)

  • ponder_cost (torch.Tensor) – The ponder cost for this timestep.

    Shape: () (scalar)

  • ponder_steps (torch.Tensor) – The number of ponder steps each element in the batch took.

    Shape: (batch)

Add adaptive computation time to an RNNCell

class models.AdaptiveCellWrapper(cell: torch.nn.modules.rnn.RNNCellBase, time_penalty: float, initial_halting_bias: float = - 1.0, ponder_epsilon: float = 0.01, time_limit: int = 100)

Wraps an RNN cell to add adaptive computation time.

Note that the cell will need an input size of 1 plus the desired input size, to allow for the extra first-step flag input.

Parameters
  • cell – The cell to wrap.

  • time_penalty – How heavily to penalize the model for thinking too long. Tau in Graves 2016.

  • initial_halting_bias – Value to initialize the halting unit’s bias to. Recommended to set this to a negative number to prevent long ponder sequences early in training.

  • ponder_epsilon – When the halting values sum to more than 1 - ponder_epsilon, stop computation. Used to enable halting on the first step.

  • time_limit – Hard limit for how many substeps any computation can take. Intended to prevent overly-long computation early-on. M in Graves 2016.

forward(inputs: torch.Tensor, hidden: Optional[torch.Tensor] = None) → Tuple[torch.Tensor, torch.Tensor, torch.Tensor]

Execute one timestep of the RNN, which may correspond to several internal steps.

Parameters
  • inputs

    Tensor containing input features.

    Shape: (batch, input_size)

  • hidden

    Initial hidden value for the wrapped cell. If not provided, relies on the wrapped cell to provide its own initial value.

    Shape: (batch, hidden_size)

Returns

  • next_hiddens (torch.Tensor) – The hidden state for this timestep.

    Shape: (batch, hidden_size)

  • ponder_cost (torch.Tensor) – The ponder cost for this timestep.

    Shape: () (scalar)

  • ponder_steps (torch.Tensor) – The number of ponder steps each element in the batch took.

    Shape: (batch)