Skip to content

adding Context Length Specialization (CCL) #388

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

quic-vjanfaza
Copy link

No description provided.

@@ -1388,6 +1389,9 @@ def from_pretrained(

kv_offload = kwargs.pop("kv_offload", None)

comp_ctx_lengths = kwargs.pop("comp_ctx_lengths", None)
cls.comp_ctx_lengths = comp_ctx_lengths
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not needed. You can pass the comp_ctx_lengths=comp_ctx_lengths as a kwarg in line 1407 while instantiating a class.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can't remove these lines since we call cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs) in line 1396 which is looking for comp_ctx_lengths in the input arguments and causes error.

@@ -1422,7 +1426,7 @@ def model_hash(self) -> str:
def get_model_config(self) -> dict:
return self.model.config.__dict__

def export(self, export_dir: Optional[str] = None) -> str:
def export(self, comp_ctx_lengths: Optional[List[int]] = None, export_dir: Optional[str] = None) -> str:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since comp_ctx_lengths is an instance variable we would not need to have an argument here.

@@ -1442,10 +1446,12 @@ def export(self, export_dir: Optional[str] = None) -> str:
"input_ids": torch.zeros((bs, seq_len), dtype=torch.int64),
"position_ids": torch.arange(seq_len, dtype=torch.int64).view(1, seq_len).repeat(bs, 1),
"past_key_values": [[] for _ in range(self.num_layers)],
"comp_ctx_lengths": torch.randint(0, 100, (40,), dtype=torch.long),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we want to add example inputs and dynamic_axes by default? shouldnt we check if comp_ctx_lengths is not None and then add it?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, we don't need to do that. Since, I generalized the code to use the same onnx file for both with CCL and the default without CCL experiments. This will reduce the number of changes in model file as well and will prevent having multiple onnx generation for with and without CCL experiments.

max_ccl_id = len(self.comp_ctx_lengths) - 1
max_position_id = np.max(decode_inputs["position_ids"])
ccl_id = 1
for i in range(1, len(self.comp_ctx_lengths)):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we have a reverse list and pop out the last value if max_position_id < self.comp_ctx_lengths[-1]? this way we can avoid the loop

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why should we check with the last element? Each request can be finished in different position_id and we need to check to find the most suitable CCL window to get the best performance. This for loop only happens at the end of a request and it's an order of length(CCL) that can't be more than a few values because of compiler limitation in the number of specializations.

ccl_id = 1
for i in range(1, len(self.comp_ctx_lengths)):
if max_position_id < self.comp_ctx_lengths[i]:
ccl_id = i
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same as above

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants