Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
ConorOBrien-Foxx committed May 13, 2024
1 parent 5fc3bce commit 89b39d3
Show file tree
Hide file tree
Showing 3 changed files with 541 additions and 157 deletions.
23 changes: 15 additions & 8 deletions model_wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ def clean_cache_dir(confirm=False):

@staticmethod
def prob_from_logit(logit):
assert false, "Do not use this function"
# TODO: check if this is the correct way to scale CodeGen logits;
# it stands to reason that the falloff could be steeper or less steep

Expand Down Expand Up @@ -143,7 +144,7 @@ def tokenize(self, prompt, time=False):
prompt_tokens = prompt_tokens.to(self.device)
if time:
time_end()
self.yap("Token count in input:", prompt_tokens["input_ids"].shape[1])
# self.yap("Token count in input:", prompt_tokens["input_ids"].shape[1])
return prompt_tokens


Expand All @@ -156,7 +157,7 @@ def model_no_grad(self, *args, **kwargs):

def generate(self, inputs, time=False, *args, **kwargs):
if isinstance(inputs, str):
self.yap("Tokenizing input prompt...")
# self.yap("Tokenizing input prompt...")
inputs = self.tokenize(inputs, time=time)

self.yap("Generating...")
Expand All @@ -174,7 +175,7 @@ def generate(self, inputs, time=False, *args, **kwargs):
def multiple_choice_token(self, inputs, targets, time=False):
assert len(targets) >= 2, "Expected at least 2 targets"
if isinstance(inputs, str):
self.yap("Tokenizing input prompt...")
# self.yap("Tokenizing input prompt...")
inputs = self.tokenize(inputs, time=time)

if time:
Expand Down Expand Up @@ -345,9 +346,13 @@ def _multiple_choice_prompts_multiply(self, input_tokens, target_tokens, time=Fa
# P(A|H) = P(a0|H) * P(a1|H.a0) * ... * P(aN|H.a0.a1...a(N-1))
for idx, tokens in target_tokens:
# goal: calculate P(tokens|H) = 𝚷 P(aj|H.∑ak 0<=k<j) 0<=j<=N
logit_score = base_logits[:, tokens[0, 0]].item()
first_token = tokens[0, 0]
# logit_score = base_logits[:, tokens[0, 0]].item()
# P(a0|H)
total_prob = Model.prob_from_logit(logit_score)
# total_prob = Model.prob_from_logit(logit_score)
initial_distribution = self.softmax(base_logits)
total_prob = initial_distribution[:, first_token].item()
print(f"init = {total_prob * 100:.4f}%")
# H
running_inputs = input_tokens["input_ids"]

Expand All @@ -364,13 +369,15 @@ def _multiple_choice_prompts_multiply(self, input_tokens, target_tokens, time=Fa
distribution = self.softmax(next_logits)
# self.yap("SOFTMAX:", distribution)

logit_score = next_logits[:, token].item()
# logit_score = next_logits[:, token].item()
prob = distribution[:, token].item()
self.yap(f"Token {j}: P={prob}, logit={logit_score}")
# self.yap(f"Token {j}: P={prob}, logit={logit_score}")
print(f"prob = {prob * 100:.4f}%")
total_prob *= prob
# self.yap("Running P:", total_prob)

score = total_prob
print(f"overall = {total_prob * 100:.120f}%")

if best_option_idx is None or score > best_score:
best_score = score
Expand All @@ -396,7 +403,7 @@ def multiple_choice_prompts(self, inputs, targets, time=False, strategy=Multiple
- idx corresponding to the most likely prompt
"""
if isinstance(inputs, str):
self.yap("Tokenizing input prompt...")
# self.yap("Tokenizing input prompt...")
inputs = self.tokenize(inputs, time=time)

# TODO: deduplicate testing for input target tokens
Expand Down
Loading

0 comments on commit 89b39d3

Please sign in to comment.