models¶
Implements adaptive computation time for RNNs from Graves 2016.
- This module offers 3 interfaces:
Adaptive*, which mimics the torch.nn.RNN interface.
Adaptive*Cell, which mimics the torch.nn.RNNCell interface.
AdaptiveCellWrapper, which wraps any RNN cell to add adaptive computation time.
LSTMs and bidirectional networks are not currently implemented.
torch.nn.RNN-style interface¶
-
class
models.
AdaptiveRNN
(input_size: int, hidden_size: int, num_layers: int, time_penalty: float, bias: bool = True, nonlinearity: str = 'tanh', initial_halting_bias: float = - 1.0, ponder_epsilon: float = 0.01, time_limit: int = 100, batch_first: bool = False, dropout: float = 0.0)¶ An adaptive-time variant of torch.nn.RNN.
- Parameters
input_size – The number of expected features in the input.
hidden_size – The number of features in the hidden state.
num_layers – How many layers to use.
time_penalty – How heavily to penalize the model for thinking too long. Tau in Graves 2016.
bias – Whether to use a learnable bias for the input-hidden and hidden-hidden functions.
nonlinearity – The nonlinearity to use. Can be either “tanh” or “relu”.
initial_halting_bias – Value to initialize the halting unit’s bias to. Recommended to set this to a negative number to prevent long ponder sequences early in training.
ponder_epsilon – When the halting values sum to more than 1 - ponder_epsilon, stop computation. Used to enable halting on the first step.
time_limit – Hard limit for how many substeps any computation can take. Intended to prevent overly-long computation early-on. M in Graves 2016.
batch_first – If True, expects the first dimension of each sequence to be the batch axis and the second to be the sequence axis.
dropout – Amount of dropout to apply to the output of each layer except the last.
-
forward
(inputs: torch.Tensor, hidden: Optional[torch.Tensor] = None) → Tuple[torch.Tensor, torch.Tensor, torch.Tensor]¶ - Parameters
inputs –
Input to the network.
Shape: (seq_len, batch, input_size), or (batch, seq_len, input_size) if batch_first is True.
hidden –
Initial hidden state for each element in the batch, as a tensor of
Shape: (num_layers * num_directions, batch, hidden_size)
- Returns
output (
torch.Tensor
) – The output features from the last layer of the RNN for each timestep.Shape: (seq_len, batch, hidden_size), or (batch, seq_len, hidden_size) if batch_first is True.
hidden (
torch.Tensor
) – The hidden state for the final step.Shape: (num_layers, batch, hidden_size)
ponder_cost (
torch.Tensor
) – The total ponder cost for this sequence.Shape: ()
-
class
models.
AdaptiveGRU
(input_size: int, hidden_size: int, num_layers: int, time_penalty: float, bias: bool = True, initial_halting_bias: float = - 1.0, ponder_epsilon: float = 0.01, time_limit: int = 100, batch_first: bool = False, dropout: float = 0.0)¶ An adaptive-time variant of torch.nn.GRU.
- Parameters
input_size – The number of expected features in the input.
hidden_size – The number of features in the hidden state.
num_layers – How many layers to use.
time_penalty – How heavily to penalize the model for thinking too long. Tau in Graves 2016.
bias – Whether to use a learnable bias for the input-hidden and hidden-hidden functions.
initial_halting_bias – Value to initialize the halting unit’s bias to. Recommended to set this to a negative number to prevent long ponder sequences early in training.
ponder_epsilon – When the halting values sum to more than 1 - ponder_epsilon, stop computation. Used to enable halting on the first step.
time_limit – Hard limit for how many substeps any computation can take. Intended to prevent overly-long computation early-on. M in Graves 2016.
batch_first – If True, expects the first dimension of each sequence to be the batch axis and the second to be the sequence axis.
dropout – Amount of dropout to apply to the output of each layer except the last.
-
forward
(inputs: torch.Tensor, hidden: Optional[torch.Tensor] = None) → Tuple[torch.Tensor, torch.Tensor, torch.Tensor]¶ - Parameters
inputs –
Input to the network.
Shape: (seq_len, batch, input_size), or (batch, seq_len, input_size) if batch_first is True.
hidden –
Initial hidden state for each element in the batch, as a tensor of
Shape: (num_layers * num_directions, batch, hidden_size)
- Returns
output (
torch.Tensor
) – The output features from the last layer of the RNN for each timestep.Shape: (seq_len, batch, hidden_size), or (batch, seq_len, hidden_size) if batch_first is True.
hidden (
torch.Tensor
) – The hidden state for the final step.Shape: (num_layers, batch, hidden_size)
ponder_cost (
torch.Tensor
) – The total ponder cost for this sequence.Shape: ()
torch.nn.RNNCell-style interface¶
-
class
models.
AdaptiveRNNCell
(input_size: int, hidden_size: int, time_penalty: float, bias: bool = True, nonlinearity: str = 'tanh', initial_halting_bias: float = - 1.0, ponder_epsilon: float = 0.01, time_limit: int = 100)¶ An adaptive-time variant of torch.nn.RNNCell.
- Parameters
input_size – The number of expected features in the input.
hidden_size – The number of features in the hidden state.
time_penalty – How heavily to penalize the model for thinking too long. Tau in Graves 2016.
bias – Whether to use a learnable bias for the input-hidden and hidden-hidden functions.
nonlinearity – The nonlinearity to use. Can be either “tanh” or “relu”.
initial_halting_bias – Value to initialize the halting unit’s bias to. Recommended to set this to a negative number to prevent long ponder sequences early in training.
ponder_epsilon – When the halting values sum to more than 1 - ponder_epsilon, stop computation. Used to enable halting on the first step.
time_limit – Hard limit for how many substeps any computation can take. Intended to prevent overly-long computation early-on. M in Graves 2016.
-
forward
(inputs: torch.Tensor, hidden: Optional[torch.Tensor] = None) → Tuple[torch.Tensor, torch.Tensor, torch.Tensor]¶ Execute one timestep of the RNN, which may correspond to several internal steps.
- Parameters
inputs –
Tensor containing input features.
Shape: (batch, input_size)
hidden –
Initial hidden value for the wrapped cell. If not provided, relies on the wrapped cell to provide its own initial value.
Shape: (batch, hidden_size)
- Returns
next_hiddens (
torch.Tensor
) – The hidden state for this timestep.Shape: (batch, hidden_size)
ponder_cost (
torch.Tensor
) – The ponder cost for this timestep.Shape: () (scalar)
ponder_steps (
torch.Tensor
) – The number of ponder steps each element in the batch took.Shape: (batch)
-
class
models.
AdaptiveGRUCell
(input_size: int, hidden_size: int, time_penalty: float, bias: bool = True, initial_halting_bias: float = - 1.0, ponder_epsilon: float = 0.01, time_limit: int = 100)¶ An adaptive-time variant of torch.nn.GRUCell.
- Parameters
input_size – The number of expected features in the input.
hidden_size – The number of features in the hidden state.
time_penalty – How heavily to penalize the model for thinking too long. Tau in Graves 2016.
bias – Whether to use a learnable bias for the input-hidden and hidden-hidden functions.
initial_halting_bias – Value to initialize the halting unit’s bias to. Recommended to set this to a negative number to prevent long ponder sequences early in training.
ponder_epsilon – When the halting values sum to more than 1 - ponder_epsilon, stop computation. Used to enable halting on the first step.
time_limit – Hard limit for how many substeps any computation can take. Intended to prevent overly-long computation early-on. M in Graves 2016.
-
forward
(inputs: torch.Tensor, hidden: Optional[torch.Tensor] = None) → Tuple[torch.Tensor, torch.Tensor, torch.Tensor]¶ Execute one timestep of the RNN, which may correspond to several internal steps.
- Parameters
inputs –
Tensor containing input features.
Shape: (batch, input_size)
hidden –
Initial hidden value for the wrapped cell. If not provided, relies on the wrapped cell to provide its own initial value.
Shape: (batch, hidden_size)
- Returns
next_hiddens (
torch.Tensor
) – The hidden state for this timestep.Shape: (batch, hidden_size)
ponder_cost (
torch.Tensor
) – The ponder cost for this timestep.Shape: () (scalar)
ponder_steps (
torch.Tensor
) – The number of ponder steps each element in the batch took.Shape: (batch)
Add adaptive computation time to an RNNCell¶
-
class
models.
AdaptiveCellWrapper
(cell: torch.nn.modules.rnn.RNNCellBase, time_penalty: float, initial_halting_bias: float = - 1.0, ponder_epsilon: float = 0.01, time_limit: int = 100)¶ Wraps an RNN cell to add adaptive computation time.
Note that the cell will need an input size of 1 plus the desired input size, to allow for the extra first-step flag input.
- Parameters
cell – The cell to wrap.
time_penalty – How heavily to penalize the model for thinking too long. Tau in Graves 2016.
initial_halting_bias – Value to initialize the halting unit’s bias to. Recommended to set this to a negative number to prevent long ponder sequences early in training.
ponder_epsilon – When the halting values sum to more than 1 - ponder_epsilon, stop computation. Used to enable halting on the first step.
time_limit – Hard limit for how many substeps any computation can take. Intended to prevent overly-long computation early-on. M in Graves 2016.
-
forward
(inputs: torch.Tensor, hidden: Optional[torch.Tensor] = None) → Tuple[torch.Tensor, torch.Tensor, torch.Tensor]¶ Execute one timestep of the RNN, which may correspond to several internal steps.
- Parameters
inputs –
Tensor containing input features.
Shape: (batch, input_size)
hidden –
Initial hidden value for the wrapped cell. If not provided, relies on the wrapped cell to provide its own initial value.
Shape: (batch, hidden_size)
- Returns
next_hiddens (
torch.Tensor
) – The hidden state for this timestep.Shape: (batch, hidden_size)
ponder_cost (
torch.Tensor
) – The ponder cost for this timestep.Shape: () (scalar)
ponder_steps (
torch.Tensor
) – The number of ponder steps each element in the batch took.Shape: (batch)