Skip to content

Commit fb6db87

Browse files
authored
Merge pull request #4 from cachevector/example-notebooks
Add example notebooks
2 parents f153d66 + 91ac882 commit fb6db87

4 files changed

Lines changed: 451 additions & 0 deletions

File tree

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
1313
p50/p90/p99, min/max, and throughput. `cx.compare_benchmarks()` returns a
1414
before/after comparison with speedup and latency/throughput deltas. Quantized
1515
models are automatically run on CPU. New `comprexx bench` CLI command.
16+
- Example notebooks: ResNet18 edge deployment (prune + quantize + ONNX export)
17+
and BERT-tiny quantization (low-rank decomposition + INT4 weight quant).
1618
- GitHub Actions CI workflow running `pytest` on Python 3.10, 3.11, 3.12 plus a
1719
`ruff check` lint job.
1820
- `CHANGELOG.md` with history for v0.1.0 and v0.2.0.

README.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -229,6 +229,13 @@ And for picking what to compress:
229229
|------|-------------|
230230
| Sensitivity analysis | `cx.analyze_sensitivity()` probes each Conv2d/Linear layer with a prune or noise perturbation, re-runs your `eval_fn`, and ranks layers by metric drop. Can also suggest `exclude_layers` above a chosen threshold. |
231231

232+
## Examples
233+
234+
Check out the example notebooks in [`examples/`](./examples/):
235+
236+
- [ResNet18 edge deployment](./examples/resnet18_edge_deploy.ipynb): profile, fuse, prune, quantize, benchmark, and export a ResNet18 to ONNX.
237+
- [BERT-tiny quantization](./examples/bert_tiny_quantize.ipynb): low-rank decomposition + INT4 weight quantization on a small transformer, with latency benchmarks.
238+
232239
## License
233240

234241
Apache 2.0

examples/bert_tiny_quantize.ipynb

Lines changed: 233 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,233 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"# BERT-tiny Quantization with Comprexx\n",
8+
"\n",
9+
"This notebook shows how to compress a small transformer model:\n",
10+
"\n",
11+
"1. Profile the model\n",
12+
"2. Apply low-rank decomposition to shrink Linear layers\n",
13+
"3. Apply weight-only INT4 quantization\n",
14+
"4. Benchmark before/after\n",
15+
"\n",
16+
"We use a minimal 2-layer transformer so this runs in seconds on CPU.\n",
17+
"\n",
18+
"Install: `pip install comprexx`"
19+
]
20+
},
21+
{
22+
"cell_type": "code",
23+
"execution_count": null,
24+
"metadata": {},
25+
"outputs": [],
26+
"source": "import torch.nn as nn\n\nimport comprexx as cx"
27+
},
28+
{
29+
"cell_type": "markdown",
30+
"metadata": {},
31+
"source": [
32+
"## 1. Define a small transformer\n",
33+
"\n",
34+
"2-layer encoder, d_model=128, 4 heads, feedforward dim=512. Small enough for a notebook, large enough to show compression working."
35+
]
36+
},
37+
{
38+
"cell_type": "code",
39+
"execution_count": null,
40+
"metadata": {},
41+
"outputs": [],
42+
"source": [
43+
"class TinyBERT(nn.Module):\n",
44+
" def __init__(self, vocab_size=1000, d_model=128, nhead=4, num_layers=2, num_classes=4):\n",
45+
" super().__init__()\n",
46+
" self.embedding = nn.Embedding(vocab_size, d_model)\n",
47+
" encoder_layer = nn.TransformerEncoderLayer(\n",
48+
" d_model=d_model, nhead=nhead, dim_feedforward=512, batch_first=True,\n",
49+
" )\n",
50+
" self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)\n",
51+
" self.classifier = nn.Linear(d_model, num_classes)\n",
52+
"\n",
53+
" def forward(self, x):\n",
54+
" x = self.embedding(x)\n",
55+
" x = self.encoder(x)\n",
56+
" # Mean pooling over sequence dim\n",
57+
" x = x.mean(dim=1)\n",
58+
" return self.classifier(x)\n",
59+
"\n",
60+
"model = TinyBERT()\n",
61+
"model.eval()\n",
62+
"print(f\"Model: {sum(p.numel() for p in model.parameters()):,} parameters\")"
63+
]
64+
},
65+
{
66+
"cell_type": "markdown",
67+
"metadata": {},
68+
"source": [
69+
"## 2. Profile the model\n",
70+
"\n",
71+
"We pass token IDs as input, but the profiler needs a float tensor. We'll profile using the embedding output shape and note that the embedding table is counted in params."
72+
]
73+
},
74+
{
75+
"cell_type": "code",
76+
"execution_count": null,
77+
"metadata": {},
78+
"outputs": [],
79+
"source": [
80+
"# For profiling, we use a float input that skips the embedding.\n",
81+
"# The full model takes integer token IDs, so we profile the encoder+classifier separately.\n",
82+
"class EncoderClassifier(nn.Module):\n",
83+
" \"\"\"Wraps encoder + classifier with float input for profiling.\"\"\"\n",
84+
" def __init__(self, encoder, classifier):\n",
85+
" super().__init__()\n",
86+
" self.encoder = encoder\n",
87+
" self.classifier = classifier\n",
88+
"\n",
89+
" def forward(self, x):\n",
90+
" x = self.encoder(x)\n",
91+
" return self.classifier(x.mean(dim=1))\n",
92+
"\n",
93+
"profiling_model = EncoderClassifier(model.encoder, model.classifier)\n",
94+
"profile = cx.analyze(profiling_model, input_shape=(1, 32, 128)) # (batch, seq_len, d_model)\n",
95+
"print(profile.summary())"
96+
]
97+
},
98+
{
99+
"cell_type": "markdown",
100+
"metadata": {},
101+
"source": [
102+
"## 3. Low-rank decomposition\n",
103+
"\n",
104+
"The feedforward layers inside the transformer encoder are 128x512 and 512x128. SVD can factorize these into pairs of smaller layers. We keep 50% of the singular values (by energy)."
105+
]
106+
},
107+
{
108+
"cell_type": "code",
109+
"execution_count": null,
110+
"metadata": {},
111+
"outputs": [],
112+
"source": [
113+
"from comprexx.stages.base import StageContext\n",
114+
"\n",
115+
"stage_lr = cx.stages.LowRankDecomposition(mode=\"energy\", energy_threshold=0.9)\n",
116+
"ctx = StageContext(input_shape=(1, 32, 128), device=\"cpu\")\n",
117+
"\n",
118+
"model_lr, report_lr = stage_lr.apply(profiling_model, ctx)\n",
119+
"print(report_lr.summary())"
120+
]
121+
},
122+
{
123+
"cell_type": "markdown",
124+
"metadata": {},
125+
"source": [
126+
"## 4. Weight-only INT4 quantization\n",
127+
"\n",
128+
"After SVD, we quantize remaining Linear weights to INT4 (group size 64, symmetric). Activations stay in float32."
129+
]
130+
},
131+
{
132+
"cell_type": "code",
133+
"execution_count": null,
134+
"metadata": {},
135+
"outputs": [],
136+
"source": [
137+
"stage_wq = cx.stages.WeightOnlyQuant(bits=4, group_size=64, symmetric=True)\n",
138+
"\n",
139+
"model_quant, report_wq = stage_wq.apply(model_lr, ctx)\n",
140+
"print(report_wq.summary())"
141+
]
142+
},
143+
{
144+
"cell_type": "markdown",
145+
"metadata": {},
146+
"source": [
147+
"## 5. Or use a Pipeline for the same thing"
148+
]
149+
},
150+
{
151+
"cell_type": "code",
152+
"execution_count": null,
153+
"metadata": {},
154+
"outputs": [],
155+
"source": [
156+
"pipeline = cx.Pipeline([\n",
157+
" cx.stages.LowRankDecomposition(mode=\"energy\", energy_threshold=0.9),\n",
158+
" cx.stages.WeightOnlyQuant(bits=4, group_size=64),\n",
159+
"])\n",
160+
"\n",
161+
"result = pipeline.run(profiling_model, input_shape=(1, 32, 128))\n",
162+
"print(result.report.summary())"
163+
]
164+
},
165+
{
166+
"cell_type": "markdown",
167+
"metadata": {},
168+
"source": [
169+
"## 6. Benchmark latency"
170+
]
171+
},
172+
{
173+
"cell_type": "code",
174+
"execution_count": null,
175+
"metadata": {},
176+
"outputs": [],
177+
"source": [
178+
"cmp = cx.compare_benchmarks(\n",
179+
" profiling_model, result.model,\n",
180+
" input_shape=(1, 32, 128),\n",
181+
" warmup=10,\n",
182+
" iters=50,\n",
183+
")\n",
184+
"print(cmp.summary())"
185+
]
186+
},
187+
{
188+
"cell_type": "markdown",
189+
"metadata": {},
190+
"source": [
191+
"## 7. Dynamic PTQ as an alternative\n",
192+
"\n",
193+
"If weight-only quantization isn't giving you enough speedup, dynamic PTQ quantizes both weights and activations to INT8 at runtime. It's a simpler path that works well for inference on CPU."
194+
]
195+
},
196+
{
197+
"cell_type": "code",
198+
"execution_count": null,
199+
"metadata": {},
200+
"outputs": [],
201+
"source": [
202+
"pipeline_ptq = cx.Pipeline([\n",
203+
" cx.stages.LowRankDecomposition(mode=\"energy\", energy_threshold=0.9),\n",
204+
" cx.stages.PTQDynamic(),\n",
205+
"])\n",
206+
"\n",
207+
"result_ptq = pipeline_ptq.run(profiling_model, input_shape=(1, 32, 128))\n",
208+
"print(result_ptq.report.summary())\n",
209+
"\n",
210+
"cmp_ptq = cx.compare_benchmarks(\n",
211+
" profiling_model, result_ptq.model,\n",
212+
" input_shape=(1, 32, 128),\n",
213+
" warmup=10,\n",
214+
" iters=50,\n",
215+
")\n",
216+
"print(\"\\n\" + cmp_ptq.summary())"
217+
]
218+
}
219+
],
220+
"metadata": {
221+
"kernelspec": {
222+
"display_name": "Python 3",
223+
"language": "python",
224+
"name": "python3"
225+
},
226+
"language_info": {
227+
"name": "python",
228+
"version": "3.10.0"
229+
}
230+
},
231+
"nbformat": 4,
232+
"nbformat_minor": 4
233+
}

0 commit comments

Comments
 (0)