Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions defuser/model_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
# Contact: qubitium@modelcloud.ai, x.com/qubitium
from enum import Enum

from transformers.core_model_loading import WeightConverter

from defuser.checkpoint_ops import OwnedChunk


class PATCH(str, Enum):
REPLACE_MODULE = "replace_module"
Expand Down Expand Up @@ -71,4 +75,24 @@ class PATCH(str, Enum):
)
],
},
"glm4v": {
"min_transformers_version": "5.0.0",
PATCH.REPLACE_MODULE: [
(
"transformers.models.glm4v.modeling_glm4v.Glm4vTextMLP",
"defuser.modeling.glm4v.LinearGlm4vTextMLP",
)
],
# Split HF checkpoints that still store `gate_up_proj` as one fused tensor.
"checkpoint_mapping": [
WeightConverter(
source_patterns="mlp.gate_up_proj.weight",
target_patterns=[
"mlp.gate_proj.weight",
"mlp.up_proj.weight",
],
operations=[OwnedChunk(dim=0)],
),
],
},
}
20 changes: 20 additions & 0 deletions defuser/modeling/glm4v.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from torch import nn
from transformers.activations import ACT2FN


class LinearGlm4vTextMLP(nn.Module):
"""GLM4V text MLP with the fused gate/up projection split into two linears."""

def __init__(self, config):
super().__init__()
self.config = config
self.gate_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.up_proj = nn.Linear(config.hidden_size, config.intermediate_size, bias=False)
self.down_proj = nn.Linear(config.intermediate_size, config.hidden_size, bias=False)
self.activation_fn = ACT2FN[config.hidden_act]

def forward(self, hidden_states):
gate = self.gate_proj(hidden_states)
up = self.up_proj(hidden_states)
# Match the original fused `gate_up_proj.chunk(2, dim=-1)` activation path.
return self.down_proj(up * self.activation_fn(gate))
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "Defuser"
version = "0.0.12"
version = "0.0.13"
description = "Model defuser helper for HF Transformers."
readme = "README.md"
requires-python = ">=3.9"
Expand Down
101 changes: 101 additions & 0 deletions tests/test_convert_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@

import torch
from torch import nn
from transformers.core_model_loading import WeightConverter
from transformers.models.glm4v.configuration_glm4v import Glm4vConfig
from transformers.models.glm4v.modeling_glm4v import Glm4vForConditionalGeneration
from transformers.models.glm4_moe.modeling_glm4_moe import Glm4MoeConfig, Glm4MoeForCausalLM
from transformers.models.mixtral.configuration_mixtral import MixtralConfig
from transformers.models.mixtral.modeling_mixtral import MixtralForCausalLM
Expand All @@ -18,6 +21,7 @@
from transformers.models.qwen3_omni_moe.modeling_qwen3_omni_moe import Qwen3OmniMoeForConditionalGeneration

from defuser import convert_model, replace_fused_blocks
from defuser.checkpoint_ops import OwnedChunk
from defuser.modeling.replace_modules import ReplacementModuleBase, apply_replacements, materialize_model


Expand Down Expand Up @@ -143,6 +147,34 @@ def _tiny_glm4_moe_config():
first_k_dense_replace=-1, # Ensure that the first layer is Glm4MoeMoE.
)


def _tiny_glm4v_config():
return Glm4vConfig(
text_config={
"vocab_size": 128,
"hidden_size": 64,
"num_hidden_layers": 1,
"num_attention_heads": 4,
"num_key_value_heads": 1,
"head_dim": 16,
"intermediate_size": 128,
"hidden_act": "silu",
"pad_token_id": 0,
"bos_token_id": 1,
"eos_token_id": 2,
},
vision_config={
"hidden_size": 64,
"intermediate_size": 128,
"num_hidden_layers": 1,
"num_attention_heads": 4,
"num_channels": 3,
"image_size": 16,
"patch_size": 4,
"out_hidden_size": 64,
},
)

def _assert_unfused_expert_module(experts):
assert hasattr(experts, "0")
expert0 = getattr(experts, "0")
Expand Down Expand Up @@ -330,3 +362,72 @@ def test_glm4_moe():
assert not converted

_assert_unfused_expert_module(model.model.layers[0].mlp.experts)


def test_glm4v():
model_type = "glm4v"
replace_fused_blocks(model_type)

from defuser.modeling.glm4v import LinearGlm4vTextMLP

model = Glm4vForConditionalGeneration(_tiny_glm4v_config())
assert model.config.model_type == model_type

mlp = model.model.language_model.layers[0].mlp
assert isinstance(mlp, LinearGlm4vTextMLP)
assert hasattr(mlp, "gate_proj")
assert hasattr(mlp, "up_proj")
assert hasattr(mlp, "down_proj")
assert not hasattr(mlp, "gate_up_proj")

converted = convert_model(model, cleanup_original=False, max_layers=1)
assert not converted


def test_glm4v_checkpoint_mapping_splits_gate_up_proj():
from defuser.defuser import get_checkpoint_conversion_mapping

mapping = get_checkpoint_conversion_mapping("glm4v")
converter = next(
item
for item in mapping
if isinstance(item, WeightConverter) and item.source_patterns == ["mlp.gate_up_proj.weight"]
)
assert isinstance(converter.operations[0], OwnedChunk)

assert converter.target_patterns == [
"mlp.gate_proj.weight",
"mlp.up_proj.weight",
]

fused = torch.arange(48, dtype=torch.float32).reshape(6, 8)
split = converter.operations[0].convert(
{"mlp.gate_up_proj.weight": fused},
converter.source_patterns,
converter.target_patterns,
)

torch.testing.assert_close(split["mlp.gate_proj.weight"], fused[:3])
torch.testing.assert_close(split["mlp.up_proj.weight"], fused[3:])
assert split["mlp.gate_proj.weight"].data_ptr() != split["mlp.up_proj.weight"].data_ptr()


def test_glm4v_split_forward_matches_fused_math():
from defuser.modeling.glm4v import LinearGlm4vTextMLP

config = SimpleNamespace(hidden_size=8, intermediate_size=6, hidden_act="silu")
fused_gate_up = torch.randn(2 * config.intermediate_size, config.hidden_size, dtype=torch.float32)
down_proj = torch.randn(config.hidden_size, config.intermediate_size, dtype=torch.float32)
hidden_states = torch.randn(3, config.hidden_size, dtype=torch.float32)

mlp = LinearGlm4vTextMLP(config)
with torch.no_grad():
mlp.gate_proj.weight.copy_(fused_gate_up[: config.intermediate_size])
mlp.up_proj.weight.copy_(fused_gate_up[config.intermediate_size :])
mlp.down_proj.weight.copy_(down_proj)

fused_gate, fused_up = (hidden_states @ fused_gate_up.transpose(0, 1)).chunk(2, dim=-1)
expected = (torch.nn.functional.silu(fused_gate) * fused_up) @ down_proj.transpose(0, 1)

# The split module should exactly reproduce the original fused MLP math.
torch.testing.assert_close(mlp(hidden_states), expected)