diffusers中的AttnProcessor源码解析(key_padding_mask和attn_mask如何在MSA中作用)

1. prepare_attention_mask

这里结合Mutil Head Attention了解下不同mask的作用。key_padding_maskattn_mask两个实际上都是作用到attn_output_weights来影响最终的output,前者专注处理序列中的<PAD>,而后者专注处理序列交叉中的“不可见”逻辑

首先先建立一个概念:多头的分头,分的是QKV的 dim维度

query=[batch_size, source_length, dim], key=[batch_size, target_len, dim], value=[batch_size, target_len, dim]
# 分头(split heads)
head_dim = dim // heads # heads=8
query=[batch_size, source_length, heads, dim//heads], key=[batch_size, target_len, heads, dim//heads], value=[batch_size, target_len, heads, dim//heads]

1.1 key_padding_mask

  • key_padding_mask,长度是(B, S),B为batch_size,S为源序列长度,即query的seq_len(NLP的token个数S/CV的patch个数HW),序列中没有到达max_len的token用<PAD>填充,key_padding_mask中对应的位置为True,计算attention时会将key中mask=True的部分省略掉。
query=[batch_size, source_length, dim], key=[batch_size, target_len, dim], value=[batch_size, target_len, dim]

在这里插入图片描述
计算self-attention时,key_padding_mask只屏蔽key中的mask:即非mask的token作为query时,和sequence中所有非mask的token作为key计算self-attention;而mask的token也可以作为query,和sequence中所有非mask的token作为key计算self-attention。(mask的token不作为key token参与计算)因为就算mask的token作为key参与了计算,最后reshape会原来的形状后,也不使用padding的部分,所以这部分注意力的计算是冗余的。
在这里插入图片描述torch实现方法,key_padding_maskattn_mask上进行实现的。如下图的伪代码实现是(不考虑多头时 attn_mask.shape=[batch, seq_len, seq_len]):torch.baddbmm计算QK然后将attn_mask加到QK矩阵上,然后mask的部分就算负无穷-inf,再经过softmax就变为0.
在这里插入图片描述
a t t e n t i o n = S o f t m a x ( Q K T d k + a t t n _ m a s k ) ⋅ V attention=Softmax(\frac{QK^T}{\sqrt{d_k}}+attn\_mask)·V attention=Softmax(dk QKT+attn_mask)V在这里插入图片描述

# 模拟key_pad_mask加到attn_mask上
import torch
from einops import rearrange, repeat
batch_size, seq_len, dim = 1, 9, 8
key_pad_mask = torch.tensor([False, False, True, False, False, True, False, False, True]).unsqueeze(0)
# tensor([[False, False,  True, False, False,  True, False, False,  True]])
key_pad_mask = torch.where(key_pad_mask, float('-inf'), 0)
# tensor([[0., 0., -inf, 0., 0., -inf, 0., 0., -inf]])
key_pad_mask = repeat(key_pad_mask, 'b s -> b ss s', ss=seq_len)
'''
tensor([[[0., 0., -inf, 0., 0., -inf, 0., 0., -inf],
         [0., 0., -inf, 0., 0., -inf, 0., 0., -inf],
         [0., 0., -inf, 0., 0., -inf, 0., 0., -inf],
         [0., 0., -inf, 0., 0., -inf, 0., 0., -inf],
         [0., 0., -inf, 0., 0., -inf, 0., 0., -inf],
         [0., 0., -inf, 0., 0., -inf, 0., 0., -inf],
         [0., 0., -inf, 0., 0., -inf, 0., 0., -inf],
         [0., 0., -inf, 0., 0., -inf, 0., 0., -inf],
         [0., 0., -inf, 0., 0., -inf, 0., 0., -inf]]])
'''
# 假设用casual attention: 下三角attn_mask
attn_mask = torch.ones(batch_size, seq_len, seq_len, dtype=torch.bool).tril(diagonal=0)
attn_mask = torch.where(attn_mask, float('-inf'), 0)  # attn_mask.masked_fill_(attn_mask.logical_not(), float("-inf"))
'''
tensor([[[-inf, 0., 0., 0., 0., 0., 0., 0., 0.],
         [-inf, -inf, 0., 0., 0., 0., 0., 0., 0.],
         [-inf, -inf, -inf, 0., 0., 0., 0., 0., 0.],
         [-inf, -inf, -inf, -inf, 0., 0., 0., 0., 0.],
         [-inf, -inf, -inf, -inf, -inf, 0., 0., 0., 0.],
         [-inf, -inf, -inf, -inf, -inf, -inf, 0., 0., 0.],
         [-inf, -inf, -inf, -inf, -inf, -inf, -inf, 0., 0.],
         [-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, 0.],
         [-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf]]])
'''
attn_mask += key_pad_mask
'''
tensor([[[-inf, 0., -inf, 0., 0., -inf, 0., 0., -inf],
         [-inf, -inf, -inf, 0., 0., -inf, 0., 0., -inf],
         [-inf, -inf, -inf, 0., 0., -inf, 0., 0., -inf],
         [-inf, -inf, -inf, -inf, 0., -inf, 0., 0., -inf],
         [-inf, -inf, -inf, -inf, -inf, -inf, 0., 0., -inf],
         [-inf, -inf, -inf, -inf, -inf, -inf, 0., 0., -inf],
         [-inf, -inf, -inf, -inf, -inf, -inf, -inf, 0., -inf],
         [-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf],
         [-inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf, -inf]]])
'''

# query=[batch, seq_len, dim], key=[batch, tgt_len, dim], value=[batch, tgt_len, dim]
attn_score = torch.softmax(torch.baddbmm(attn_mask, query, key.transpose(-2, -1)), dim=-1)
attn_output = torch.bmm(attn_score, value)

1.2 attn_mask

  • attn_mask,长度是(B, source_length, target_length),其中B表示batch_sizesource_length表示源序列长度(Q的seq_len),target_length是目标序列长度(KV的seq_len),表示对权重矩阵做mask;
query=[batch_size, source_length, dim], key=[batch_size, target_len, dim], value=[batch_size, target_len, dim]

如果考虑多头,则要在scaled_dot_product_attention之前,把attn_mask为每个head复制一份(diffusers中使用prepare_attention_mask函数实现):

  • 如果attn_mask的shape是4维度的,初始(batch, source_length, target_length),则unseqeeze出一个head维度,沿第1维度(heads维度)复制heads份,变成(batch, heads, source_length, target_length)
  • 如果attn_mask的shape是3维度的,初始(batch, source_length, target_length),直接将注意力掩码沿着第0维度(batch维度)重复head_size次,变成(batch x heads, source_length, target_length)

这样batch x heads个头[source_length, target_length]@[target_length, source_length] 的矩阵乘法后,分别相同batch的head使用相同的attn_mask然后再进行softmax

# Efficient implementation equivalent to the following:
def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0.0, is_causal=False, scale=None) -> torch.Tensor:
    # Efficient implementation equivalent to the following:
    L, S = query.size(-2), key.size(-2)
    scale_factor = 1 / math.sqrt(query.size(-1)) if scale is None else scale
    attn_bias = torch.zeros(L, S, dtype=query.dtype)
    if is_causal:
        assert attn_mask is None
        temp_mask = torch.ones(L, S, dtype=torch.bool).tril(diagonal=0)
        attn_bias.masked_fill_(temp_mask.logical_not(), float("-inf"))
        attn_bias.to(query.dtype)

    if attn_mask is not None:
        if attn_mask.dtype == torch.bool:
            attn_bias.masked_fill_(attn_mask.logical_not(), float("-inf"))
        else:
            attn_bias += attn_mask
    attn_weight = query @ key.transpose(-2, -1) * scale_factor
    attn_weight += attn_bias
    attn_weight = torch.softmax(attn_weight, dim=-1)
    attn_weight = torch.dropout(attn_weight, dropout_p, train=True)
    return attn_weight @ value

2. AttnProcessor

用于执行 self-attention 或 cross-attention:

class AttnProcessor:
    r"""
    Default processor for performing attention-related computations.
    """

    def __call__(
        self,
        attn: Attention,
        hidden_states: torch.FloatTensor,
        encoder_hidden_states: Optional[torch.FloatTensor] = None,
        attention_mask: Optional[torch.FloatTensor] = None,
        temb: Optional[torch.FloatTensor] = None,
        scale: float = 1.0,
    ) -> torch.Tensor:
        residual = hidden_states

        args = () if USE_PEFT_BACKEND else (scale,)

        if attn.spatial_norm is not None:
            hidden_states = attn.spatial_norm(hidden_states, temb)

        input_ndim = hidden_states.ndim

        if input_ndim == 4:
            batch_size, channel, height, width = hidden_states.shape
            hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)

        batch_size, sequence_length, _ = (
            hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
        )
        attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)

        if attn.group_norm is not None:
            hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)

        query = attn.to_q(hidden_states, *args)

        if encoder_hidden_states is None:
            encoder_hidden_states = hidden_states
        elif attn.norm_cross:
            encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)

        key = attn.to_k(encoder_hidden_states, *args)
        value = attn.to_v(encoder_hidden_states, *args)

        query = attn.head_to_batch_dim(query)
        key = attn.head_to_batch_dim(key)
        value = attn.head_to_batch_dim(value)

        attention_probs = attn.get_attention_scores(query, key, attention_mask)
        hidden_states = torch.bmm(attention_probs, value)
        hidden_states = attn.batch_to_head_dim(hidden_states)

        # linear proj
        hidden_states = attn.to_out[0](hidden_states, *args)
        # dropout
        hidden_states = attn.to_out[1](hidden_states)

        if input_ndim == 4:
            hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)

        if attn.residual_connection:
            hidden_states = hidden_states + residual

        hidden_states = hidden_states / attn.rescale_output_factor

        return hidden_states
<think>嗯,用户遇到了Python中的ModuleNotFoundError,具体是找不到diffusers.models.unet_2d_condition模块。根据提供的引用信息,这可能是由于diffusers库的版本更新导致模块结构变化。我需要先分析可能的原因,然后提供解决方案。 首先,查看引用[4]提到diffusers库的版本问题,高版本中某些模块可能被重命名或移除。例如,CrossAttention被移到了attention_processor模块。同样,unet_2d_condition可能在较新版本中不再存在,或者路径发生了变化。 用户需要确认安装的diffusers版本。如果使用的是最新版,可能需要降级到兼容的版本,比如引用[4]中提到的0.14.0。或者检查是否有其他依赖库需要更新,比如transformers或modelscope,如引用[5]提到的文件路径,可能涉及其他库的兼容性问题。 另外,检查代码中的导入语句是否正确。比如,是否应该从不同的子模块导入,或者类名是否更改。例如,引用[3]中的PositionNet导入错误,可能是因为在新版本中该类的位置或名称发生了变化。 所以,解决方案可能包括降级diffusers,更新相关库,或者修改代码中的导入路径。需要用户先尝试安装旧版本,如果不行,再检查代码是否有需要调整的地方。同时,确保所有依赖库的版本兼容,避免冲突。</think>### 解决 `diffusers.models.unet_2d_condition` 模块缺失错误的方法 #### 问题分析 错误 `ModuleNotFoundError: No module named 'diffusers.models.unet_2d_condition'` 通常由以下原因导致: 1. **版本不兼容**:新版本 `diffusers` 库可能移除了旧模块路径,或模块被重命名[^4][^3]。 2. **依赖库冲突**:项目可能依赖特定版本的 `diffusers` 或关联库(如 `modelscope` 或 `transformers`)[^5]。 --- #### 分步解决方案 ##### 方案一:安装兼容的 `diffusers` 版本 1. **降级到旧版本**(推荐尝试): ```bash pip install diffusers==0.14.0 ``` 若报错仍存在,可尝试更低版本(如 `0.13.0`)。 2. **验证安装**: ```python import diffusers print(diffusers.__version__) # 确保版本为 0.14.0 ``` ##### 方案二:调整代码中的导入路径 若无法降级,需根据新版本库的接口修改代码: 1. **查找新模块路径**: - 检查 `diffusers` 官方文档或源码,确认 `unet_2d_condition` 是否被移动(例如合并到 `diffusers.models.unet`)。 2. **修改导入语句**: ```python # 原错误代码 from diffusers.models.unet_2d_condition import UNet2DConditionModel # 新版本可能改为 from diffusers import UNet2DConditionModel # 或从其他子模块导入 ``` ##### 方案三:更新关联依赖库 某些项目(如 `modelscope`)需要同步更新: ```bash pip install --upgrade modelscope transformers ``` --- #### 附加建议 - **虚拟环境管理**:使用 `conda` 或 `venv` 隔离项目环境,避免版本冲突。 - **参考官方文档**:查看 `diffusers` 的 [Release Notes](https://github.com/huggingface/diffusers/releases) 确认模块变更历史[^3]。 ---
评论
添加红包

请填写红包祝福语或标题

红包个数最小为10个

红包金额最低5元

当前余额3.43前往充值 >
需支付:10.00
成就一亿技术人!
领取后你会自动成为博主和红包主的粉丝 规则
hope_wisdom
发出的红包

打赏作者

Yuezero_

你的鼓励将是我创作的最大动力

¥1 ¥2 ¥4 ¥6 ¥10 ¥20
扫码支付:¥1
获取中
扫码支付

您的余额不足,请更换扫码支付或充值

打赏作者

实付
使用余额支付
点击重新获取
扫码支付
钱包余额 0

抵扣说明:

1.余额是钱包充值的虚拟货币,按照1:1的比例进行支付金额的抵扣。
2.余额无法直接购买下载,可以购买VIP、付费专栏及课程。

余额充值
OSZAR »