-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy patharc_core.py
More file actions
112 lines (85 loc) · 3.52 KB
/
arc_core.py
File metadata and controls
112 lines (85 loc) · 3.52 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""
Core utilities for ARC testing - prompt generation and model inference
"""
import torch
from utils.tokenization_utils import tokenizer, format_message, tokenize
def question_tok_to_prompt(question_tok, system_message="Respond with just the letter corresponding to the correct answer."):
"""
Convert tokenized question to a full prompt with system and user messages.
Args:
question_tok: Tokenized question text
system_message: System message to use (allows slight variations between scripts)
Returns:
List of token IDs for the complete prompt
"""
tokens = [tokenizer.bos_token_id]
tokens += format_message(
message=system_message,
who="system"
)
tokens += format_message(
message=question_tok + tokenize("\n\nRemember to respond with just the letter corresponding to the correct answer."),
who="user"
)
tokens += format_message(
message=[],
who="assistant",
ended=False
)
return tokens
def batch_inference_with_logits(model, prompts, batch_size=8, device=None):
"""
Run batch inference on prompts and return logits for answer tokens.
Args:
model: The language model
prompts: List of prompts (each as token list)
batch_size: Batch size for inference
device: Device to run on (defaults to cuda if available)
Returns:
List of logits for each prompt
"""
if device is None:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
all_logits = []
for i in range(0, len(prompts), batch_size):
batch = prompts[i:i+batch_size]
# Pad batch to same length
max_len = max(len(tokens) for tokens in batch)
batch_tokens = [[tokenizer.pad_token_id] * (max_len - len(x)) + x for x in batch]
input_ids = torch.tensor(batch_tokens, dtype=torch.long).to(device)
attention_mask = input_ids != tokenizer.pad_token_id
with torch.no_grad():
logits = model(input_ids, attention_mask=attention_mask.to(device)).logits
# Extract last token logits for each item in batch
for j in range(len(batch)):
all_logits.append(logits[j, -1, :])
return all_logits
def get_answer_token_ids():
"""
Get token IDs for answer letters A, B, C, D.
Returns dict mapping letters to their possible token IDs.
"""
tokens_for_each_letter = {
letter: [tokenizer.convert_tokens_to_ids(letter),
tokenizer.convert_tokens_to_ids("Ġ" + letter)]
for letter in ["A", "B", "C", "D"]
}
return tokens_for_each_letter
def extract_letter_probabilities(logits, tokens_for_each_letter=None):
"""
Extract probabilities for answer letters from model logits.
Args:
logits: Model output logits for last token
tokens_for_each_letter: Dict mapping letters to token IDs (uses default if None)
Returns:
Dict mapping letters to their probabilities
"""
if tokens_for_each_letter is None:
tokens_for_each_letter = get_answer_token_ids()
probs = torch.softmax(logits, dim=-1)
letter_probs = {}
for letter, token_ids in tokens_for_each_letter.items():
# Sum probabilities for all token variants of this letter
letter_prob = sum(probs[tid].item() for tid in token_ids if tid < len(probs))
letter_probs[letter] = letter_prob
return letter_probs