forked from Nerogar/OneTrainer
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsample.py
More file actions
60 lines (47 loc) · 1.63 KB
/
Copy pathsample.py
File metadata and controls
60 lines (47 loc) · 1.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
import os
import sys
sys.path.append(os.getcwd())
import torch
from modules.util.config.SampleConfig import SampleConfig
from modules.util.enum.ImageFormat import ImageFormat
from modules.util.enum.TrainingMethod import TrainingMethod
from modules.util import create
from modules.util.args.SampleArgs import SampleArgs
def main():
args = SampleArgs.parse_args()
device = torch.device("cuda")
training_method = TrainingMethod.FINE_TUNE
if args.embedding_name is not None:
training_method = TrainingMethod.EMBEDDING
model_loader = create.create_model_loader(args.model_type, training_method=training_method)
model_setup = create.create_model_setup(args.model_type, device, device, training_method=training_method)
print("Loading model " + args.base_model_name)
model = model_loader.load(
model_type=args.model_type,
weight_dtypes=args.weight_dtypes(),
)
model.to(device)
model.eval()
model_sampler = create.create_model_sampler(
train_device=device,
temp_device=device,
model=model,
model_type=args.model_type,
)
print("Sampling " + args.destination)
model_sampler.sample(
sample_params=SampleConfig.default_values().from_dict(
{
"prompt": args.prompt,
"negative_prompt": args.negative_prompt,
"height": 512,
"width": 512,
"seed": 42,
}
),
image_format=ImageFormat.JPG,
destination=args.destination,
text_encoder_layer_skip=args.text_encoder_layer_skip,
)
if __name__ == '__main__':
main()