Skip to content

Commit fd89d9d

Browse files
authored
Refactor layers. (#1866)
# What does this PR do? <!-- Congratulations! You've made it this far! You're not quite done yet though. Once merged, your PR is going to appear in the release notes with the title you set, so make sure it's a great title that fully reflects the extent of your awesome contribution. Then, please replace this with a description of the change and which issue is fixed (if applicable). Please also include relevant motivation and context. List any dependencies (if any) that are required for this change. Once you're done, someone will review your PR shortly (see the section "Who can review?" below to tag some potential reviewers). They may suggest changes to make the code even better. If no one reviewed your PR after a week has passed, don't hesitate to post a new comment @-mentioning the same persons---sometimes notifications get lost. --> <!-- Remove if not applicable --> Fixes # (issue) ## Before submitting - [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case). - [ ] Did you read the [contributor guideline](https://github.com/huggingface/transformers/blob/main/CONTRIBUTING.md#start-contributing-pull-requests), Pull Request section? - [ ] Was this discussed/approved via a Github issue or the [forum](https://discuss.huggingface.co/)? Please add a link to it if that's the case. - [ ] Did you make sure to update the documentation with your changes? Here are the [documentation guidelines](https://github.com/huggingface/transformers/tree/main/docs), and [here are tips on formatting docstrings](https://github.com/huggingface/transformers/tree/main/docs#writing-source-documentation). - [ ] Did you write any new necessary tests? ## Who can review? Anyone in the community is free to review the PR once the tests have passed. Feel free to tag members/contributors who may be interested in your PR. <!-- Your PR will be replied to more quickly if you can figure out the right person to tag with @ @OlivierDehaene OR @Narsil -->
1 parent 59b3ffe commit fd89d9d

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

59 files changed

+2055
-1911
lines changed

server/tests/utils/test_layers.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import torch
2-
from text_generation_server.utils.layers import (
2+
from text_generation_server.layers import (
33
TensorParallelEmbedding,
44
)
55

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
from text_generation_server.layers.tensor_parallel import (
2+
TensorParallelColumnLinear,
3+
TensorParallelRowLinear,
4+
TensorParallelEmbedding,
5+
)
6+
from text_generation_server.layers.speculative import SpeculativeHead
7+
from text_generation_server.layers.linear import (
8+
get_linear,
9+
FastLinear,
10+
)
11+
12+
# Just to add the `load` methods.
13+
from text_generation_server.layers.layernorm import load_layer_norm
14+
from text_generation_server.layers.conv import load_conv2d
Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,106 @@
1+
import torch
2+
from loguru import logger
3+
from functools import lru_cache
4+
import bitsandbytes as bnb
5+
from bitsandbytes.nn import Int8Params, Params4bit
6+
7+
8+
@lru_cache(1)
9+
def warn_deprecate_bnb():
10+
logger.warning(
11+
"Bitsandbytes 8bit is deprecated, using `eetq` is a drop-in replacement, and has much better performnce"
12+
)
13+
14+
15+
class Linear8bitLt(torch.nn.Module):
16+
def __init__(
17+
self,
18+
weight,
19+
bias,
20+
has_fp16_weights=True,
21+
memory_efficient_backward=False,
22+
threshold=0.0,
23+
index=None,
24+
):
25+
super().__init__()
26+
assert (
27+
not memory_efficient_backward
28+
), "memory_efficient_backward is no longer required and the argument is deprecated in 0.37.0 and will be removed in 0.39.0"
29+
self.state = bnb.MatmulLtState()
30+
self.index = index
31+
32+
# Necessary for stacked layers
33+
self.state.threshold = threshold
34+
self.state.has_fp16_weights = has_fp16_weights
35+
self.state.memory_efficient_backward = memory_efficient_backward
36+
if threshold > 0.0 and not has_fp16_weights:
37+
self.state.use_pool = True
38+
39+
self.weight = Int8Params(
40+
weight.data,
41+
has_fp16_weights=has_fp16_weights,
42+
requires_grad=has_fp16_weights,
43+
)
44+
self.weight.cuda(weight.device)
45+
self.bias = bias
46+
47+
def init_8bit_state(self):
48+
self.state.CB = self.weight.CB
49+
self.state.SCB = self.weight.SCB
50+
self.weight.CB = None
51+
self.weight.SCB = None
52+
53+
def forward(self, x: torch.Tensor):
54+
self.state.is_training = self.training
55+
if self.weight.CB is not None:
56+
self.init_8bit_state()
57+
58+
# weights are cast automatically as Int8Params, but the bias has to be cast manually
59+
if self.bias is not None and self.bias.dtype != x.dtype:
60+
self.bias.data = self.bias.data.to(x.dtype)
61+
62+
out = bnb.matmul(x, self.weight, bias=self.bias, state=self.state)
63+
64+
if not self.state.has_fp16_weights:
65+
if self.state.CB is not None and self.state.CxB is not None:
66+
# we converted 8-bit row major to turing/ampere format in the first inference pass
67+
# we no longer need the row-major weight
68+
del self.state.CB
69+
self.weight.data = self.state.CxB
70+
return out
71+
72+
73+
class Linear4bit(nn.Module):
74+
def __init__(self, weight, bias, quant_type):
75+
super().__init__()
76+
self.weight = Params4bit(
77+
weight.data,
78+
requires_grad=False,
79+
compress_statistics=True,
80+
quant_type=quant_type,
81+
)
82+
self.compute_dtype = None
83+
self.weight.cuda(weight.device)
84+
self.bias = bias
85+
86+
def forward(self, x: torch.Tensor):
87+
# weights are cast automatically as Int8Params, but the bias has to be cast manually
88+
if self.bias is not None and self.bias.dtype != x.dtype:
89+
self.bias.data = self.bias.data.to(x.dtype)
90+
91+
if getattr(self.weight, "quant_state", None) is None:
92+
print(
93+
"FP4 quantization state not initialized. Please call .cuda() or .to(device) on the LinearFP4 layer first."
94+
)
95+
inp_dtype = x.dtype
96+
if self.compute_dtype is not None:
97+
x = x.to(self.compute_dtype)
98+
99+
bias = None if self.bias is None else self.bias.to(self.compute_dtype)
100+
out = bnb.matmul_4bit(
101+
x, self.weight.t(), bias=bias, quant_state=self.weight.quant_state
102+
)
103+
104+
out = out.to(inp_dtype)
105+
106+
return out
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from accelerate import init_empty_weights
2+
import torch
3+
4+
5+
@classmethod
6+
def load_conv2d(cls, prefix, weights, in_channels, out_channels, kernel_size, stride):
7+
weight = weights.get_tensor(f"{prefix}.weight")
8+
bias = weights.get_tensor(f"{prefix}.bias")
9+
with init_empty_weights():
10+
conv2d = cls(
11+
in_channels=in_channels,
12+
out_channels=out_channels,
13+
kernel_size=kernel_size,
14+
stride=stride,
15+
)
16+
17+
conv2d.weight = torch.nn.Parameter(weight)
18+
conv2d.bias = torch.nn.Parameter(bias)
19+
return conv2d
20+
21+
22+
@classmethod
23+
def load_conv2d_no_bias(
24+
cls, prefix, weights, in_channels, out_channels, kernel_size, stride
25+
):
26+
weight = weights.get_tensor(f"{prefix}.weight")
27+
with init_empty_weights():
28+
conv2d = cls(
29+
in_channels=in_channels,
30+
out_channels=out_channels,
31+
kernel_size=kernel_size,
32+
stride=stride,
33+
)
34+
35+
conv2d.weight = torch.nn.Parameter(weight)
36+
conv2d.bias = None
37+
return conv2d
38+
39+
40+
torch.nn.Conv2d.load = load_conv2d
41+
torch.nn.Conv2d.load_no_bias = load_conv2d_no_bias
Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
import torch
2+
from EETQ import quant_weights, w8_a16_gemm
3+
4+
5+
class EETQLinear(torch.nn.Module):
6+
def __init__(
7+
self,
8+
weight,
9+
bias,
10+
) -> None:
11+
super().__init__()
12+
device = weight.device
13+
if weight.dtype != torch.float16:
14+
weight = weight.to(dtype=torch.float16)
15+
weight = torch.t(weight).contiguous().cpu()
16+
weight, scale = quant_weights(weight, torch.int8, False)
17+
18+
self.weight = weight.cuda(device)
19+
self.scale = scale.cuda(device)
20+
self.bias = bias.cuda(device) if bias is not None else None
21+
22+
def forward(self, input: torch.Tensor) -> torch.Tensor:
23+
output = w8_a16_gemm(input, self.weight, self.scale)
24+
output = output + self.bias if self.bias is not None else output
25+
return output
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
import torch
2+
3+
4+
def fp8_quantize(weight, qdtype=torch.float8_e4m3fn):
5+
device = weight.device
6+
# weight, scale = quant_weights(weight, torch.int8, False)
7+
finfo = torch.finfo(qdtype)
8+
# Calculate the scale as dtype max divided by absmax
9+
scale = finfo.max / weight.abs().max().clamp(min=1e-12)
10+
# scale and clamp the tensor to bring it to
11+
# the representative range of float8 data type
12+
# (as default cast is unsaturated)
13+
qweight = (weight * scale).clamp(min=finfo.min, max=finfo.max)
14+
# Return both float8 data and the inverse scale (as float),
15+
# as both required as inputs to torch._scaled_mm
16+
qweight = qweight.to(qdtype)
17+
scale = scale.float().reciprocal()
18+
return qweight, scale
19+
20+
21+
class Fp8Linear(torch.nn.Module):
22+
def __init__(
23+
self,
24+
weight,
25+
bias,
26+
) -> None:
27+
super().__init__()
28+
self.dtype = weight.dtype
29+
self.qweight, self.scale = fp8_quantize(weight)
30+
31+
self.bias = bias if bias is not None else None
32+
33+
def forward(self, input: torch.Tensor) -> torch.Tensor:
34+
qinput, scale = fp8_quantize(input)
35+
output, _ = torch._scaled_mm(
36+
qinput,
37+
self.qweight.t(),
38+
out_dtype=self.dtype,
39+
scale_a=scale,
40+
scale_b=self.scale,
41+
bias=self.bias,
42+
)
43+
return output
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
import os
2+
import torch
3+
from text_generation_server.utils.import_utils import (
4+
SYSTEM,
5+
)
6+
7+
try:
8+
major, _minor = torch.cuda.get_device_capability()
9+
except Exception:
10+
major = 1
11+
12+
HAS_EXLLAMA = False
13+
CAN_EXLLAMA = major >= 8 or SYSTEM == "rocm"
14+
V2 = os.getenv("EXLLAMA_VERSION", "2") == "2"
15+
if os.getenv("DISABLE_EXLLAMA") == "True":
16+
HAS_EXLLAMA = False
17+
elif CAN_EXLLAMA:
18+
try:
19+
if V2:
20+
from text_generation_server.layers.gptq.exllamav2 import (
21+
QuantLinear as ExllamaQuantLinear,
22+
create_exllama_buffers,
23+
set_device,
24+
)
25+
26+
HAS_EXLLAMA = "2"
27+
else:
28+
from text_generation_server.layers.gptq.exllama import (
29+
Ex4bitLinear as ExllamaQuantLinear,
30+
create_exllama_buffers,
31+
set_device,
32+
)
33+
34+
HAS_EXLLAMA = "1"
35+
36+
except ImportError:
37+
pass
38+
39+
from text_generation_server.layers.gptq.quant_linear import QuantLinear

server/text_generation_server/utils/gptq/exllamav2.py renamed to server/text_generation_server/layers/gptq/exllamav2.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@ def ext_make_q_matrix(w: dict, temp_dq, key: str = None):
119119
none_tensor,
120120
temp_dq,
121121
)
122+
else:
123+
RuntimeError("Cannot create handle")
122124

123125

124126
DEVICE = None

0 commit comments

Comments
 (0)