File size: 1,524 Bytes
bf277fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from random import choice

import torch

from dia.model import Dia


torch._inductor.config.coordinate_descent_tuning = True
torch._inductor.config.triton.unique_kernel_names = True
torch._inductor.config.fx_graph_cache = True

# debugging
torch._logging.set_logs(graph_breaks=True, recompiles=True)

model_name = "nari-labs/Dia-1.6B-0626"
compute_dtype = "float16"

model = Dia.from_pretrained(model_name, compute_dtype=compute_dtype)


test_cases = [
    "[S1] Dia is an open weights text to dialogue model.",
    "[S1] Dia is an open weights text to dialogue model. [S2] You get full control over scripts and voices. [S1] Wow. Amazing. (laughs) [S2] Try it now on Git hub or Hugging Face.",
    "[S1] torch.compile is a new feature in PyTorch that allows you to compile your model with a single line of code.",
    "[S1] torch.compile is a new feature in PyTorch that allows you to compile your model with a single line of code. [S2] It is a new feature in PyTorch that allows you to compile your model with a single line of code.",
]


# Wram up
for _ in range(2):
    text = choice(test_cases)
    output = model.generate(text, audio_prompt="./example_prompt.mp3", use_torch_compile=True, verbose=True)
    output = model.generate(text, use_torch_compile=True, verbose=True)

# Benchmark
for _ in range(10):
    text = choice(test_cases)
    output = model.generate(text, use_torch_compile=True, verbose=True)
    output = model.generate(text, audio_prompt="./example_prompt.mp3", use_torch_compile=True, verbose=True)