We built a custom kernel that assigns different bit widths to individual experts in MoE models. It preserved model quality perfectly — and was too slow to use in production.
Introduction
After profiling expert activation patterns in Qwen3-235B and Qwen3.5-397B (see Article 1), we had a clear classification: some experts are critical (need 8-bit), most are standard (4-bit is fine), some can be aggressively compressed (2-bit), and many can be pruned entirely. The question was how to implement this.
The answer turned out to be harder than expected, because the standard MLX framework doesn't support per-expert bit widths. This article describes the constraint we hit, the custom kernel we built to work around it, the quality results (excellent), and the speed results (fatal for production).
The QuantizedSwitchLinear Constraint
MLX's QuantizedSwitchLinear is the standard layer type for quantized MoE expert weights. It stores all experts' quantized weights in a single tensor and dispatches to the selected experts during forward pass.
The critical constraint: all experts in a QuantizedSwitchLinear must share the same bit width. The quantized weight tensor has a fixed element size — you can't have expert 0's weights at 8-bit and expert 47's weights at 4-bit in the same tensor.
This means the standard quantization path gives you one choice per layer: all experts at 4-bit, or all experts at 8-bit. No mixing.
MixedBitSwitchGLU: Our Solution
We implemented MixedBitSwitchGLU, a drop-in replacement for the standard SwitchGLU (the gated linear unit used in Qwen's MoE blocks). The key idea: group experts by bit width into separate QuantizedSwitchLinear instances, then mask-and-combine the results.
Architecture
For a layer with experts at three bit widths (8, 4, 2) plus some pruned:
Input tokens + routing indices
|
┌─────┼──────────────┐
v v v
Group A Group B Group C
(8-bit) (4-bit) (2-bit)
23 exp 345 exp 80 exp
| | |
v v v
QSwitchL QSwitchL QSwitchL
| | |
v v v
Mask A Mask B Mask C
| | |
└─────┼──────────────┘
v
Sum (combine)
|
v
Output
Each group is a standard QuantizedSwitchLinear with uniform bits within the group. The mask zeroes out contributions from experts not in that group, so each token only gets the result from whichever group contains its routed expert.
The Mask-and-Combine Dispatch
def __call__(self, x, indices):
output = mx.zeros_like(expected_output)
for group in self.groups:
# Run ALL tokens through this group's QuantizedSwitchLinear
group_output = group.switch_linear(x, indices)
# Mask: only keep results for tokens routed to experts in this group
mask = build_group_mask(indices, group.expert_indices)
output = output + group_output * mask
return output
Every group processes every token, but the mask ensures only the correct group's output contributes for each token-expert pair. This is wasteful — we're doing 3x the computation if there are 3 groups — but it keeps everything as standard MLX tensor operations.
Memory-Efficient Conversion
Converting 512 experts at once would require dequantizing the entire layer (~26 GB peak for Qwen3.5). Instead, we convert per-group:
for group in bit_groups:
# Only dequantize THIS group's experts (e.g., 23 out of 512)
expert_weights = extract_experts(full_weight, group.indices)
quantized = mx.quantize(expert_weights, bits=group.bits)
group.switch_linear = QuantizedSwitchLinear.from_quantized(quantized)
mx.clear_cache() # Free the dequantized intermediates
Peak memory: ~3 GB per group instead of ~26 GB for the full layer.
Index Shape Gotcha
One non-obvious issue: the expert indices tensor can be 2D [batch, top_k] during generation but 3D [batch, seq_len, top_k] during prefill. Our mask construction had to be shape-agnostic:
# Wrong: assumes 2D
mask = (indices.unsqueeze(-1) == group_ids).any(-1)
# Right: works for both 2D and 3D
mask = mx.expand_dims(
mx.any(mx.equal(
mx.expand_dims(indices, -1),
group_ids.reshape((1,) * (indices.ndim - 1) + (-1,))
), axis=-1),
axis=-1
)
Results: Quality
Quality preservation was excellent across both models.
Qwen3-235B-A22B
Using the profiling-driven bit allocation (23 critical@8bit, 10,845 standard@4bit, 365 deprioritized@2bit, 799 pruned):
| Benchmark | ExpertQuant v2 | 4-bit Baseline | BF16 Official |
|---|---|---|---|
| MMLU-Pro | 76.7% | 72.1% | 75.7% |
| ARC-Challenge | 96.2% | 96.0% | — |
| GSM8K | 92.0% | 88.7% | 91.5% |
| HumanEval | 88.0% | 78.7% | 80.5% |
ExpertQuant v2 actually exceeded the BF16 reference scores on all measured benchmarks. Allocating 8 bits to critical experts and 2 bits to deprioritized ones produced better results than uniform 4-bit, likely because the critical experts handle the hardest decisions.
MoE weight savings: 112.4 GB → 104.1 GB (7.3% reduction).
Qwen3.5-397B-A17B
Using the profiling-driven bit allocation (879 critical@8bit, 22,466 standard@4bit, 1,813 deprioritized@2bit, 5,562 pruned):
- 15/15 collapse tests passed (0 warnings)
- MoE weight savings: 202.5 GB → 167.0 GB (17.5% reduction, saving 35.5 GB)
- Response quality: identical to baseline on all test prompts
Results: Speed (The Problem)
Here's where the approach fails. The mask-and-combine dispatch pattern introduces substantial overhead:
| Model | Standard Kernels | MixedBitSwitchGLU | Overhead |
|---|---|---|---|
| Qwen3-235B | ~16s/prompt | ~21s/prompt | ~30% |
| Qwen3.5-397B | ~8s/prompt | ~47s/prompt | ~490% |
For Qwen3.5-397B with 512 experts classified into 3 bit-width groups per layer, the overhead is catastrophic. Each forward pass through a MoE layer runs the computation 3 times (once per group), and the masking adds additional memory traffic.
The overhead is worse for 512-expert models because:
- More experts → more groups (typically 3 vs 2 for 128-expert models)
- Each
QuantizedSwitchLineargroup has more experts - The mask tensors are larger (512 entries vs 128)
- Mask-and-combine requires reading and writing the full output 3 times
At 47.3 seconds per prompt, the model is unusable for interactive applications. Users won't wait 47 seconds for each response when a uniform 4-bit model responds in 8 seconds.
Why Not Sorted Dispatch?
The mask-and-combine approach is the simplest but not the only option. An alternative is sorted dispatch:
- Sort tokens by which bit-width group their routed expert belongs to
- Process each group's tokens in a contiguous batch (no masking needed)
- Unsort the results back to original order
This would avoid the redundant computation (each token processed only once) but requires:
- Sorting and unsorting indices at every layer
- Handling variable batch sizes per group
- Potentially worse memory access patterns
We did not implement sorted dispatch. It remains the most promising optimization path if per-expert mixed-bit quantization is to be viable.
Version Comparison: More 8-Bit Layers ≠ Better Quality
We iterated through multiple versions of ExpertQuant for Qwen3-235B, varying how many layers receive 8-bit experts:
| Version | Strategy | 8-bit Layers | MMLU-Pro | ARC | GSM8K | HumanEval | Size |
|---|---|---|---|---|---|---|---|
| v2 | Critical wins | 17 | 76.7% | 96.2% | 92.0% | 88.0% | 149 GB |
| v3 | Relaxed threshold | 25 | 68.6% | 95.4% | 93.0% | 88.0% | 151 GB |
| v3b | Same as v2 | 17 | 76.7% | 96.2% | 92.0% | 88.0% | 149 GB |
| v4 | Aggressive 8-bit | 35 | 71.7% | 96.2% | 94.0% | 84.0% | 153 GB |
| v4b | Most layers 8-bit | 40 | 69.3% | 96.2% | 95.0% | 86.0% | 153 GB |
Key finding: v2 with only 17 layers at 8-bit scored the highest MMLU-Pro (76.7%). Versions with more 8-bit layers (v3, v4, v4b) scored lower on MMLU-Pro despite using more storage. The relationship between bit allocation and quality is non-monotonic — over-allocating high precision to non-critical layers can hurt by changing the relative precision balance.
v3b reproduced v2's scores exactly, confirming that 17 layers at 8-bit is the sweet spot for this model.
Conclusions
- Per-expert mixed-bit quantization preserves quality excellently. Profiling-guided bit allocation outperforms uniform quantization on every benchmark.
- Custom dispatch kernels are too slow for production. The 30-52% overhead from mask-and-combine dispatch is unacceptable for interactive use.
- The quality results suggest that expert importance varies significantly. Giving critical experts more precision and de-prioritized experts less precision yields better results than uniform allocation.
- More precision is not always better. The v2→v4 progression shows that blindly adding 8-bit layers can degrade quality.
- The speed problem is solvable in principle. Sorted dispatch or native kernel support for mixed-bit
QuantizedSwitchLinearwould eliminate the overhead. This requires framework-level changes (in MLX or PyTorch) rather than Python-level workarounds.
For our production deployment, we ultimately abandoned MixedBitSwitchGLU in favor of layer-level uniform quantization with standard kernels — trading the per-expert granularity for a 5.8x speed improvement. The quality cost of using layer-level instead of expert-level decisions was negligible.
Next in this series: Expert Pruning in MoE Models — When Dead Experts Aren't Dead — what happened when we removed 18% of experts, and why the model started speaking Chinese in the middle of Spanish translations.