File size: 6,584 Bytes
8511ba7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
 # Copyright (c) 2026 ByteDance Ltd. and/or its affiliates
 # SPDX-License-Identifier: MIT

import torch
import numpy as np
import torch.nn.functional as F

from transformers import AutoTokenizer, AutoModel, AutoModelForCausalLM
from transformers.cache_utils import DynamicCache

def add_gumbel_noise(logits, temperature):
    if temperature == 0:
        return logits
    logits = logits.to(torch.float64)
    noise = torch.rand_like(logits, dtype=torch.float64)
    gumbel_noise = (- torch.log(noise)) ** temperature
    return logits.exp() / gumbel_noise


def get_num_transfer_tokens(mask_index, steps):
    mask_num = mask_index.sum(dim=1, keepdim=True)

    base = mask_num // steps
    remainder = mask_num % steps

    num_transfer_tokens = torch.zeros(mask_num.size(0), steps, device=mask_index.device, dtype=torch.int64) + base

    for i in range(mask_num.size(0)):
        num_transfer_tokens[i, :remainder[i]] += 1

    return num_transfer_tokens

def make_block_causal_mask(seq_len, block_size=2, device=None, dtype=torch.bool):
    num_blocks = (seq_len + block_size - 1) // block_size
    block_mask = torch.tril(torch.ones((num_blocks, num_blocks), dtype=torch.bool, device=device))
    local_block = torch.ones((block_size, block_size), dtype=torch.bool, device=device)
    mask = torch.kron(block_mask, local_block)[:seq_len, :seq_len]

    attention_mask = mask.float()
    attention_mask.masked_fill_(~mask, float('-inf'))
    attention_mask = attention_mask.unsqueeze(0).unsqueeze(0).to(dtype)
    return attention_mask

@ torch.no_grad()
def generate_block(model, prompt, steps=128, gen_length=128, block_length=128, temperature=0.,
            remasking='low_confidence', tokenizer=None, mask_id=5, threshold=0.95, shift=False, eos_id=None):
    x = torch.full((1, prompt.shape[1] + gen_length), mask_id, dtype=torch.long).to(model.device)
    x[:, :prompt.shape[1]] = prompt.clone()

    assert gen_length % block_length == 0
    num_blocks = gen_length // block_length

    assert steps % num_blocks == 0
    steps = steps // num_blocks

    prompt_len = prompt.shape[1]
    res_block = block_length - prompt_len % block_length
    every_block = [block_length for _ in range(num_blocks)]
    if res_block > 0:
        every_block = [res_block] + every_block
        every_block[-1] = block_length - res_block
    cum_block = [sum(every_block[:i+1]) for i in range(len(every_block))]
    num_block = len(cum_block)

    block_diffusion_attention_mask = make_block_causal_mask(prompt.shape[1] + gen_length, block_length, model.device, dtype=torch.bfloat16)
    nfe = 0  
    final_flag = 0
    prefill_length = prompt_len // block_length * block_length
    if prefill_length > 0:
        cur_attn_mask = block_diffusion_attention_mask[:, :, :prefill_length, :prefill_length]
        past_key_values = model(x[:, :prefill_length], attention_mask=cur_attn_mask, use_cache=True).past_key_values
    for num_block in range(num_blocks):
        current_block_start = prompt_len + cum_block[num_block - 1] if num_block > 0 else prefill_length
        current_block_end = prompt_len + cum_block[num_block]

        block_mask_index = (x[:, current_block_start:current_block_end] == mask_id)
        num_transfer_tokens = get_num_transfer_tokens(block_mask_index, steps)

        replace_position = torch.zeros_like(x, dtype=torch.bool)
        replace_position[:, current_block_start:current_block_end] = 1
        i = 0
        while True:
            nfe += 1
            mask_index = (x[:, current_block_start:current_block_end] == mask_id)
            cur_attn_mask = block_diffusion_attention_mask[:, :, current_block_start:current_block_end, :current_block_end]
            output = model(x[:, current_block_start:current_block_end], attention_mask=block_diffusion_attention_mask[:, :, current_block_start:current_block_end, :current_block_end], past_key_values=past_key_values, use_cache=True, cache_position=replace_position.nonzero(as_tuple=True)[1])
            logits = output.logits
            x0, transfer_index = get_transfer_index(logits, temperature, remasking, mask_index, 
                                            x[:, current_block_start:current_block_end], num_transfer_tokens[:, i] if threshold is None else None, threshold, shift=False)
            x[:, current_block_start:current_block_end][transfer_index] = x0[transfer_index]
            if (x[:, current_block_start:current_block_end] == mask_id).sum() == 0:
                if eos_id is not None and (x[:, current_block_start:current_block_end] == eos_id).sum() > 0:
                    final_flag = 1
                    x = x[:, :current_block_end]
                    break
                past_key_values = model(x[:, current_block_start:current_block_end], attention_mask=block_diffusion_attention_mask[:, :, current_block_start:current_block_end, :current_block_end], past_key_values=past_key_values, use_cache=True, cache_position=replace_position.nonzero(as_tuple=True)[1]).past_key_values
                break
        if final_flag == 1:
            break
            i += 1
    return x, nfe


def get_transfer_index(logits, temperature, remasking, mask_index, x, num_transfer_tokens, threshold=None, shift=False):
    logits_with_noise = add_gumbel_noise(logits, temperature=temperature)
    x0 = torch.argmax(logits_with_noise, dim=-1) # b, l
    if shift == True:
        x0 = torch.cat([x[:, :1], x0[:, :-1]], dim=-1)
        pad = torch.zeros_like(logits[:, :1])
        logits = torch.cat([pad, logits[:, :-1]], dim=1)
    if remasking == 'low_confidence':
        p = F.softmax(logits.to(torch.float64), dim=-1)
        x0_p = torch.squeeze(
            torch.gather(p, dim=-1, index=torch.unsqueeze(x0, -1)), -1) # b, l
    elif remasking == 'random':
        x0_p = torch.rand((x0.shape[0], x0.shape[1]), device=x0.device)
    else:
        raise NotImplementedError(remasking)
    
    x0 = torch.where(mask_index, x0, x)
    confidence = torch.where(mask_index, x0_p, -np.inf)

    transfer_index = torch.zeros_like(x0, dtype=torch.bool, device=x0.device)
    if threshold is not None:
        num_transfer_tokens = mask_index.sum(dim=1, keepdim=True)
    for j in range(confidence.shape[0]):
        _, select_index = torch.topk(confidence[j], k=num_transfer_tokens[j])
        transfer_index[j, select_index] = True
        if threshold is not None:
            for k in range(1, num_transfer_tokens[j]):
                if confidence[j, select_index[k]] < threshold:
                    transfer_index[j, select_index[k]] = False
    return x0, transfer_index