tasks

Includes data, code, and configuration to reproduce experiments from (Graves 2016). Each module includes a torch.utils.data.IterableDataset to generate data for that task.

Additionally, each module can be run as a script configured through command-line arguments. In addition to those listed below, all flags supported by pytorch-lightning’s Trainer may be used.

For example, to run an experiment on one or more GPUs (recommended, as these take a long time), use --gpus n where n is the number of GPUs available, and similar with --tpu_cores.

Parity

Implements the Parity task from Graves 2016: determining the parity of a statically-presented binary vector.

class tasks.parity.ParityDataset(bits: int)

An infinite IterableDataset for binary parity problems.

Examples are pairs (vector, parity), where:
  • vector has a random number of places set to +/- 1 and the rest are zero.

  • parity is defined as 1 if there are an odd number of ones and 0 otherwise.

Parameters

bits – The maximum length of each binary vector.

By default, runs an easier version of the task. To reproduce the paper, use --bits 64 --hidden_size 128.

usage: poetry run pytorch-adaptive-computation-time/tasks/parity.py
       [-h] [--max-steps MAX_STEPS] [--bits BITS] [--hidden_size HIDDEN_SIZE]
       [--time_penalty TIME_PENALTY] [--batch_size BATCH_SIZE]
       [--learning_rate LEARNING_RATE] [--time-limit TIME_LIMIT]
       [--data-workers DATA_WORKERS]

Named Arguments

--max-steps

Default: 200000

--bits

Default: 16

--hidden_size

Default: 64

--time_penalty

Default: 0.001

--batch_size

Default: 32

--learning_rate

Default: 0.0001

--time-limit

Default: 20

--data-workers

Default: 1

Addition

Implements the Addition task from Graves 2016: adding a collection of decimal numbers. See Section 3.3 in the paper for full details.

class tasks.addition.AdditionDataset(sequence_length: int, max_digits: int)

An infinite IterableDataset for addition problems.

Examples are pairs (sequence of numbers, sequence of sums), where:
  • Each number is a vector of concatenated one-hot encoded decimal digits

  • Each target is the sum of all prior numbers in the sequence, as one 11-way classification for each decimal digit (or empty space) in the output.

Parameters
  • sequence_length – The length of sequence to add.

  • max_digits – The maximum number of decimal digits each number can have.

NOTE: uses a GRU instead of an LSTM, as originally used in the paper.

usage: poetry run pytorch_adaptive_computation_time/tasks/addition.py
       [-h] [--max-steps MAX_STEPS] [--sequence-length SEQUENCE_LENGTH]
       [--max-digits MAX_DIGITS] [--hidden_size HIDDEN_SIZE]
       [--time_penalty TIME_PENALTY] [--batch_size BATCH_SIZE]
       [--learning_rate LEARNING_RATE] [--time-limit TIME_LIMIT]
       [--data-workers DATA_WORKERS]

Named Arguments

--max-steps

Default: 200000

--sequence-length

Default: 5

--max-digits

Default: 5

--hidden_size

Default: 512

--time_penalty

Default: 0.001

--batch_size

Default: 32

--learning_rate

Default: 0.0001

--time-limit

Default: 20

--data-workers

Default: 1