We quantized 400B+ parameter MoE models on a 512 GB Mac Studio using MLX. Along the way, we found a data corruption bug, a dtype footgun, and a GPU timeout trap.
Introduction
Apple's MLX framework is remarkably capable for running large language models on Apple Silicon. The unified memory architecture means a 512 GB Mac Studio can load models that would require multi-GPU setups on NVIDIA hardware. But when you push MLX to its limits — quantizing and post-processing models with 400 billion parameters and bfloat16 weights — you encounter edge cases that aren't in any documentation.
This article documents four engineering pitfalls we encountered during the ExpertQuant project, along with concrete workarounds and code examples. Each issue cost us hours of debugging and gigabytes of wasted computation. We hope this saves you the same.
Bug 1: mx.save_safetensors Corrupts bfloat16 Data
Severity: Critical. Causes complete model collapse.
The Symptom
After quantizing Qwen3.5-397B and post-processing the gate weights (setting pruned expert rows to -1e9), the model produced garbage output:
User: What is 2+2?
Model: !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
The model had collapsed completely. Every response was either random tokens or repeated punctuation.
The Investigation
We spent hours ruling out causes:
- Verified gate masking was correct (5,562 experts masked, correct layer assignments) ✓
- Verified quantization config had correct per-module overrides (351 overrides) ✓
- Compared configs with baseline model ✓
- Reconverted with bfloat16 instead of float16 — still collapsed ✗
The breakthrough came from testing without post-processing: we converted the model, skipped the gate weight masking step entirely, and the model worked perfectly. 15/15 collapse tests passed.
The Root Cause
The post-processing code was loading each safetensors shard with mx.load, modifying the gate weights, and saving with mx.save_safetensors:
# THIS CORRUPTS DATA:
tensors = mx.load(str(shard_file))
tensors[gate_key][expert_idx, :] = mx.array(-1e9)
mx.save_safetensors(str(shard_file), tensors)
We compared the byte contents of shards before and after this load→save roundtrip. Every single tensor in the shard was corrupted — not just the gate weights we modified, but all weights, all biases, everything.
The bug is in mx.save_safetensors (mlx 0.29.3): when saving tensors that were loaded as bfloat16, the serialization produces incorrect bytes. The corruption is silent — no error, no warning, the file is a valid safetensors file with the right shape and dtype metadata, but the numerical values are wrong.
The Fix
Use the safetensors Python library (from Hugging Face) instead of MLX for the roundtrip:
from safetensors import safe_open
from safetensors.torch import save_file
import torch
# Load with safetensors (preserves bfloat16 exactly)
tensors = {}
with safe_open(str(shard_file), framework="pt") as f:
for key in f.keys():
tensors[key] = f.get_tensor(key)
# Modify the gate weights
gate_weight = tensors[gate_key].clone()
gate_weight[expert_idx, :] = -1e9
tensors[gate_key] = gate_weight
# Save with safetensors (not MLX)
save_file(tensors, str(shard_file))
This preserves bfloat16 data perfectly through the roundtrip. After applying this fix, the model passed 15/15 collapse tests.
How to Detect
If you suspect this bug, compare shard sizes before and after:
import hashlib
before = hashlib.md5(open(shard, "rb").read()).hexdigest()
# ... load and save ...
after = hashlib.md5(open(shard, "rb").read()).hexdigest()
assert before == after # Will FAIL if corruption occurred
Or compare specific tensor values:
original = mx.load("shard_backup.safetensors")
roundtripped = mx.load("shard_modified.safetensors")
for key in original:
diff = mx.abs(original[key].astype(mx.float32) - roundtripped[key].astype(mx.float32))
if mx.max(diff).item() > 0.01:
print(f"CORRUPTED: {key}, max diff: {mx.max(diff).item()}")
Bug 2: float16 Causes GatedDeltaNet State Overflow
Severity: Critical. Causes model collapse on Qwen3.5 and other GatedDeltaNet models.
The Symptom
Identical to Bug 1 — garbage output after conversion. But this time, the conversion used dtype="float16" instead of dtype="bfloat16".
The Root Cause
Qwen3.5-397B uses GatedDeltaNet attention (not standard multi-head attention). GatedDeltaNet maintains recurrent state that accumulates across the sequence. The accumulated values can exceed the range of float16 (±65,504), causing overflow → NaN → collapse.
bfloat16 has the same exponent range as float32 (±3.4 × 10³⁸), so it handles the accumulation without overflow.
The relevant parameter is A_log in each attention layer, which controls the decay rate of the recurrent state. During conversion, the sanitization step computes v + 1.0 which upcasts A_log to float32. But during inference, intermediate computations in float16 can still overflow.
The Fix
Always use dtype="bfloat16" when converting Qwen3.5 and other models with GatedDeltaNet or similar recurrent attention:
from mlx_lm import convert
convert(
hf_path=source_model,
mlx_path=output_path,
quantize=True,
q_bits=4,
q_group_size=128,
dtype="bfloat16", # NOT "float16"
)
How to Detect
If your converted model produces garbage, check the model architecture:
- Standard transformer attention (Qwen3, Llama, etc.): float16 is fine
- GatedDeltaNet, RWKV, Mamba, or any recurrent/state-space attention: use bfloat16
You can check programmatically:
import json
config = json.load(open(f"{model_path}/config.json"))
model_type = config.get("model_type", "")
if "gated" in model_type.lower() or "delta" in model_type.lower():
print("WARNING: Use bfloat16 for this model type")
Bug 3: GPU Timeout on Large Model Lazy Loading
Severity: Medium. Server crashes on first request.
The Symptom
Starting mlx_lm.server with a 237 GB model:
mlx_lm.server --model ./qwen35-397b-hybrid-v1 --port 8888
The server starts and reports "Starting httpd at 127.0.0.1 on port 8888...". The first HTTP request triggers model loading (104 safetensors files). Then:
libc++abi: terminating due to uncaught exception of type std::runtime_error:
[METAL] Command buffer execution failed:
Caused GPU Timeout Error (00000002:kIOGPUCommandBufferCallbackErrorTimeout)
The Root Cause
mlx_lm.server uses lazy loading — the model weights are loaded from disk on first use, not at startup. When the first request arrives, it triggers loading 237 GB of data from 104 safetensors files into unified memory, then immediately attempts the first forward pass.
The Metal GPU has a command buffer timeout. If the first forward pass takes too long (because the system is still paging in model weights), the GPU kills the command buffer.
The Fix
Use mlx_lm.load directly instead of the server for large models:
from mlx_lm import load, generate
from mlx_lm.sample_utils import make_sampler
# Eager loading — all 237 GB loaded upfront
model, tokenizer = load("./qwen35-397b-hybrid-v1") # Takes ~38 seconds
sampler = make_sampler(temp=0.0)
# First inference works immediately
result = generate(model, tokenizer, prompt="Hello", max_tokens=50, sampler=sampler)
# Completes in 1.9 seconds
If you must use the server, you could potentially warm it up with a dummy request after loading, but the fundamental issue is that the lazy loading + first forward pass exceeds the GPU timeout for very large models.
Threshold
In our testing on a Mac Studio M3 Ultra (512 GB):
- Models up to ~200 GB:
mlx_lm.serverworks fine - Models above ~230 GB: GPU timeout on first request
- Direct
mlx_lm.load: Works for any size that fits in memory
Bug 4: Python Dunder Method Resolution Bypasses Instance Patching
Severity: Low (development-time only). Causes profiling hooks to silently not fire.
The Symptom
When trying to hook into MoE layer __call__ methods for activation profiling:
import types
def hooked_call(self, x):
# Capture activations...
return original_call(self, x)
# This appears to work:
layer.__call__ = types.MethodType(hooked_call, layer)
# But the hook NEVER fires during inference!
The Root Cause
Python's data model specifies that "special methods" (dunder methods like __call__, __getattr__, __len__, etc.) are looked up on the type (class), not the instance. When you set layer.__call__, you're adding an instance attribute, but Python's method resolution for layer(x) goes through type(layer).__call__, which finds the original class method.
This is documented in Python's data model but is a common source of confusion: https://docs.python.org/3/reference/datamodel.html#special-lookup
The Fix
Swap the instance's __class__ to a dynamically created subclass:
original_class = layer.__class__
def hooked_call(self, x):
# Your profiling code here
activations.append(capture_routing(self, x))
return original_class.__call__(self, x)
HookedClass = type(
f"Hooked{original_class.__name__}",
(original_class,),
{"__call__": hooked_call}
)
layer.__class__ = HookedClass
Now type(layer).__call__ resolves to hooked_call because the instance's class has been swapped. This is safe, reversible (layer.__class__ = original_class to unhook), and works for all dunder methods.
Bonus: MLX bfloat16 Array Gotchas
Two smaller issues that caused bugs during development:
No .copy() Method
MLX arrays don't have a .copy() method. Use multiplication by 1:
# WRONG:
arr_copy = arr.copy() # AttributeError
# RIGHT:
arr_copy = arr * 1
Can't Convert bfloat16 to NumPy Directly
import mlx.core as mx
import numpy as np
arr = mx.array([1.0, 2.0], dtype=mx.bfloat16)
# WRONG:
np_arr = np.array(arr) # TypeError or wrong values
# RIGHT:
np_arr = np.array(arr.astype(mx.float32))
mx.metal.clear_cache is Deprecated
# WRONG (deprecated):
mx.metal.clear_cache()
# RIGHT:
mx.clear_cache()
Summary
| Bug | Severity | Symptom | Fix |
|---|---|---|---|
| mx.save_safetensors bfloat16 | Critical | Model collapse after post-processing | Use safetensors.torch.save_file |
| float16 with GatedDeltaNet | Critical | Model collapse after conversion | Use dtype="bfloat16" |
| GPU timeout on lazy loading | Medium | Server crash on first request | Use mlx_lm.load directly |
| Dunder method resolution | Low | Profiling hooks don't fire | Use __class__ swapping |
These bugs are specific to mlx 0.29.3 and mlx_lm 0.29.1. Future versions may fix some of them. The bfloat16 serialization bug in particular is a correctness issue that should be addressed at the framework level.
If you're working with large MoE models on Apple Silicon, we recommend:
- Always verify your conversion with a quick inference test before spending hours on evaluation
- Use bfloat16 as your default dtype unless you have a specific reason for float16
- Use the safetensors library (not MLX) for any load→modify→save workflows
- Test with direct model loading before deploying to a server
Next in this series: Layer-Level vs Expert-Level Granularity in MoE Quantization — why giving every expert its own bit width barely beats treating entire layers uniformly.