8
8
import argparse
9
9
import os
10
10
import sys
11
+ import json
12
+ from math import prod
11
13
from pathlib import Path
12
- from types import EllipsisType
13
- from typing import TYPE_CHECKING , Callable , Iterable , Iterator , Sequence , SupportsIndex , cast
14
+ from typing import TYPE_CHECKING , Any , Callable , Iterable , Iterator , Sequence , SupportsIndex , cast
14
15
15
16
import torch
16
17
22
23
import gguf
23
24
24
25
# reuse model definitions from convert_hf_to_gguf.py
25
- from convert_hf_to_gguf import Model
26
+ from convert_hf_to_gguf import LazyTorchTensor , Model
26
27
27
28
logger = logging .getLogger ("lora-to-gguf" )
28
29
@@ -35,45 +36,53 @@ class PartialLoraTensor:
35
36
36
37
# magic to support tensor shape modifications and splitting
37
38
class LoraTorchTensor :
38
- _lora_A : Tensor
39
- _lora_B : Tensor
39
+ _lora_A : Tensor # (n_rank, row_size)
40
+ _lora_B : Tensor # (col_size, n_rank)
40
41
_rank : int
41
42
42
43
def __init__ (self , A : Tensor , B : Tensor ):
43
44
assert len (A .shape ) == len (B .shape )
45
+ assert A .shape [- 2 ] == B .shape [- 1 ]
44
46
if A .dtype != B .dtype :
45
47
A = A .to (torch .float32 )
46
48
B = B .to (torch .float32 )
47
49
self ._lora_A = A
48
50
self ._lora_B = B
49
- assert self ._lora_A .shape [- 2 ] == self ._lora_B .shape [- 1 ]
50
- self ._rank = self ._lora_B .shape [- 1 ]
51
+ self ._rank = B .shape [- 1 ]
52
+
53
+ def get_lora_A_B (self ) -> tuple [Tensor , Tensor ]:
54
+ return (self ._lora_A , self ._lora_B )
51
55
52
56
def __getitem__ (
53
57
self ,
54
58
indices : (
55
59
SupportsIndex
56
60
| slice
57
- | tuple [SupportsIndex | slice | EllipsisType | Tensor , ...]
61
+ | tuple [SupportsIndex | slice | Tensor , ...] # TODO: add ellipsis in the type signature
58
62
),
59
63
) -> LoraTorchTensor :
60
64
shape = self .shape
61
- if isinstance (indices , ( SupportsIndex , slice ) ):
65
+ if isinstance (indices , SupportsIndex ):
62
66
if len (shape ) > 2 :
63
67
return LoraTorchTensor (self ._lora_A [indices ], self ._lora_B [indices ])
64
68
else :
65
- raise NotImplementedError
69
+ raise NotImplementedError # can't return a vector
70
+ elif isinstance (indices , slice ):
71
+ if len (shape ) > 2 :
72
+ return LoraTorchTensor (self ._lora_A [indices ], self ._lora_B [indices ])
73
+ else :
74
+ return LoraTorchTensor (self ._lora_A , self ._lora_B [indices ])
66
75
elif isinstance (indices , tuple ):
67
76
assert len (indices ) > 0
68
- if isinstance ( indices [- 1 ], EllipsisType ) :
77
+ if indices [- 1 ] is Ellipsis :
69
78
return self [indices [:- 1 ]]
70
79
# expand ellipsis
71
80
indices = tuple (
72
81
u
73
82
for v in (
74
83
(
75
84
(slice (None , None ) for _ in range (len (indices ) - 1 ))
76
- if isinstance ( i , EllipsisType )
85
+ if i is Ellipsis
77
86
else (i ,)
78
87
)
79
88
for i in indices
@@ -85,19 +94,22 @@ def __getitem__(
85
94
indices = (* indices , * (slice (None , None ) for _ in range (len (indices ), len (shape ))))
86
95
87
96
# TODO: make sure this is correct
88
- # lora_A has a shape which looks like (..., 1, 1, rank, self.shape[-1])
89
97
indices_A = (
90
98
* (
91
- 0 if isinstance (i , SupportsIndex ) else slice (None , None )
92
- for i in indices [:- 2 ]
99
+ (
100
+ j .__index__ () % self ._lora_A .shape [i ]
101
+ if isinstance (j , SupportsIndex )
102
+ else slice (None , None )
103
+ )
104
+ for i , j in enumerate (indices [:- 2 ])
93
105
),
94
106
slice (None , None ),
95
107
indices [- 1 ],
96
108
)
97
109
indices_B = indices [:- 1 ]
98
110
return LoraTorchTensor (self ._lora_A [indices_A ], self ._lora_B [indices_B ])
99
111
else :
100
- raise NotImplementedError
112
+ raise NotImplementedError # unknown indice type
101
113
102
114
@property
103
115
def dtype (self ) -> torch .dtype :
@@ -106,23 +118,37 @@ def dtype(self) -> torch.dtype:
106
118
107
119
@property
108
120
def shape (self ) -> tuple [int , ...]:
121
+ assert len (self ._lora_A .shape ) == len (self ._lora_B .shape )
109
122
return (* self ._lora_B .shape [:- 1 ], self ._lora_A .shape [- 1 ])
110
123
111
124
def size (self , dim = None ):
112
125
assert dim is None
113
126
return self .shape
114
127
115
- def reshape (self , * shape : int | tuple [int ]) -> LoraTorchTensor :
128
+ def reshape (self , * shape : int | tuple [int , ... ]) -> LoraTorchTensor :
116
129
if isinstance (shape [0 ], tuple ):
117
- new_shape : tuple [int ] = shape [0 ]
130
+ new_shape : tuple [int , ... ] = shape [0 ]
118
131
else :
119
- new_shape = cast (tuple [int ], shape )
132
+ new_shape = cast (tuple [int , ... ], shape )
120
133
orig_shape = self .shape
134
+ if len (new_shape ) < 2 :
135
+ raise NotImplementedError # can't become a vector
136
+
137
+ # expand -1 in the shape
138
+ if any (dim == - 1 for dim in new_shape ):
139
+ n_elems = prod (orig_shape )
140
+ n_new_elems = prod (dim if dim != - 1 else 1 for dim in new_shape )
141
+ assert n_elems % n_new_elems == 0
142
+ new_shape = (* (dim if dim != - 1 else n_elems // n_new_elems for dim in new_shape ),)
143
+
121
144
if new_shape [- 1 ] != orig_shape [- 1 ]:
122
- raise NotImplementedError
145
+ raise NotImplementedError # can't reshape the row size trivially
146
+
147
+ shape_A = (* (1 for _ in new_shape [:- 2 ]), self ._rank , orig_shape [- 1 ])
148
+ shape_B = (* new_shape [:- 1 ], self ._rank )
123
149
return LoraTorchTensor (
124
- self ._lora_A .reshape (( * ( 1 for _ in new_shape [: - 2 ]), * self . _lora_A . shape [ - 2 :]) ),
125
- self ._lora_B .reshape (( * new_shape [: - 1 ], self . _rank ) ),
150
+ self ._lora_A .reshape (shape_A ),
151
+ self ._lora_B .reshape (shape_B ),
126
152
)
127
153
128
154
def reshape_as (self , other : Tensor ) -> LoraTorchTensor :
@@ -134,12 +160,15 @@ def view(self, *size: int) -> LoraTorchTensor:
134
160
def permute (self , * dims : int ) -> LoraTorchTensor :
135
161
shape = self .shape
136
162
dims = tuple (dim - len (shape ) if dim >= 0 else dim for dim in dims )
137
- if dims [- 1 ] == - 2 and dims [- 2 ] == - 1 :
138
- return LoraTorchTensor (self ._lora_B .permute (* dims ), self ._lora_A .permute (* dims ))
139
- else :
140
- assert dims [- 1 ] == - 1
163
+ if dims [- 1 ] == - 1 :
164
+ # TODO: support higher dimensional A shapes bigger than 1
141
165
assert all (dim == 1 for dim in self ._lora_A .shape [:- 2 ])
142
166
return LoraTorchTensor (self ._lora_A , self ._lora_B .permute (* dims ))
167
+ if len (shape ) == 2 and dims [- 1 ] == - 2 and dims [- 2 ] == - 1 :
168
+ return LoraTorchTensor (self ._lora_B .permute (* dims ), self ._lora_A .permute (* dims ))
169
+ else :
170
+ # TODO: compose the above two
171
+ raise NotImplementedError
143
172
144
173
def transpose (self , dim0 : int , dim1 : int ) -> LoraTorchTensor :
145
174
shape = self .shape
@@ -181,11 +210,13 @@ def __torch_function__(cls, func: Callable, types, args=(), kwargs=None):
181
210
torch .cat ([a ._lora_A for a in args [0 ]], dim ),
182
211
torch .cat ([b ._lora_B for b in args [0 ]], dim ),
183
212
)
184
- else :
213
+ elif all ( torch . equal ( args [ 0 ][ 0 ]. _lora_A , t . _lora_A ) for t in args [ 0 ][ 1 :]) :
185
214
return LoraTorchTensor (
186
- args [0 ][0 ]._lora_A , # TODO: is this correct? (can't cat over the rank)
215
+ args [0 ][0 ]._lora_A ,
187
216
torch .cat ([b ._lora_B for b in args [0 ]], dim ),
188
217
)
218
+ else :
219
+ raise NotImplementedError
189
220
else :
190
221
raise NotImplementedError
191
222
@@ -205,13 +236,17 @@ def parse_args() -> argparse.Namespace:
205
236
help = "path to write to; default: based on input. {ftype} will be replaced by the outtype." ,
206
237
)
207
238
parser .add_argument (
208
- "--outtype" , type = str , choices = ["f32" , "f16" , "bf16" , "q8_0" ], default = "f16" ,
209
- help = "output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0" ,
239
+ "--outtype" , type = str , choices = ["f32" , "f16" , "bf16" , "q8_0" , "auto" ], default = "f16" ,
240
+ help = "output format - use f32 for float32, f16 for float16, bf16 for bfloat16, q8_0 for Q8_0, auto for the highest-fidelity 16-bit float type depending on the first loaded tensor type " ,
210
241
)
211
242
parser .add_argument (
212
243
"--bigendian" , action = "store_true" ,
213
244
help = "model is executed on big endian machine" ,
214
245
)
246
+ parser .add_argument (
247
+ "--no-lazy" , action = "store_true" ,
248
+ help = "use more RAM by computing all outputs before writing (use in case lazy evaluation is broken)" ,
249
+ )
215
250
parser .add_argument (
216
251
"--verbose" , action = "store_true" ,
217
252
help = "increase output verbosity" ,
@@ -237,13 +272,16 @@ def parse_args() -> argparse.Namespace:
237
272
"f16" : gguf .LlamaFileType .MOSTLY_F16 ,
238
273
"bf16" : gguf .LlamaFileType .MOSTLY_BF16 ,
239
274
"q8_0" : gguf .LlamaFileType .MOSTLY_Q8_0 ,
275
+ "auto" : gguf .LlamaFileType .GUESSED ,
240
276
}
277
+
241
278
ftype = ftype_map [args .outtype ]
242
279
243
- dir_base_model = args .base
244
- dir_lora = args .lora_path
245
- input_json = os .path .join (dir_lora , "adapter_config.json" )
246
- input_model = os .path .join (dir_lora , "adapter_model.safetensors" )
280
+ dir_base_model : Path = args .base
281
+ dir_lora : Path = args .lora_path
282
+ lora_config = dir_lora / "adapter_config.json"
283
+ input_model = dir_lora / "adapter_model.safetensors"
284
+
247
285
if args .outfile is not None :
248
286
fname_out = args .outfile
249
287
else :
@@ -276,6 +314,8 @@ def get_tensors(self) -> Iterator[tuple[str, Tensor]]:
276
314
tensor_map : dict [str , PartialLoraTensor ] = {}
277
315
278
316
for name , tensor in lora_model .items ():
317
+ if self .lazy :
318
+ tensor = LazyTorchTensor .from_eager (tensor )
279
319
base_name = get_base_tensor_name (name )
280
320
is_lora_a = ".lora_A.weight" in name
281
321
is_lora_b = ".lora_B.weight" in name
@@ -305,16 +345,30 @@ def modify_tensors(self, data_torch: Tensor, name: str, bid: int | None) -> Iter
305
345
dest = super ().modify_tensors (data_torch , name , bid )
306
346
for dest_name , dest_data in dest :
307
347
assert isinstance (dest_data , LoraTorchTensor )
308
- # logger.info(f"{orig_name} --> {dest_name}")
309
- yield (dest_name + ".lora_a" , dest_data ._lora_A )
310
- yield (dest_name + ".lora_b" , dest_data ._lora_B )
311
-
312
- model_instance = LoraModel (dir_base_model , ftype , fname_out , args .bigendian , False , False , None )
348
+ lora_a , lora_b = dest_data .get_lora_A_B ()
349
+
350
+ yield (dest_name + ".lora_a" , lora_a )
351
+ yield (dest_name + ".lora_b" , lora_b )
352
+
353
+ model_instance = LoraModel (
354
+ dir_base_model ,
355
+ ftype ,
356
+ fname_out ,
357
+ is_big_endian = args .bigendian ,
358
+ use_temp_file = False ,
359
+ eager = args .no_lazy ,
360
+ model_name = None ,
361
+ )
313
362
logger .info ("Set model parameters" )
314
363
model_instance .set_gguf_parameters ()
315
364
316
- # adapter_config = json.load(input_json)
365
+ with open (lora_config , "r" ) as f :
366
+ lparams : dict [str , Any ] = json .load (f )
367
+
368
+ alpha = lparams ["lora_alpha" ]
369
+
317
370
model_instance .gguf_writer .add_string ("training.type" , "finetune_lora" )
371
+ model_instance .gguf_writer .add_float32 ("training.lora.alpha" , float (alpha ))
318
372
319
373
model_instance .gguf_writer .add_quantization_version (gguf .GGML_QUANT_VERSION )
320
374
logger .info ("Exporting model..." )
0 commit comments