Fix fp_qmv_impl small-output-dim branch using raw fp8 scale byte#3763
Open
jax-0n-git wants to merge 1 commit into
Open
Fix fp_qmv_impl small-output-dim branch using raw fp8 scale byte#3763jax-0n-git wants to merge 1 commit into
jax-0n-git wants to merge 1 commit into
Conversation
The out_vec_size < 8 branch's full-block loop loaded the fp8 scale as a raw byte and passed it to qdot without dequantize_scale, so fp-quantized (mxfp4/mxfp8/nvfp4) matvec with output dim < 8 multiplied by the raw e8m0/e4m3 byte instead of the decoded scale (gross error, ~1e2-1e4 relative). Every other fp scale-load site decodes it, including the remainder loop of this same branch. Add test_fp_qmv_small_non_multiples covering the fp modes at output dim < 8 with K large enough to run the full-block loop (the existing fp tests use output dim >= 8 or K below block_size). Co-Authored-By: Claude Opus 4.8 <noreply@anthropic.com>
b8079fd to
ce93101
Compare
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Fixes #3762.
Problem
In
fp_qmv_impl(mlx/backend/metal/kernels/fp_quantized.h), theout_vec_size < 8branch's full-block loop loads the fp8 scale as a raw byte and passes it straight toqdot, skippingdequantize_scale:Every other fp scale-load site in the file decodes the byte first — including the remainder loop of this same branch (:464) and the
>= 8branch loops (:306/:368/:495/:514):U s = dequantize_scale<U, group_size>(sl[0]);qdotapplies the scale directly (scale * accum); it does not decode internally. So fp-quantized (mxfp4/mxfp8/nvfp4) mat-vec with output dim< 8multiplies by the literal0–255exponent/mantissa byte instead of the real fp8 scale, producing grossly wrong output.The line looks copy-pasted from the integer
quantized.hqmv_impl, whereU s = sl[0];is correct because therescalesis a realconst device T*float array. Infp_quantized.h,scalesisconst device uint8_t*(packed e8m0/e4m3) and must be decoded.When it bites
Both conditions are required:
N < 8→ enters theout_vec_size < (num_simdgroups * results_per_simdgroup) = 8branch, andK > block_size(values_per_thread * SIMD_SIZE; 256 for 4-bit) → the full-block loop at the buggy line actually runs. (ForK ≤ block_sizeonly the correct remainder loop runs.)This is why it slipped past CI:
test_fp_qmvuses output dim ≥ 8 (wrong branch), andtest_qmv_small_non_multiplesusesK = 32(belowblock_size, so only the correct remainder loop executes) and doesn't covermxfp4.Repro on
0.31.2(M5 Max) — relative error vsdequantize+matmul,K=512:The integer
affinepath is clean at everyN, isolating the defect to the fp scale-decode.Fix
Decode the scale, matching the four sibling loops:
Test
Adds
test_fp_qmv_small_non_multiplesinpython/tests/test_quantized.py: fp modes {mxfp4, mxfp8, nvfp4} × M∈{1,2} × N∈{1,2,3,5,7} atK=512(forcing at least one full block), inputs normalized by1/sqrt(K)like the other large-dim tests, asserting(y_q - y_hat).abs().max() < 1e-3.Verified locally on an M5 Max (Metal) by building MLX from source at the current
main:>= 8path (~5e-4 relative).test_quantized.pysuite (32 tests) passes with the fix — no regressions.pre-commit(clang-format, black, isort) is clean on both files.