tasks¶
Includes data, code, and configuration to reproduce experiments from (Graves 2016). Each module includes a torch.utils.data.IterableDataset to generate data for that task.
Additionally, each module can be run as a script configured through command-line arguments. In addition to those listed below, all flags supported by pytorch-lightning’s Trainer may be used.
For example, to run an experiment on one or more GPUs (recommended, as these take a long time),
use --gpus n
where n is the number of GPUs available, and similar with --tpu_cores
.
Parity¶
Implements the Parity task from Graves 2016: determining the parity of a statically-presented binary vector.
-
class
tasks.parity.
ParityDataset
(bits: int)¶ An infinite IterableDataset for binary parity problems.
- Examples are pairs (vector, parity), where:
vector has a random number of places set to +/- 1 and the rest are zero.
parity is defined as 1 if there are an odd number of ones and 0 otherwise.
- Parameters
bits – The maximum length of each binary vector.
By default, runs an easier version of the task. To reproduce the paper, use
--bits 64 --hidden_size 128
.
usage: poetry run pytorch-adaptive-computation-time/tasks/parity.py
[-h] [--max-steps MAX_STEPS] [--bits BITS] [--hidden_size HIDDEN_SIZE]
[--time_penalty TIME_PENALTY] [--batch_size BATCH_SIZE]
[--learning_rate LEARNING_RATE] [--time-limit TIME_LIMIT]
[--data-workers DATA_WORKERS]
Named Arguments¶
- --max-steps
Default: 200000
- --bits
Default: 16
- --hidden_size
Default: 64
- --time_penalty
Default: 0.001
- --batch_size
Default: 32
- --learning_rate
Default: 0.0001
- --time-limit
Default: 20
- --data-workers
Default: 1
Addition¶
Implements the Addition task from Graves 2016: adding a collection of decimal numbers. See Section 3.3 in the paper for full details.
-
class
tasks.addition.
AdditionDataset
(sequence_length: int, max_digits: int)¶ An infinite IterableDataset for addition problems.
- Examples are pairs (sequence of numbers, sequence of sums), where:
Each number is a vector of concatenated one-hot encoded decimal digits
Each target is the sum of all prior numbers in the sequence, as one 11-way classification for each decimal digit (or empty space) in the output.
- Parameters
sequence_length – The length of sequence to add.
max_digits – The maximum number of decimal digits each number can have.
NOTE: uses a GRU instead of an LSTM, as originally used in the paper.
usage: poetry run pytorch_adaptive_computation_time/tasks/addition.py
[-h] [--max-steps MAX_STEPS] [--sequence-length SEQUENCE_LENGTH]
[--max-digits MAX_DIGITS] [--hidden_size HIDDEN_SIZE]
[--time_penalty TIME_PENALTY] [--batch_size BATCH_SIZE]
[--learning_rate LEARNING_RATE] [--time-limit TIME_LIMIT]
[--data-workers DATA_WORKERS]
Named Arguments¶
- --max-steps
Default: 200000
- --sequence-length
Default: 5
- --max-digits
Default: 5
- --hidden_size
Default: 512
- --time_penalty
Default: 0.001
- --batch_size
Default: 32
- --learning_rate
Default: 0.0001
- --time-limit
Default: 20
- --data-workers
Default: 1