We quantized 400B+ parameter MoE models on a 512 GB Mac Studio using MLX. Along the way, we found a data corruption bug, a dtype footgun, and a GPU timeout trap.
Introduction
Apple's MLX framework is impressively capable for running large language models on Apple Silicon. The unified memory architecture means a 512 GB Mac Studio can load models that would need multi-GPU setups on NVIDIA hardware. But when you push MLX to its limits, quantizing and post-processing 400-billion parameter models with bfloat16 weights, you hit edge cases that aren't documented anywhere.
This article documents four engineering pitfalls we ran into during the ExpertQuant project, with concrete workarounds and code examples. Each one cost us hours of debugging and gigabytes of wasted computation. Hopefully this saves you the same pain.
Bug 1: mx.save_safetensors Corrupts bfloat16 Data
Severity: Critical. Causes complete model collapse.
The Symptom
After quantizing Qwen3.5-397B and post-processing the gate weights (setting pruned expert rows to -1e9), the model produced garbage:
User: What is 2+2?
Model: !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
Total collapse. Every response was either random tokens or repeated punctuation.
The Investigation
We spent hours ruling out causes:
- Verified gate masking was correct (5,562 experts masked, correct layer assignments) ✓
- Verified quantization config had correct per-module overrides (351 overrides) ✓
- Compared configs with baseline model ✓
- Reconverted with bfloat16 instead of float16, still collapsed ✗
The breakthrough: we tested without post-processing. Converted the model, skipped the gate weight masking step entirely, and it worked perfectly. 15/15 collapse tests passed.
The Root Cause
The post-processing code loaded each safetensors shard with mx.load, modified the gate weights, and saved with mx.save_safetensors:
# THIS CORRUPTS DATA:
tensors = mx.load(str(shard_file))
tensors[gate_key][expert_idx, :] = mx.array(-1e9)
mx.save_safetensors(str(shard_file), tensors)
We compared the byte contents of shards before and after the roundtrip. Every tensor in the shard was corrupted, not just the gate weights we modified. All weights, all biases, everything.
The bug is in mx.save_safetensors (mlx 0.29.3): when saving tensors loaded as bfloat16, the serialization produces wrong bytes. The corruption is completely silent. No error, no warning. The file is a valid safetensors file with correct shape and dtype metadata, but the numerical values are garbage.
The Fix
Use the safetensors Python library (from Hugging Face) instead of MLX for the roundtrip:
from safetensors import safe_open
from safetensors.torch import save_file
import torch
# Load with safetensors (preserves bfloat16 exactly)
tensors = {}
with safe_open(str(shard_file), framework="pt") as f:
for key in f.keys():
tensors[key] = f.get_tensor(key)
# Modify the gate weights
gate_weight = tensors[gate_key].clone()
gate_weight[expert_idx, :] = -1e9
tensors[gate_key] = gate_weight
# Save with safetensors (not MLX)
save_file(tensors, str(shard_file))
This preserves bfloat16 data perfectly through the roundtrip. After applying this fix, the model passed 15/15 collapse tests.
How to Detect
If you suspect this bug, compare shard hashes before and after:
import hashlib
before = hashlib.md5(open(shard, "rb").read()).hexdigest()
# ... load and save ...
after = hashlib.md5(open(shard, "rb").read()).hexdigest()
assert before == after # Will FAIL if corruption occurred
Or compare specific tensor values:
original = mx.load("shard_backup.safetensors")
roundtripped = mx.load("shard_modified.safetensors")
for key in original:
diff = mx.abs(original[key].astype(mx.float32) - roundtripped[key].astype(mx.float32))
if mx.max(diff).item() > 0.01:
print(f"CORRUPTED: {key}, max diff: {mx.max(diff).item()}")
Bug 2: float16 Causes GatedDeltaNet State Overflow
Severity: Critical. Causes model collapse on Qwen3.5 and other GatedDeltaNet models.
The Symptom
Same as Bug 1: garbage output after conversion. But this time, the conversion used dtype="float16" instead of dtype="bfloat16".
The Root Cause
Qwen3.5-397B uses GatedDeltaNet attention (not standard multi-head attention). GatedDeltaNet maintains recurrent state that accumulates across the sequence. The accumulated values can exceed the range of float16 (plus/minus 65,504), causing overflow, then NaN, then collapse.
bfloat16 shares the same exponent range as float32 (plus/minus 3.4 x 10 to the 38th), so it handles the accumulation without overflow.
The relevant parameter is A_log in each attention layer. It controls the decay rate of the recurrent state. During conversion, the sanitization step computes v + 1.0 which upcasts A_log to float32. But during inference, intermediate computations in float16 can still overflow.
The Fix
Always use dtype="bfloat16" when converting Qwen3.5 and other models with GatedDeltaNet or similar recurrent attention:
from mlx_lm import convert
convert(
hf_path=source_model,
mlx_path=output_path,
quantize=True,
q_bits=4,
q_group_size=128,
dtype="bfloat16", # NOT "float16"
)
How to Detect
If your converted model produces garbage, check the architecture:
- Standard transformer attention (Qwen3, Llama, etc.): float16 is fine
- GatedDeltaNet, RWKV, Mamba, or any recurrent/state-space attention: use bfloat16
You can check programmatically:
import json
config = json.load(open(f"{model_path}/config.json"))
model_type = config.get("model_type", "")
if "gated" in model_type.lower() or "delta" in model_type.lower():
print("WARNING: Use bfloat16 for this model type")
Bug 3: GPU Timeout on Large Model Lazy Loading
Severity: Medium. Server crashes on first request.
The Symptom
Starting mlx_lm.server with a 237 GB model:
mlx_lm.server --model ./qwen35-397b-hybrid-v1 --port 8888
The server starts and reports "Starting httpd at 127.0.0.1 on port 8888...". The first HTTP request triggers model loading (104 safetensors files). Then:
libc++abi: terminating due to uncaught exception of type std::runtime_error:
[METAL] Command buffer execution failed:
Caused GPU Timeout Error (00000002:kIOGPUCommandBufferCallbackErrorTimeout)
The Root Cause
mlx_lm.server uses lazy loading. The model weights get pulled from disk on first use, not at startup. When the first request comes in, it triggers loading 237 GB from 104 safetensors files into unified memory, then immediately tries the first forward pass.
The Metal GPU has a command buffer timeout. If that first forward pass takes too long (because the system is still paging in weights), the GPU kills the command buffer.
The Fix
Use mlx_lm.load directly instead of the server for large models:
from mlx_lm import load, generate
from mlx_lm.sample_utils import make_sampler
# Eager loading, all 237 GB loaded upfront
model, tokenizer = load("./qwen35-397b-hybrid-v1") # Takes ~38 seconds
sampler = make_sampler(temp=0.0)
# First inference works immediately
result = generate(model, tokenizer, prompt="Hello", max_tokens=50, sampler=sampler)
# Completes in 1.9 seconds
If you must use the server, you could try warming it up with a dummy request after loading. But the root issue is that lazy loading plus the first forward pass exceeds the GPU timeout for very large models.
Threshold
In our testing on a Mac Studio M3 Ultra (512 GB):
- Models up to ~200 GB:
mlx_lm.serverworks fine - Models above ~230 GB: GPU timeout on first request
- Direct
mlx_lm.load: Works for any size that fits in memory
Bug 4: Python Dunder Method Resolution Bypasses Instance Patching
Severity: Low (development-time only). Causes profiling hooks to silently not fire.
The Symptom
Trying to hook into MoE layer __call__ methods for activation profiling:
import types
def hooked_call(self, x):
# Capture activations...
return original_call(self, x)
# This appears to work:
layer.__call__ = types.MethodType(hooked_call, layer)
# But the hook NEVER fires during inference!
The Root Cause
Python's data model says that "special methods" (dunder methods like __call__, __getattr__, __len__) are looked up on the type (class), not the instance. When you set layer.__call__, you're adding an instance attribute. But Python's method resolution for layer(x) goes through type(layer).__call__, which finds the original class method.
This is documented in Python's data model but trips people up all the time: https://docs.python.org/3/reference/datamodel.html#special-lookup
The Fix
Swap the instance's __class__ to a dynamically created subclass:
original_class = layer.__class__
def hooked_call(self, x):
# Your profiling code here
activations.append(capture_routing(self, x))
return original_class.__call__(self, x)
HookedClass = type(
f"Hooked{original_class.__name__}",
(original_class,),
{"__call__": hooked_call}
)
layer.__class__ = HookedClass
Now type(layer).__call__ resolves to hooked_call because the instance's class has been swapped. It's safe, reversible (layer.__class__ = original_class to unhook), and works for all dunder methods.
Bonus: MLX bfloat16 Array Gotchas
Two smaller issues that bit us during development:
No .copy() Method
MLX arrays don't have a .copy() method. Multiply by 1 instead:
# WRONG:
arr_copy = arr.copy() # AttributeError
# RIGHT:
arr_copy = arr * 1
Can't Convert bfloat16 to NumPy Directly
import mlx.core as mx
import numpy as np
arr = mx.array([1.0, 2.0], dtype=mx.bfloat16)
# WRONG:
np_arr = np.array(arr) # TypeError or wrong values
# RIGHT:
np_arr = np.array(arr.astype(mx.float32))
mx.metal.clear_cache is Deprecated
# WRONG (deprecated):
mx.metal.clear_cache()
# RIGHT:
mx.clear_cache()
Summary
| Bug | Severity | Symptom | Fix |
|---|---|---|---|
| mx.save_safetensors bfloat16 | Critical | Model collapse after post-processing | Use safetensors.torch.save_file |
| float16 with GatedDeltaNet | Critical | Model collapse after conversion | Use dtype="bfloat16" |
| GPU timeout on lazy loading | Medium | Server crash on first request | Use mlx_lm.load directly |
| Dunder method resolution | Low | Profiling hooks don't fire | Use __class__ swapping |
These bugs are specific to mlx 0.29.3 and mlx_lm 0.29.1. Future versions may fix some of them. The bfloat16 serialization bug in particular is a correctness issue that should be fixed at the framework level.
If you're working with large MoE models on Apple Silicon, here's what we'd recommend:
- Always verify your conversion with a quick inference test before spending hours on evaluation
- Use bfloat16 as your default dtype unless you have a specific reason for float16
- Use the safetensors library (not MLX) for any load-modify-save workflows
- Test with direct model loading before deploying to a server
Next in this series: Layer-Level vs Expert-Level Granularity in MoE Quantization, why giving every expert its own bit width barely beats treating entire layers uniformly.
Read the Full Paper
The full MoE expert quantization paper, covering expert activation profiling, per-expert mixed-bit allocation, and evaluation across 512-expert architectures, is available on our HuggingFace:
MoE Expert Quantization: Per-Expert Mixed-Precision for Mixture-of-Experts Models, Full Paper
huggingface.co/spaces/baa-ai/MoE-Expert-QuantizationLicensed under CC BY-NC-ND 4.0