Skip to content

Latest commit

 

History

History
 
 

README.md

Train using supervised examples

Requirements

pip install -e .. (pyproject.toml resides in the parent directory)

Make sure the oasst_data module is installed

python -m pip install ../../oasst-data/

Run tests: pytest .

You might run into a SystemExit here for the test tests/test_patched_gpt_neox.py::test_flash_attention_patch. If so just follow the warning and install flash_attn:

python -m pip install flash_attn

Start training SFT model

python trainer_sft.py --configs galactica-125m

If you want to get started with a small amount of test data to begin with, add the config webgpt_dataset_only.

If you kill and want to resume, see the --resume_from_checkpoint option.

For wandb: update the entity argument in trainer_sft.py's call to wandb.init to be your weights and biases username per docs.

Dataset choices

To specify which translation pair for WMT and TED Talk translation simply add the supported language pair at the postfix

  datasets:
    - wmt2019_zh-en
    - wmt2019_ru-en
    - wmt2019_de-en
    - ted_trans_nl-en
    - ted_trans_de-ja

Currently only these languages are supported via prompt translation:

ar,de,fr,en,it,nl,tr,ru,ms,ko,ja,zh

Dataset sub-sampling

We can subsample the training data by passing either the fraction or size argument in the configs/config.yml file. Don't forget the additional colon ":" after the dataset name when doing this.

Example:

  datasets:
    - webgpt:
        fraction : 0.05
    - prompt_dialogue:
        size : 500
    - adversarial_qa
    - trivia_qa_nocontext

In this example, per epoch we will use:

  • A random 5% of webgpt;
  • A random 500 examples from prompt_dialogue;
  • All examples from datasets for which we don't specify the fraction or size argument.

In the above example, per epoch we'll use a different 5% from webgpt and a different 500 examples from prompt_dialogue.

This works with torch.distributed.

Training only on OA internal data:

To experiment with the Open Assistant data simply run:

python trainer_sft.py --configs oasst_export_eu galactica-125m

Change the input_file_path in the oasst_export_eu from the configs/config.yaml file to the correct path.

Training the Reward Model

To experiment with the reward model run:

python trainer_rm.py --configs defaults_rm oasst-rm-1-pythia-1b

Since the model configs are kept quite minimal it is important to overwrite the other default options (as given by defaults_rm) with the model specific ones.

Training with RL

To train using trlx try:

python trainer_rl.py --configs defaults_rlhf

Test your model

You can interactively test your model like this:

python3 tools/model_cli.py --model_path <saved_path/huggingface>
# For example, if you trained with the default config:
python3 tools/model_cli.py --model_path saved_model
# Add --8bit  if it is an 8bit model

Or start a conversation with your bot interactively, mainly for testing context switch ability

python3 tools/model_chat.py --model_path <saved_path/huggingface>
# For example, if you trained with the default config:
python3 tools/model_chat.py --model_path saved_model

Model

Normally you should be able to add new models in configs/config.yml

your-model-name:
  learning_rate: 2e-6
  model_name: <huggingface model name>
  weight_decay: 0.01
  max_length: 812
  warmup_steps: 600
  gradient_checkpointing: false
  gradient_accumulation_steps: 5
  per_device_train_batch_size: 4
  per_device_eval_batch_size: 4
python trainer_sft.py --configs defaults your-model-name

However, if the model of your choice doesn't have pad_token, eos_token, sep_token, you have to update get_tokenizer in utils.py to use the right token.

Deepspeed support

You can edit the configs/zero_config.json and use any stage you wish. The current config uses zero-stage 3. For more details on how to setup the config checkout this page.

Once you are satisfy with your deepzero config, you can add --deepspeed flag at the end to trigger deepspeed

python trainer_sft.py --configs defaults your-model-name --deepspeed

Datasets

Here is an uncomplete overview of datasets for sft:

dataset_name | train_counts | eval_counts | total_counts

webgpt | 15662 | 3916 | 19578 squad_v2 | 130319 | 11873 | 142192 adversarial_qa | 30000 | 3000 | 33000 trivia_qa_nocontext | 138384 | 17944 | 156328 xsum | 204045 | 11332 | 215377 cnn_dailymail | 287113 | 13368 | 300481 multi_news | 44972 | 5622 | 50594 scitldr | 1992 | 619 | 2611 joke | 301 | 76 | 377 gsm8k | 7473 | 1319 | 8792 dive_mt | 6192 | 1548 | 7740

This list can be generated with the following command, but beware that this downloads all available datasets (>100GB):

python check_dataset_counts.py --datasets all --mode sft

One can specify datasets, which can be found in the config corresponding to the mode the mode (e.g. configs/config.yaml for sft, configs/config_rm.yaml for rm):

python check_dataset_counts.py --datasets webgpt squad_v2 --mode sft

Troubleshooting

  • If training on a VM, you might need to install OpenMPI. Check out this blog post by Lambda on how to install OpenMPI on their machines.
  • Installing mpi4py requires python-dev, which can be installed via sudo apt install libpython3.10-dev (replace 3.10 with whatever Python version you're running).

Results

Experimental results in wandb here.

TODOS