FastKMeans Suite is an enterprise-grade, GPU-accelerated framework that bridges the gap between the accuracy of traditional Metric Learning (like kNN or SVM) and the massive scalability of Deep Learning.
By condensing infinite datasets into an intelligent, differentiable grid of topological prototypes, FastKMeans drops inference complexity from
- Installation
- Available Models
- The Core Philosophy
- Master API & Stream Learning
- Global Parameters Configuration
- Model-Specific Parameters
- Advanced Features Deep Dive
- Quick Start Example
To install the framework locally in editable mode (so your code updates apply immediately):
git clone https://github.com/yourusername/FastKMeans.git
cd FastKMeans
pip install -e .Requirements: torch >= 2.0.0, numpy, scikit-learn, tqdm.
(Optional for extreme scaling: faiss-cpu or faiss-gpu)
FastKMeans provides four highly specialized classes inheriting from BaseFastKMeans.
| Model Class | Task | Loss Function | Key Architectural Features |
|---|---|---|---|
FastKMeansClusterer |
Unsupervised Clustering | Feature Distance | Evaluates raw metric distances. Uses diversity_reg to prevent mode collapse. |
FastKMeansClassifier |
Multiclass Classification | Cross-Entropy | Supports Negative Sampling (Sampled Softmax) for extreme >100k class spaces. Features LVQ Repulsion to push boundaries away from wrong classes. |
FastMultiLabelKMeansClassifier |
Multi-Label / Tagging | Asymmetric Loss (ASL) | Builds Independent Sub-topologies per tag. Uses ASL to combat extreme sparsity. Includes zero-hyperparameter gap prediction strategy. |
FastKMeansRegressor |
Continuous / Vector Regression | Mean Squared Error (MSE) | Uses 2-Stage Stratification (clusters Y, then spawns X prototypes inside). Supports multi-dimensional vector outputs. |
- Phase 1: Hard Stream Learning (EMA)
Data flows in infinitely viafit_batch(gradient_step=False). Samples are strictly assigned to the nearest centroid (creating perfect Voronoi cells). Centroids shift geometrically using Exponential Moving Averages (EMA). - Phase 2: Differentiable Fine-tuning
The topological grid is unfrozen. Usingfit_batch(gradient_step=True), theAdamoptimizer physically moves the centroids via gradient descent to directly minimize the Loss function. - Phase 3: Soft Inference (Smart kNN)
During.predict(), targets of thetop_kclosest prototypes are smoothly blended using Inverse Distance Weighting (IDW).
You are no longer constrained by RAM. The entire API revolves around the fit_batch master function, allowing you to stream billions of rows directly from a database.
model = FastKMeansClassifier()
# 1. Build the Grid (Phase 1)
for batch_X, batch_y, batch_sw, batch_mask in infinite_stream:
model.fit_batch(batch_X, batch_y, sample_weight=batch_sw, feature_mask=batch_mask, gradient_step=False)
# 2. Refine with Gradients (Phase 2)
model.find_learning_rate(X_sample, y_sample) # Auto-find optimal LR
for batch_X, batch_y, batch_sw, batch_mask in infinite_stream:
model.fit_batch(batch_X, batch_y, sample_weight=batch_sw, feature_mask=batch_mask, gradient_step=True)(The helper functions .fit() and .finetune() are convenient wrappers around fit_batch for static datasets that fit in RAM).
These parameters are available in the __init__ of all models.
| Parameter | Type | Default | Description |
|---|---|---|---|
distance |
str |
'cosine' |
Math metric. Options: 'cosine', 'euclidean', 'l1'. |
dtype |
str |
'float32' |
Memory precision. Options: 'float32', 'float16', 'bfloat16'. |
init_mode |
str |
'kmeans++' |
Prototype spawning logic. Options: 'kmeans++', 'random'. |
soft_type |
str |
'scaled' |
Crucial: Routing logic for targets and gradients. Options: 'hard', 'mean', 'scaled', 'softmax_scaled'. (Note: 'hard' and 'mean' disable gradients). |
temperature |
float |
1.0 |
Scalar dividing distances when soft_type='softmax_scaled'. |
top_k |
int |
5 |
Number of prototypes aggregated during inference. Use -1 for all. |
auto_feature_weights |
bool |
False |
Enables real-time ANOVA algorithmic feature importance scaling. |
negative_sampling |
int/None |
None |
Restricts loss calculation to |
diversity_reg |
float |
0.0 |
Orthogonality penalty enforcing prototype spreading (prevents mode collapse). |
l2_reg |
float |
0.0 |
Standard L2 weight decay applied to centroid coordinates. |
use_faiss |
bool |
False |
Utilizes HNSW Graph indexing for faiss). |
use_compile |
bool |
False |
Utilizes PyTorch 2.0 Triton compiler torch.compile for speedups. |
| Parameter | Type | Default | Description |
|---|---|---|---|
k_init |
int/str |
3 |
Number of prototypes per class. If set to 'auto', uses Geometric Information Dispersion to analytically spawn the perfect amount of prototypes based on class complexity! |
repulsion_factor |
float |
0.05 |
(LVQ) Push-force applied by negative samples to expel wrong prototypes. |
| Parameter | Type | Default | Description |
|---|---|---|---|
asl_gamma_neg |
float |
4.0 |
Asymmetric Loss decay for easy negative samples. |
asl_gamma_pos |
float |
1.0 |
Asymmetric Loss decay for positive samples. |
asl_clip |
float |
0.05 |
Margin under which negative samples are fully discarded from gradients. |
| Parameter | Type | Default | Description |
|---|---|---|---|
k_targets |
int/str |
20 |
Stage 1: Number of 'auto'). |
k_features |
int/str |
10 |
Stage 2: Number of 'auto'). |
target_distance |
str |
'euclidean' |
Metric used to cluster the |
target_assignment |
str |
'scaled' |
How EMA target averages are generated. Options: 'hard', 'mean', 'scaled', 'softmax_scaled'. |
If you set k_init='auto', the framework becomes hyperparameter-free. It calculates the Geometric Information Dispersion of the class:
- If a class is dense and identical (dispersion
$\approx 0$ ), it spawns$1$ prototype. - If a class is chaotic and spread out, it scales up to
$\mathcal{O}(\sqrt{N})$ .
Setting auto_feature_weights=True launches Welford's algorithm inline. It compares the within-cluster variance of a feature to its global variance.
Features carrying noise are automatically down-weighted in the distance math. During .finetune(), these weights become nn.Parameter and are tuned via Backpropagation!
In fit, fit_batch, and predict, you can pass a feature_mask tensor (1 for active, 0 for pad/missing).
- The Magic: The distance math completely ignores the masked features. During EMA, the framework allocates a
(K, D)matrix so that prototypes accumulate "age" independently per coordinate, preventing padding tokens from pulling prototypes to absolute zero.
If you do not provide an lr argument to .finetune(), the framework runs a micro-simulation. It exponentially scales the learning rate, tests batches, mathematically calculates the steepest negative gradient of the loss curve, sets the optimal LR, and uses copy.deepcopy to perfectly restore the untainted topological grid before real training starts.
Instead of blindly guessing threshold probabilities (e.g. p > 0.5), calling predict(X, strategy='gap') sorts the probabilities and automatically places the cutoff at the largest mathematical confidence drop.
import torch
from sklearn.datasets import fetch_california_housing
from sklearn.preprocessing import StandardScaler
from fast_kmeans import FastKMeansRegressor
# 1. Load Data
X_np, y_np = fetch_california_housing(return_X_y=True)
X = torch.tensor(StandardScaler().fit_transform(X_np), dtype=torch.float32)
y = torch.tensor(y_np, dtype=torch.float32).unsqueeze(1) # Support Vector Targets!
# 2. Initialize Model
reg = FastKMeansRegressor(
k_targets='auto', # Analytically choose target buckets
k_features=10, # 10 prototypes per bucket
distance='euclidean',
soft_type='softmax_scaled', # Differentiable routing required for gradients!
diversity_reg=0.01 # Prevent prototypes from clustering together
)
# 3. Phase 1: EMA Topology construction
print("Building rigid Voronoi map...")
reg.fit(X, y, max_iters=20)
# 4. Phase 2: Differentiable Finetuning (Auto LR triggered)
print("Finetuning with Gradients...")
reg.finetune(X, y, epochs=30, early_stopping_rounds=5)
# 5. Inference
predictions = reg.predict(X)
print(predictions[:5])