-
Notifications
You must be signed in to change notification settings - Fork 43
adding Context Length Specialization (CCL) #388
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
adding Context Length Specialization (CCL) #388
Conversation
Signed-off-by: vjanfaza <[email protected]>
Signed-off-by: vjanfaza <[email protected]>
@@ -1388,6 +1389,9 @@ def from_pretrained( | |||
|
|||
kv_offload = kwargs.pop("kv_offload", None) | |||
|
|||
comp_ctx_lengths = kwargs.pop("comp_ctx_lengths", None) | |||
cls.comp_ctx_lengths = comp_ctx_lengths |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not needed. You can pass the comp_ctx_lengths=comp_ctx_lengths as a kwarg in line 1407 while instantiating a class.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can't remove these lines since we call cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) in line 1396 which is looking for comp_ctx_lengths in the input arguments and causes error.
@@ -1422,7 +1426,7 @@ def model_hash(self) -> str: | |||
def get_model_config(self) -> dict: | |||
return self.model.config.__dict__ | |||
|
|||
def export(self, export_dir: Optional[str] = None) -> str: | |||
def export(self, comp_ctx_lengths: Optional[List[int]] = None, export_dir: Optional[str] = None) -> str: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since comp_ctx_lengths is an instance variable we would not need to have an argument here.
@@ -1442,10 +1446,12 @@ def export(self, export_dir: Optional[str] = None) -> str: | |||
"input_ids": torch.zeros((bs, seq_len), dtype=torch.int64), | |||
"position_ids": torch.arange(seq_len, dtype=torch.int64).view(1, seq_len).repeat(bs, 1), | |||
"past_key_values": [[] for _ in range(self.num_layers)], | |||
"comp_ctx_lengths": torch.randint(0, 100, (40,), dtype=torch.long), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Do we want to add example inputs and dynamic_axes by default? shouldnt we check if comp_ctx_lengths is not None and then add it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, we don't need to do that. Since, I generalized the code to use the same onnx file for both with CCL and the default without CCL experiments. This will reduce the number of changes in model file as well and will prevent having multiple onnx generation for with and without CCL experiments.
max_ccl_id = len(self.comp_ctx_lengths) - 1 | ||
max_position_id = np.max(decode_inputs["position_ids"]) | ||
ccl_id = 1 | ||
for i in range(1, len(self.comp_ctx_lengths)): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
can we have a reverse list and pop out the last value if max_position_id < self.comp_ctx_lengths[-1]? this way we can avoid the loop
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why should we check with the last element? Each request can be finished in different position_id and we need to check to find the most suitable CCL window to get the best performance. This for loop only happens at the end of a request and it's an order of length(CCL) that can't be more than a few values because of compiler limitation in the number of specializations.
ccl_id = 1 | ||
for i in range(1, len(self.comp_ctx_lengths)): | ||
if max_position_id < self.comp_ctx_lengths[i]: | ||
ccl_id = i |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
same as above
No description provided.