Skip to content

Commit d9316bf

Browse files
authored
Fix mutable proj_out weight in the Attention layer (huggingface#73)
* Catch unused params in DDP * Fix proj_out, add test
1 parent 3abf4bc commit d9316bf

4 files changed

Lines changed: 29 additions & 8 deletions

File tree

examples/train_unconditional.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import torch
55
import torch.nn.functional as F
66

7-
from accelerate import Accelerator
7+
from accelerate import Accelerator, DistributedDataParallelKwargs
88
from accelerate.logging import get_logger
99
from datasets import load_dataset
1010
from diffusers import DDIMPipeline, DDIMScheduler, UNetModel
@@ -27,8 +27,14 @@
2727

2828

2929
def main(args):
30+
ddp_unused_params = DistributedDataParallelKwargs(find_unused_parameters=True)
3031
logging_dir = os.path.join(args.output_dir, args.logging_dir)
31-
accelerator = Accelerator(mixed_precision=args.mixed_precision, log_with="tensorboard", logging_dir=logging_dir)
32+
accelerator = Accelerator(
33+
mixed_precision=args.mixed_precision,
34+
log_with="tensorboard",
35+
logging_dir=logging_dir,
36+
kwargs_handlers=[ddp_unused_params],
37+
)
3238

3339
model = UNetModel(
3440
attn_resolutions=(16,),

src/diffusers/models/attention.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def __init__(
7070
if encoder_channels is not None:
7171
self.encoder_kv = nn.Conv1d(encoder_channels, channels * 2, 1)
7272

73-
self.proj_out = zero_module(nn.Conv1d(channels, channels, 1))
73+
self.proj = zero_module(nn.Conv1d(channels, channels, 1))
7474

7575
self.overwrite_qkv = overwrite_qkv
7676
if overwrite_qkv:
@@ -108,15 +108,15 @@ def set_weights(self, module):
108108
proj_out.weight.data = module.proj_out.weight.data[:, :, :, 0]
109109
proj_out.bias.data = module.proj_out.bias.data
110110

111-
self.proj_out = proj_out
111+
self.proj = proj_out
112112
elif self.overwrite_linear:
113113
self.qkv.weight.data = torch.concat(
114114
[self.NIN_0.W.data.T, self.NIN_1.W.data.T, self.NIN_2.W.data.T], dim=0
115115
)[:, :, None]
116116
self.qkv.bias.data = torch.concat([self.NIN_0.b.data, self.NIN_1.b.data, self.NIN_2.b.data], dim=0)
117117

118-
self.proj_out.weight.data = self.NIN_3.W.data.T[:, :, None]
119-
self.proj_out.bias.data = self.NIN_3.b.data
118+
self.proj.weight.data = self.NIN_3.W.data.T[:, :, None]
119+
self.proj.bias.data = self.NIN_3.b.data
120120

121121
self.norm.weight.data = self.GroupNorm_0.weight.data
122122
self.norm.bias.data = self.GroupNorm_0.bias.data
@@ -149,7 +149,7 @@ def forward(self, x, encoder_out=None):
149149
a = torch.einsum("bts,bcs->bct", weight, v)
150150
h = a.reshape(bs, -1, length)
151151

152-
h = self.proj_out(h)
152+
h = self.proj(h)
153153
h = h.reshape(b, c, *spatial)
154154

155155
result = x + h

src/diffusers/training_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ def __init__(
3030
min_value (float): The minimum EMA decay rate. Default: 0.
3131
"""
3232

33-
self.averaged_model = copy.deepcopy(model)
33+
self.averaged_model = copy.deepcopy(model).eval()
3434
self.averaged_model.requires_grad_(False)
3535

3636
self.update_after_step = update_after_step

tests/test_modeling_utils.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@
5252
from diffusers.pipeline_utils import DiffusionPipeline
5353
from diffusers.pipelines.bddm.pipeline_bddm import DiffWave
5454
from diffusers.testing_utils import floats_tensor, slow, torch_device
55+
from diffusers.training_utils import EMAModel
5556

5657

5758
torch.backends.cuda.matmul.allow_tf32 = False
@@ -197,6 +198,20 @@ def test_training(self):
197198
loss = torch.nn.functional.mse_loss(output, noise)
198199
loss.backward()
199200

201+
def test_ema_training(self):
202+
init_dict, inputs_dict = self.prepare_init_args_and_inputs_for_common()
203+
204+
model = self.model_class(**init_dict)
205+
model.to(torch_device)
206+
model.train()
207+
ema_model = EMAModel(model, device=torch_device)
208+
209+
output = model(**inputs_dict)
210+
noise = torch.randn((inputs_dict["x"].shape[0],) + self.output_shape).to(torch_device)
211+
loss = torch.nn.functional.mse_loss(output, noise)
212+
loss.backward()
213+
ema_model.step(model)
214+
200215

201216
class UnetModelTests(ModelTesterMixin, unittest.TestCase):
202217
model_class = UNetModel

0 commit comments

Comments
 (0)