Skip to content

Delete AffineQuantizedTensor, AQTTensorImpl, and Layout#4245

Open
andrewor14 wants to merge 1 commit intomainfrom
delete-aqt
Open

Delete AffineQuantizedTensor, AQTTensorImpl, and Layout#4245
andrewor14 wants to merge 1 commit intomainfrom
delete-aqt

Conversation

@andrewor14
Copy link
Copy Markdown
Contributor

@andrewor14 andrewor14 commented Apr 6, 2026

Summary: AffineQuantizedTensor was the v1 quantized tensor
system, now fully superseded by v2 tensor types (Int8Tensor,
Int4Tensor, Float8Tensor, IntxUnpackedToInt8Tensor, etc.) that
inherit from TorchAOBaseTensor.

BC-Breaking notes:

Before (AQT):

from torchao.dtypes import to_affine_quantized_intx
from torchao.quantization import quantize_, Int4WeightOnlyConfig

# Low-level AQT API
weight = to_affine_quantized_intx(
    weight, mapping_type, block_size, target_dtype,
    quant_min, quant_max, eps, _layout=Layout(),
)

# High-level API (unchanged)
quantize_(model, Int4WeightOnlyConfig())

After (v2 tensors):

from torchao.quantization import quantize_, Int4WeightOnlyConfig

# High-level API (unchanged, recommended)
quantize_(model, Int4WeightOnlyConfig())

# Low-level v2 API (if needed)
from torchao.quantization import Int4Tensor, IntxUnpackedToInt8Tensor
weight = Int4Tensor.from_hp(weight, block_size)
weight = IntxUnpackedToInt8Tensor.from_hp(weight, block_size, torch.int4)

Detailed changes:

Core deletions:

  • torchao/dtypes/affine_quantized_tensor.py (class definition)
  • torchao/dtypes/affine_quantized_tensor_ops.py (aten dispatch)
  • torchao/dtypes/floatx/, torchao/dtypes/uintx/ (empty subpackages)
  • torchao/dtypes/README.md (stale AQT-centric docs)
  • torchao/dtypes/utils.py: removed Layout class and AQTTensorImpl class
  • torchao/dtypes/init.py: removed all AQT and Layout exports
  • torchao/utils.py: removed _register_layout, _get_tensor_impl_constructor,
    and their classmethod registrations on TorchAOBaseTensor
  • test/dtypes/test_affine_quantized.py
  • test/dtypes/test_affine_quantized_tensor_parallel.py

Core updates:

  • quant_api.py: removed AQT from _is_linear check, removed 5 dead
    activation quant helpers
  • testing/utils.py: switched defaults from AQT to Int8Tensor
  • Updated test assertions, docstrings, and docs to remove AQT references

Prototype updates:

Still broken (tracked with TODOs):

  • tutorials/calibration_flow/ (uses to_affine_quantized_intx_static)
  • tutorials/developer_api_guide/ (uses Layout)

Docs/comments only (not broken, just stale references):

  • prototype/quantization/module_swap/ (README)
  • prototype/parq/ (README)
  • prototype/quantized_training/ (comments)

@pytorch-bot
Copy link
Copy Markdown

pytorch-bot bot commented Apr 6, 2026

🔗 Helpful Links

🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/4245

Note: Links to docs will display an error until the docs builds have been completed.

✅ No Failures

As of commit 9eb99cb with merge base 1a1debb (image):
💚 Looks good so far! There are no failures yet. 💚

This comment was automatically generated by Dr. CI and updates every 15 minutes.

@meta-cla meta-cla bot added the CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. label Apr 6, 2026
@andrewor14 andrewor14 added module: core changes affecting multiple modules, e.g. base config/tensor, observers, quant ops module: bc-breaking Use this tag if this PR breaks backward compatibility and removed module: core changes affecting multiple modules, e.g. base config/tensor, observers, quant ops labels Apr 6, 2026
@andrewor14
Copy link
Copy Markdown
Contributor Author

@claude review please?

@claude
Copy link
Copy Markdown

claude bot commented Apr 6, 2026

Claude finished @andrewor14's task in 6m 11s —— View job


Review of PR #4245: [draft] Delete AQT

  • Gather context and read the full diff
  • Review deleted files for remaining references
  • Review modified files for correctness
  • Check for dangling imports or broken references
  • Post review feedback

Overall this is a clean deletion of the deprecated AffineQuantizedTensor system. The core changes — removing affine_quantized_tensor.py, affine_quantized_tensor_ops.py, AQTTensorImpl, the Layout docstring/code in utils.py, the UintxLayout/UintxAQTTensorImpl shim, and the helper functions in quant_api.py — all look correct. No dangling imports were found. A few items worth addressing:


Issues

1. Unrelated change to Float8TrainingOpConfig in torchao/prototype/moe_training/config.py

This PR removes @register_as_pytree_constant and the custom __eq__/__hash__ from Float8TrainingOpConfig. This is unrelated to the AQT deletion and is a behavioral change:

  • Without @register_as_pytree_constant, this config can no longer be used with dynamo nonstrict trace (the sibling MXFP8TrainingOpConfig retains the decorator with an explicit comment explaining why it's needed).
  • Without __hash__, the dataclass becomes unhashable (since @dataclass with mutable fields sets __hash__ = None by default), which could break code that uses it in sets or dicts.

This should either be reverted or split into a separate PR with proper justification.

Fix this →

2. Weak assertion in test/dtypes/test_affine_quantized.py:90

assert "Tensor" in ql_str, (
    f"Expected quantized tensor in repr, got: {ql_str}"
)

The string "Tensor" will match virtually any PyTorch repr (e.g. a plain torch.Tensor), making this assertion meaningless. Consider checking for something more specific like "Int4" or "Int8" or "Float8", or checking isinstance(ql.weight, TorchAOBaseTensor) instead, to actually verify quantization occurred.

Fix this →

3. Stale comments referencing deleted functions

Two files still reference the now-deleted _int8_asymm_per_token_quant function in comments:

  • torchao/quantization/qat/linear.py:325-326:
    # Use torch.float32 to match torchao.quantization.quant_api._int8_asymm_per_token_quant,
  • torchao/quantization/linear_quant_modules.py:347-348:
    # uses fp32 to match torchao.quantization.quant_api._int8_asymm_per_token_quant

These comments should be updated to reference the current implementation.

Fix this →


Nits

4. _N_TO_TARGET_DTYPE in naive_intNwo.py has unreachable entries

The dict maps n=4 and n=8 to dtypes, but the code raises AssertionError for both those cases at lines 70-77 before the dict is ever consulted. This isn't wrong, but the entries at keys 4 and 8 are dead code.

5. autoround/core.pyapply_auto_round() now only raises

The function body was entirely replaced with a raise AssertionError(...). The existing tests that call it (test_auto_round, test_wrap_model_with_multi_tensor) are already @pytest.mark.skip, so this is fine for now, but worth noting that apply_auto_round is now a dead function that could be removed entirely (or left as-is if there are downstream callers that need the error message).


What looks good

  • The _is_linear simplification to use not isinstance(mod.weight, TorchAOBaseTensor) is correct — AffineQuantizedTensor, LinearActivationQuantizedTensor, and _AffineFakeQuantizedTensor all inherit from TorchAOBaseTensor.
  • Documentation updates correctly replace AQT references with the new tensor types (Int4Tensor, Int8Tensor).
  • The deleted tutorial files (awq_like.py, static_quant.py) were AQT-dependent and their removal is appropriate.
  • The gptq_like.py tutorial migration from AQT to Int4TilePackedTo4dTensor is clean.
  • Test class renames dropping "AffineQuantized" from names are good cleanup.
  • The static_quantization.rst tutorial migration from to_affine_quantized_intx_static to Int8Tensor.from_hp is a clean replacement.

Copy link
Copy Markdown
Contributor

@jerryzh168 jerryzh168 Apr 6, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can follow up to delete this as well I think, maybe migrate the useful ones if any

setattr(other_cls, f"{name}_{suffix}", new_test)


class TorchAOBasicTestCase(common_utils.TestCase):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can probably delete this as well later

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this doesn't seem related to AQT?

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not related, but we already have tests in workflow tests, so not needed anymore

@andrewor14 andrewor14 force-pushed the delete-aqt branch 5 times, most recently from dc8f0eb to 84cdede Compare April 7, 2026 23:14
@andrewor14 andrewor14 changed the title [draft] Delete AQT Delete AffineQuantizedTensor Apr 7, 2026
@andrewor14 andrewor14 marked this pull request as ready for review April 7, 2026 23:14
@andrewor14 andrewor14 requested review from svekars and vkuzo as code owners April 7, 2026 23:14
@andrewor14 andrewor14 requested a review from jerryzh168 April 7, 2026 23:15
@andrewor14
Copy link
Copy Markdown
Contributor Author

@claude review again please?

@claude
Copy link
Copy Markdown

claude bot commented Apr 7, 2026

Claude finished @andrewor14's task in 8m 19s —— View job


Re-review of PR #4245: Delete AffineQuantizedTensor (AQT)

  • Gather context and read the full diff
  • Check if previous review issues were addressed
  • Review all modified files for correctness
  • Check for dangling imports or broken references
  • Post review feedback

Previous review issues #1 (unrelated Float8TrainingOpConfig change) and #3 (stale _int8_asymm_per_token_quant comments) have been fixed. Issue #2 (weak test assertion) is moot since the test file was deleted per jerryzh168's feedback. The diff looks much cleaner now. Two issues remain — one is a CI-breaking import error, the other is a behavioral regression.


Issues

1. CI failure: AQTTensorImpl import breaks codebook tests

The aarch64 regression test is failing:

torchao/prototype/dtypes/uintx/uintx_utils.py:22: in <module>
    from torchao.dtypes.utils import AQTTensorImpl, Layout
E   ImportError: cannot import name 'AQTTensorImpl' from 'torchao.dtypes.utils'

Import chain: test_groupwise_lowbit_weight_lut_quantizer.pycodebook/__init__.pycodebook_quantized_tensor.pyuintx_utils.pyfrom torchao.dtypes.utils import AQTTensorImpl (deleted).

This is a top-level import in uintx_utils.py:22, so it's a hard failure that breaks any code importing from the codebook module. There are also two more dangling imports in the same file at line 221 (from torchao.dtypes.affine_quantized_tensor import register_layout), though those are behind a class definition and would only fail when the class is instantiated.

Fix options:

  • (a) Keep a minimal AQTTensorImpl stub in torchao/dtypes/utils.py that the prototype code can inherit from (this was jerryzh168's suggestion to "delete in a follow-up PR since there are a lot of deps")
  • (b) Inline the base class into uintx_utils.py directly

Fix this →

2. _is_linear no longer prevents double-quantization of v2 tensor types

The old code had:

not isinstance(mod.weight, AffineQuantizedTensor)

This was removed, but no replacement guard was added. The remaining checks only cover LinearActivationQuantizedTensor and _AffineFakeQuantizedTensor. This means calling quantize_() twice on a model with v2 tensors (Int8Tensor, Int4Tensor, Float8Tensor, etc.) will silently attempt to double-quantize weights.

The comment on line 183 explicitly states this check exists "to make sure the weight is only quantized once when it is shared by multiple linear modules."

Suggested fix: Replace the removed AffineQuantizedTensor check with:

and not isinstance(mod.weight, TorchAOBaseTensor)

This covers all v2 quantized tensor types since they all inherit from TorchAOBaseTensor. Note that LinearActivationQuantizedTensor and _AffineFakeQuantizedTensor also inherit from TorchAOBaseTensor, so their explicit checks become redundant — but keeping them for clarity is fine.

Fix this →


Nits

3. Doc path for Int4Tensor in serialization.rst

serialization.rst:99 shows <class 'torchao.quantization.Int4Tensor'> but at runtime type() will show the defining module path: torchao.quantization.quantize_.workflows.int4.int4_tensor.Int4Tensor. Per jerryzh168's feedback about shorter paths, this output example will be incorrect at runtime. Consider either noting this is illustrative or matching the actual runtime output.

4. Dead code after raise in naive_intNwo.py:52

The code after raise AssertionError(...) at line 52 (lines 57-108) is unreachable. It also contains dangling imports of to_affine_quantized_intx at lines 67 and 91. Not a problem since it's unreachable, but could be cleaned up.


What looks good

  • The autoround migration from UintxLayout/to_affine_quantized_intx_static to IntxUnpackedToInt8Tensor.from_hp() is correct — the zero_point shift from unsigned to signed range is well-documented and the from_hp() parameters match.
  • Stale comment fixes in qat/linear.py:325 and linear_quant_modules.py:347 properly reference "the PTQ activation quantization scale dtype" instead of the deleted function.
  • The static_quantization.rst tutorial migration from to_affine_quantized_intx_static to Int8Tensor.from_hp is clean.
  • Deletion of all 5 dead activation quant helpers from quant_api.py is correct.
  • TODO markers on known-broken prototype code are appropriate.

@andrewor14 andrewor14 changed the title Delete AffineQuantizedTensor Delete AffineQuantizedTensor, AQTTensorImpl, and Layout Apr 8, 2026
@andrewor14 andrewor14 force-pushed the delete-aqt branch 3 times, most recently from ac92ac1 to bf3628c Compare April 8, 2026 15:06
class StretchedIntxWeightConfig(AOBaseConfig):
granularity: Granularity = PerAxis(0)
scale_dtype: Optional[torch.dtype] = None
layout: Optional[Layout] = None
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lisjin FYI deleting this field which I think is unused

@andrewor14
Copy link
Copy Markdown
Contributor Author

@claude review again please

@claude
Copy link
Copy Markdown

claude bot commented Apr 8, 2026

Claude finished @andrewor14's task in 5m 2s —— View job


Re-review of PR #4245: Delete AffineQuantizedTensor (AQT)

  • Gather context and read the full diff
  • Check if previous review issues were addressed
  • Review all modified files for correctness
  • Check for dangling imports or broken references
  • Post review feedback

Previous review status: Issues #1 (unrelated Float8TrainingOpConfig change) and #3 (stale _int8_asymm_per_token_quant comments) were fixed. Issue #2 (_is_linear double-quantization guard) is partially addressed with a TODO comment but not yet implemented. Issue #1 (CI-breaking AQTTensorImpl import) from the second review is fixed — uintx_utils.py no longer imports AQTTensorImpl or Layout. Ruff passes. The diff is clean.


Issue

_is_linear no longer prevents double-quantization of v2 tensor types

torchao/quantization/quant_api.py:184-192 — The AffineQuantizedTensor isinstance check was removed but no replacement was added. The TODO on line 185 acknowledges this (# TODO: check isinstance(TorchAOBaseTensor)?) but the guard is missing. All v2 quantized tensor types (Int4Tensor, Int8Tensor, Float8Tensor, IntxUnpackedToInt8Tensor, Int4TilePackedTo4dTensor) inherit from TorchAOBaseTensor, so adding not isinstance(mod.weight, TorchAOBaseTensor) would cover all of them. The remaining LinearActivationQuantizedTensor and _AffineFakeQuantizedTensor checks would become redundant but could be kept for clarity.

This was flagged in the second review and the TODO was added, but since this is a behavioral regression (calling quantize_() twice on a model will now silently attempt double-quantization on shared weights), it would be better to fix it in this PR rather than defer it.

# Suggested fix at line 185-192:
return (
    isinstance(mod, torch.nn.Linear)
    and hasattr(mod, "weight")
    and not isinstance(mod.weight, TorchAOBaseTensor)
    and not isinstance(mod, nn.modules.linear.NonDynamicallyQuantizableLinear)
)

Fix this →


Nits

1. Dead code after raise in naive_intNwo.py

torchao/prototype/quantization/mixed_precision/scripts/naive_intNwo.py:51-126 — Lines 56-126 are unreachable after the raise AssertionError(...) at line 51. This includes dangling imports of to_affine_quantized_intx (lines 66, 90) and references to Layout() (lines 84, 108) with # noqa: F821 comments. Consider deleting the dead code entirely rather than keeping it with noqa suppression — it will only confuse future readers.

2. autoround/core.py dead code after raise

torchao/prototype/autoround/core.py:173-334 — Similarly, apply_auto_round() raises AssertionError at line 173, making the entire function body (lines 178-334) dead code. The UintxLayout import at line 196 and to_affine_quantized_intx_static / ZeroPointDomain references at lines 204/212 use # noqa: F821 to suppress undefined-name warnings. This is fine for now since the tests are already skipped, but could be cleaned up.

3. Tutorial files will fail at import time

Five tutorial files will ImportError when run:

  • tutorials/calibration_flow/static_quant.py — imports to_affine_quantized_intx_static from torchao.dtypes
  • tutorials/calibration_flow/awq_like.py — same
  • tutorials/calibration_flow/gptq_like.py — imports to_affine_quantized_intx and to_affine_quantized_intx_static
  • tutorials/developer_api_guide/my_dtype_tensor_subclass.py — imports Layout from torchao.dtypes.utils
  • tutorials/developer_api_guide/my_trainable_tensor_subclass.py — imports Layout from torchao.dtypes.utils

All have TODO comments — just noting for visibility since andrewor14 confirmed these are deferred to follow-up PRs.


What looks good

  • The uintx_utils.py CI-breaking import error from the previous review is fixed — AQTTensorImpl, Layout, UintxLayout, and UintxAQTTensorImpl are all cleanly removed.
  • The stale comments referencing _int8_asymm_per_token_quant in qat/linear.py:325 and linear_quant_modules.py:347 are fixed with updated wording.
  • The static_quantization.rst tutorial migration from to_affine_quantized_intx_static to Int8Tensor.from_hp is clean.
  • The testing/utils.py migration from AQT to Int8Tensor.from_hp with PerRow() granularity is correct.
  • All deleted files (affine_quantized_tensor.py, affine_quantized_tensor_ops.py, test files, README, empty subpackages) are appropriately removed.
  • The parq/quant/config_torchao.py removal of the layout: Optional[Layout] = None field is correct — the field was unused in the implementation below it.
  • No dangling live imports of deleted symbols were found in production code.
  • autoround tests that call apply_auto_round() are correctly @pytest.mark.skip'd.

|

@meta-codesync
Copy link
Copy Markdown

meta-codesync bot commented Apr 8, 2026

@andrewor14 has imported this pull request. If you are a Meta employee, you can view this in D100013598.

@andrewor14
Copy link
Copy Markdown
Contributor Author

@claude For "_is_linear no longer prevents double-quantization of v2 tensor types", this doesn't seem like a regression because for all v2 configs we already don't guard the double quantization before this PR? We just do this for v1 configs which use AQT. Can you check again?

@claude
Copy link
Copy Markdown

claude bot commented Apr 8, 2026

Claude finished @andrewor14's task in 1m 45s —— View job


Re-examining _is_linear double-quantization concern

  • Gather context on the _is_linear function (before and after this PR)
  • Check how v2 tensor types relate to AffineQuantizedTensor in the class hierarchy
  • Verify whether v2 configs were ever guarded by the AQT check
  • Provide corrected analysis

You're right — my previous review was incorrect on this point. The AffineQuantizedTensor check was not a guard against double-quantization of v2 tensor types.

Here's the class hierarchy:

TorchAOBaseTensor (base)
├── AffineQuantizedTensor (v1 — deleted in this PR)
├── Int4Tensor (v2)
├── Int8Tensor (v2)
├── Float8Tensor (v2)
├── IntxUnpackedToInt8Tensor (v2)
├── LinearActivationQuantizedTensor
└── _AffineFakeQuantizedTensor

The old check not isinstance(mod.weight, AffineQuantizedTensor) only caught v1 AQT weights. V2 tensor types are siblings of AffineQuantizedTensor, not subclasses — so isinstance(int4_weight, AffineQuantizedTensor) was always False. The v2 configs were never guarded against double-quantization by this check, either before or after this PR.

Removing the AffineQuantizedTensor isinstance check is simply removing a guard for a type that no longer exists, not introducing a regression. The TODO comment about potentially adding isinstance(TorchAOBaseTensor) is a separate enhancement, not a fix for something this PR broke. Apologies for the noise on this in the previous reviews.


)


def is_device(target_device_str: str, device: Union[str, torch.device]):
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we can follow up to move this to torchao/utils.py

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this seems unused

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I can delete

pack_dim = -1
_layout = UintxLayout(dtype=dtype, pack_dim=pack_dim)
return to_affine_quantized_intx_static(
return to_affine_quantized_intx_static( # noqa: F821
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

isn't this removed

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah this whole block is dead code (AssertionError above), it already used deleted functions before

Copy link
Copy Markdown
Contributor

@jerryzh168 jerryzh168 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks!

**Summary:** AffineQuantizedTensor was the v1 quantized tensor
system, now fully superseded by v2 tensor types (Int8Tensor,
Int4Tensor, Float8Tensor, IntxUnpackedToInt8Tensor, etc.) that
inherit from TorchAOBaseTensor.

**BC-Breaking notes:**

Before (AQT):
```python
from torchao.dtypes import to_affine_quantized_intx
from torchao.quantization import quantize_, Int4WeightOnlyConfig

# Low-level AQT API
weight = to_affine_quantized_intx(
    weight, mapping_type, block_size, target_dtype,
    quant_min, quant_max, eps, _layout=Layout(),
)

# High-level API (unchanged)
quantize_(model, Int4WeightOnlyConfig())
```

After (v2 tensors):
```python
from torchao.quantization import quantize_, Int4WeightOnlyConfig

# High-level API (unchanged, recommended)
quantize_(model, Int4WeightOnlyConfig())

# Low-level v2 API (if needed)
from torchao.quantization import Int4Tensor, IntxUnpackedToInt8Tensor
weight = Int4Tensor.from_hp(weight, block_size)
weight = IntxUnpackedToInt8Tensor.from_hp(weight, block_size, torch.int4)
```

**Detailed changes:**

Core deletions:
- torchao/dtypes/affine_quantized_tensor.py (class definition)
- torchao/dtypes/affine_quantized_tensor_ops.py (aten dispatch)
- torchao/dtypes/floatx/, torchao/dtypes/uintx/ (empty subpackages)
- torchao/dtypes/README.md (stale AQT-centric docs)
- torchao/dtypes/utils.py: removed Layout class and AQTTensorImpl class
- torchao/dtypes/__init__.py: removed all AQT and Layout exports
- torchao/utils.py: removed _register_layout, _get_tensor_impl_constructor,
  and their classmethod registrations on TorchAOBaseTensor
- test/dtypes/test_affine_quantized.py
- test/dtypes/test_affine_quantized_tensor_parallel.py

Core updates:
- quant_api.py: removed AQT from _is_linear check, removed 5 dead
  activation quant helpers
- testing/utils.py: switched defaults from AQT to Int8Tensor
- Updated test assertions, docstrings, and docs to remove AQT references

Prototype updates:
- prototype/autoround/: removed broken AQT imports, updated isinstance
  checks to TorchAOBaseTensor. Everything works except apply_auto_round()
  which was already broken before this PR (issue #1690).
- prototype/dtypes/uintx/uintx_utils.py: removed UintxLayout,
  UintxAQTTensorImpl, and AQT imports (fixes codebook import breakage)
- prototype/quantization/mixed_precision/: added assertion error since
  feature was already broken by PlainLayout deletion (#4151)

Still broken (tracked with TODOs):
- tutorials/calibration_flow/ (uses to_affine_quantized_intx_static)
- tutorials/developer_api_guide/ (uses Layout)

Docs/comments only (not broken, just stale references):
- prototype/quantization/module_swap/ (README)
- prototype/parq/ (README)
- prototype/quantized_training/ (comments)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Facebook bot. Authors need to sign the CLA before a PR can be reviewed. module: bc-breaking Use this tag if this PR breaks backward compatibility

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants