Skip to content
This repository has been archived by the owner on Oct 25, 2024. It is now read-only.

Commit

Permalink
Update huggingface.py (#1536)
Browse files Browse the repository at this point in the history
Signed-off-by: Wang, Chang <[email protected]>
  • Loading branch information
changwangss authored May 10, 2024
1 parent e559929 commit ab93e47
Showing 1 changed file with 12 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -999,9 +999,14 @@ def _model_call(self, inps, attn_mask=None, labels=None):
labels=labels,
).logits
else:
return self.model(
output = self.model(
input_ids=inps, attention_mask=attn_mask, labels=labels
).logits
)
if isinstance(output, tuple):
output = output[0]
else:
output = output.logits
return output
else:
assert self.AUTO_MODEL_CLASS == transformers.AutoModelForCausalLM
if hasattr(self.model, "config") and hasattr(self.model.config, "auto_map") and \
Expand Down Expand Up @@ -1036,7 +1041,11 @@ def _model_call(self, inps, attn_mask=None, labels=None):
inps, torch.ones(inps.shape, dtype=torch.int64)
).logits
else:
output = self.model(inps).logits
output = self.model(inps)
if isinstance(output, tuple):
output = output[0]
else:
output = output.logits
return output

def _model_generate(self, context, max_length, stop, **generation_kwargs):
Expand Down

0 comments on commit ab93e47

Please sign in to comment.