MLX Quantization on Apple Silicon, Engineering Pitfalls and Workarounds
Apple Silicon

MLX Quantization on Apple Silicon, Engineering Pitfalls and Workarounds

February 2026 · Black Sheep AI Research

We quantized 400B+ parameter MoE models on a 512 GB Mac Studio using MLX. Along the way, we found a data corruption bug, a dtype footgun, and a GPU timeout trap.

Introduction

Apple's MLX framework is impressively capable for running large language models on Apple Silicon. The unified memory architecture means a 512 GB Mac Studio can load models that would need multi-GPU setups on NVIDIA hardware. But when you push MLX to its limits, quantizing and post-processing 400-billion parameter models with bfloat16 weights, you hit edge cases that aren't documented anywhere.

This article documents four engineering pitfalls we ran into during the ExpertQuant project, with concrete workarounds and code examples. Each one cost us hours of debugging and gigabytes of wasted computation. Hopefully this saves you the same pain.

Bug 1: mx.save_safetensors Corrupts bfloat16 Data

Severity: Critical. Causes complete model collapse.

The Symptom

After quantizing Qwen3.5-397B and post-processing the gate weights (setting pruned expert rows to -1e9), the model produced garbage:


User: What is 2+2?
Model: !!!!!!!!!!!!!!!!!!!!!!!!!!!!!!

Total collapse. Every response was either random tokens or repeated punctuation.

The Investigation

We spent hours ruling out causes:

The breakthrough: we tested without post-processing. Converted the model, skipped the gate weight masking step entirely, and it worked perfectly. 15/15 collapse tests passed.

The Root Cause

The post-processing code loaded each safetensors shard with mx.load, modified the gate weights, and saved with mx.save_safetensors:


# THIS CORRUPTS DATA:
tensors = mx.load(str(shard_file))
tensors[gate_key][expert_idx, :] = mx.array(-1e9)
mx.save_safetensors(str(shard_file), tensors)

We compared the byte contents of shards before and after the roundtrip. Every tensor in the shard was corrupted, not just the gate weights we modified. All weights, all biases, everything.

The bug is in mx.save_safetensors (mlx 0.29.3): when saving tensors loaded as bfloat16, the serialization produces wrong bytes. The corruption is completely silent. No error, no warning. The file is a valid safetensors file with correct shape and dtype metadata, but the numerical values are garbage.

The Fix

Use the safetensors Python library (from Hugging Face) instead of MLX for the roundtrip:


from safetensors import safe_open
from safetensors.torch import save_file
import torch

# Load with safetensors (preserves bfloat16 exactly)
tensors = {}
with safe_open(str(shard_file), framework="pt") as f:
    for key in f.keys():
        tensors[key] = f.get_tensor(key)

# Modify the gate weights
gate_weight = tensors[gate_key].clone()
gate_weight[expert_idx, :] = -1e9
tensors[gate_key] = gate_weight

# Save with safetensors (not MLX)
save_file(tensors, str(shard_file))

This preserves bfloat16 data perfectly through the roundtrip. After applying this fix, the model passed 15/15 collapse tests.

How to Detect

If you suspect this bug, compare shard hashes before and after:


import hashlib
before = hashlib.md5(open(shard, "rb").read()).hexdigest()
# ... load and save ...
after = hashlib.md5(open(shard, "rb").read()).hexdigest()
assert before == after  # Will FAIL if corruption occurred

Or compare specific tensor values:


original = mx.load("shard_backup.safetensors")
roundtripped = mx.load("shard_modified.safetensors")
for key in original:
    diff = mx.abs(original[key].astype(mx.float32) - roundtripped[key].astype(mx.float32))
    if mx.max(diff).item() > 0.01:
        print(f"CORRUPTED: {key}, max diff: {mx.max(diff).item()}")

Bug 2: float16 Causes GatedDeltaNet State Overflow

Severity: Critical. Causes model collapse on Qwen3.5 and other GatedDeltaNet models.

The Symptom

Same as Bug 1: garbage output after conversion. But this time, the conversion used dtype="float16" instead of dtype="bfloat16".

The Root Cause

Qwen3.5-397B uses GatedDeltaNet attention (not standard multi-head attention). GatedDeltaNet maintains recurrent state that accumulates across the sequence. The accumulated values can exceed the range of float16 (plus/minus 65,504), causing overflow, then NaN, then collapse.

bfloat16 shares the same exponent range as float32 (plus/minus 3.4 x 10 to the 38th), so it handles the accumulation without overflow.

The relevant parameter is A_log in each attention layer. It controls the decay rate of the recurrent state. During conversion, the sanitization step computes v + 1.0 which upcasts A_log to float32. But during inference, intermediate computations in float16 can still overflow.

The Fix

Always use dtype="bfloat16" when converting Qwen3.5 and other models with GatedDeltaNet or similar recurrent attention:


from mlx_lm import convert

convert(
    hf_path=source_model,
    mlx_path=output_path,
    quantize=True,
    q_bits=4,
    q_group_size=128,
    dtype="bfloat16",  # NOT "float16"
)

How to Detect

If your converted model produces garbage, check the architecture:

You can check programmatically:


import json
config = json.load(open(f"{model_path}/config.json"))
model_type = config.get("model_type", "")
if "gated" in model_type.lower() or "delta" in model_type.lower():
    print("WARNING: Use bfloat16 for this model type")

Bug 3: GPU Timeout on Large Model Lazy Loading

Severity: Medium. Server crashes on first request.

The Symptom

Starting mlx_lm.server with a 237 GB model:


mlx_lm.server --model ./qwen35-397b-hybrid-v1 --port 8888

The server starts and reports "Starting httpd at 127.0.0.1 on port 8888...". The first HTTP request triggers model loading (104 safetensors files). Then:


libc++abi: terminating due to uncaught exception of type std::runtime_error:
[METAL] Command buffer execution failed:
Caused GPU Timeout Error (00000002:kIOGPUCommandBufferCallbackErrorTimeout)

The Root Cause

mlx_lm.server uses lazy loading. The model weights get pulled from disk on first use, not at startup. When the first request comes in, it triggers loading 237 GB from 104 safetensors files into unified memory, then immediately tries the first forward pass.

The Metal GPU has a command buffer timeout. If that first forward pass takes too long (because the system is still paging in weights), the GPU kills the command buffer.

The Fix

Use mlx_lm.load directly instead of the server for large models:


from mlx_lm import load, generate
from mlx_lm.sample_utils import make_sampler

# Eager loading, all 237 GB loaded upfront
model, tokenizer = load("./qwen35-397b-hybrid-v1")  # Takes ~38 seconds
sampler = make_sampler(temp=0.0)

# First inference works immediately
result = generate(model, tokenizer, prompt="Hello", max_tokens=50, sampler=sampler)
# Completes in 1.9 seconds

If you must use the server, you could try warming it up with a dummy request after loading. But the root issue is that lazy loading plus the first forward pass exceeds the GPU timeout for very large models.

Threshold

In our testing on a Mac Studio M3 Ultra (512 GB):

Bug 4: Python Dunder Method Resolution Bypasses Instance Patching

Severity: Low (development-time only). Causes profiling hooks to silently not fire.

The Symptom

Trying to hook into MoE layer __call__ methods for activation profiling:


import types

def hooked_call(self, x):
    # Capture activations...
    return original_call(self, x)

# This appears to work:
layer.__call__ = types.MethodType(hooked_call, layer)

# But the hook NEVER fires during inference!

The Root Cause

Python's data model says that "special methods" (dunder methods like __call__, __getattr__, __len__) are looked up on the type (class), not the instance. When you set layer.__call__, you're adding an instance attribute. But Python's method resolution for layer(x) goes through type(layer).__call__, which finds the original class method.

This is documented in Python's data model but trips people up all the time: https://docs.python.org/3/reference/datamodel.html#special-lookup

The Fix

Swap the instance's __class__ to a dynamically created subclass:


original_class = layer.__class__

def hooked_call(self, x):
    # Your profiling code here
    activations.append(capture_routing(self, x))
    return original_class.__call__(self, x)

HookedClass = type(
    f"Hooked{original_class.__name__}",
    (original_class,),
    {"__call__": hooked_call}
)

layer.__class__ = HookedClass

Now type(layer).__call__ resolves to hooked_call because the instance's class has been swapped. It's safe, reversible (layer.__class__ = original_class to unhook), and works for all dunder methods.

Bonus: MLX bfloat16 Array Gotchas

Two smaller issues that bit us during development:

No .copy() Method

MLX arrays don't have a .copy() method. Multiply by 1 instead:


# WRONG:
arr_copy = arr.copy()  # AttributeError

# RIGHT:
arr_copy = arr * 1

Can't Convert bfloat16 to NumPy Directly


import mlx.core as mx
import numpy as np

arr = mx.array([1.0, 2.0], dtype=mx.bfloat16)

# WRONG:
np_arr = np.array(arr)  # TypeError or wrong values

# RIGHT:
np_arr = np.array(arr.astype(mx.float32))

mx.metal.clear_cache is Deprecated


# WRONG (deprecated):
mx.metal.clear_cache()

# RIGHT:
mx.clear_cache()

Summary

Bug Severity Symptom Fix
mx.save_safetensors bfloat16 Critical Model collapse after post-processing Use safetensors.torch.save_file
float16 with GatedDeltaNet Critical Model collapse after conversion Use dtype="bfloat16"
GPU timeout on lazy loading Medium Server crash on first request Use mlx_lm.load directly
Dunder method resolution Low Profiling hooks don't fire Use __class__ swapping

These bugs are specific to mlx 0.29.3 and mlx_lm 0.29.1. Future versions may fix some of them. The bfloat16 serialization bug in particular is a correctness issue that should be fixed at the framework level.

If you're working with large MoE models on Apple Silicon, here's what we'd recommend:


Next in this series: Layer-Level vs Expert-Level Granularity in MoE Quantization, why giving every expert its own bit width barely beats treating entire layers uniformly.

Read the Full Paper

The full MoE expert quantization paper, covering expert activation profiling, per-expert mixed-bit allocation, and evaluation across 512-expert architectures, is available on our HuggingFace:

MoE Expert Quantization: Per-Expert Mixed-Precision for Mixture-of-Experts Models, Full Paper

huggingface.co/spaces/baa-ai/MoE-Expert-Quantization

Licensed under CC BY-NC-ND 4.0

← Previous: Expert Pruning in MoE Models, When Dead Experts A...

Continue Reading

Related research from our team.

MLX Quantization on Apple Silicon: How RAM Turns a Mac into a Model Compression Lab
RAM Research

MLX Quantization on Apple Silicon: How RAM Turns a Mac into a Model Compression Lab

Every RAM result was produced on a single M2 Ultra via MLX. No GPUs, no cloud, no calibration data.

RAM on Apple Silicon: Running 400B Parameter Models on a Single Mac
RAM Research

RAM on Apple Silicon: Running 400B Parameter Models on a Single Mac

How RAM compression enables frontier-scale models to run entirely on Apple Silicon hardware.

View All Research