import os import torch import torch.nn.functional as F from bitdelta.diff import compress_diff, save_diff, save_full_model from bitdelta.misc import find_corr_stddev from bitdelta.utils import get_model, parse_args, get_tokenizer from tqdm import tqdm from bitdelta.data import get_dataset, get_dataloader import json args = parse_args() # create save_dir if it doesn't exist os.makedirs(args.save_dir, exist_ok=True) tokenizer = get_tokenizer(args.base_model) with torch.no_grad(): base_model = get_model(args.base_model, args.base_model_device, args.base_model_memory_map) finetuned_model = get_model(args.finetuned_model, args.finetuned_model_device, args.finetuned_model_memory_map) # get corr/stddev stats if args.debug: print(f"finding corr/stddev stats...") corrs, stddevs = find_corr_stddev(base_model, finetuned_model) corr = sum(corrs) / len(corrs) stddev = sum(stddevs) / len(stddevs) # save in args.save_dir as csv with open(os.path.join(args.save_dir, "corr_stddev.csv"), "w") as f: f.write(f"corr,stddev\n{corr},{stddev}") finetuned_compressed_model = get_model(args.finetuned_model, args.finetuned_compressed_model_device, args.finetuned_compressed_model_memory_map) print(f"compressing diff...") compress_diff(base_model, finetuned_model, finetuned_compressed_model) train_num_samples = args.batch_size * args.num_steps train_dataset = get_dataset( args.dataset_name, args.subset, "train", size=train_num_samples, ) train_dataloader = get_dataloader( train_dataset, tokenizer, args.batch_size, num_workers=4, max_length=args.max_length, ) # save untrained delta save_diff(finetuned_compressed_model, os.path.join(args.save_dir, "diff_untrained.pt")) optimizer = torch.optim.AdamW(finetuned_compressed_model.parameters(), lr=args.lr) scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.num_steps) bar = tqdm(train_dataloader) train_loss_list = [] # Train loop for step, batch in enumerate(bar): batch1 = {k: v.to(finetuned_model.device) for k, v in batch.items()} with torch.inference_mode(): finetuned_outputs = finetuned_model(**batch1) batch2 = {k: v.to(finetuned_compressed_model.device) for k, v in batch.items()} finetuned_compressed_outputs = finetuned_compressed_model(**batch2) loss = F.mse_loss( finetuned_outputs.logits.clone().to(finetuned_compressed_outputs.logits.device), finetuned_compressed_outputs.logits, ) train_loss_list.append(loss.item()) optimizer.zero_grad() loss.backward() optimizer.step() scheduler.step() bar.set_description(f"train loss: {loss.item()}") # save loss list if args.debug: with open(os.path.join(args.save_dir, f"train_loss_{args.num_groups}.json"), "w") as f: json.dump(train_loss_list, f) # save trained delta save_diff(finetuned_compressed_model, os.path.join(args.save_dir, "diff.pt")) del base_model, finetuned_model, finetuned_compressed_model torch.cuda.empty_cache() if args.save_full_model: print("saving uncalibrated model") save_full_model(args.base_model, args.finetuned_model, os.path.join(args.save_dir, "diff_untrained.pt"), os.path.join(args.save_dir, "uncalibrated_model"), device="cpu") print("saving calibrated model") save_full_model(args.base_model, args.finetuned_model, os.path.join(args.save_dir, "diff.pt"), os.path.join(args.save_dir, "calibrated_model"), device="cpu")