forked from chenchenygu/watermark-learnability
-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathaar_watermark.py
111 lines (95 loc) · 4.17 KB
/
aar_watermark.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
from typing import Optional
import scipy.stats
import torch
from transformers import AutoTokenizer
DEFAULT_SEED = 42
class AarWatermark:
def __init__(
self,
vocab_size: int,
k: int,
seed: int = DEFAULT_SEED,
eps: float = 1e-20,
device: Optional[str] = None,
):
if not device:
device = "cuda" if torch.cuda.is_available() else "cpu"
generator = torch.Generator() # generator is always cpu for reproducibility
generator.manual_seed(seed)
# clamp to avoid NaNs
uniform = torch.clamp(torch.rand((vocab_size * k, vocab_size), generator=generator, dtype=torch.float32), min=eps)
self.gumbel = (-torch.log(torch.clamp(-torch.log(uniform), min=eps))).to(device)
self.k = k
self.vocab_size = vocab_size
self.seed = seed
self.eps = eps
self.device = device
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if input_ids.shape[-1] < self.k:
return scores
prev_token = torch.sum(input_ids[:, -self.k:], dim=-1) # (batch_size,)
gumbel = self.gumbel[prev_token] # (batch_size, vocab_size)
return scores[..., :gumbel.shape[-1]] + gumbel
def watermark_logits(
self,
input_ids: torch.LongTensor, # (batch, seq_len)
logits: torch.FloatTensor, # (batch, seq_len, vocab_size)
) -> torch.FloatTensor:
"""Returns watermarked logits to be used as distillation target."""
hashes = torch.sum(input_ids.unfold(-1, self.k, 1), dim=-1) # (batch, seq_len - k + 1)
gumbel = self.gumbel[hashes] # (batch, seq_len - k + 1, vocab_size)
# tokenizer vocab size and model outputs vocab size may be different
logits[..., self.k - 1:, :gumbel.shape[-1]] += gumbel
return logits
def watermark_logits_argmax(
self,
input_ids: torch.LongTensor, # (batch, seq_len)
logits: torch.FloatTensor, # (batch, seq_len, vocab_size)
) -> torch.LongTensor:
"""Finds argmax token for watermark, returns token indexes to be used for cross-entropy loss.
Returns tensor of shape (batch, seq_len), where each element is a token index.
"""
hashes = torch.sum(input_ids.unfold(-1, self.k, 1), dim=-1) # (batch, seq_len - k + 1)
gumbel = self.gumbel[hashes] # (batch, seq_len - k + 1, vocab_size)
# tokenizer vocab size and model outputs vocab size may be different
logits[..., self.k - 1:, :gumbel.shape[-1]] += gumbel # (batch, seq_len, vocab_size)
tokens = torch.argmax(logits, dim=-1) # (batch, seq_len)
return tokens
class AarWatermarkDetector:
def __init__(
self,
tokenizer: AutoTokenizer,
k: int = 1,
seed: int = DEFAULT_SEED,
eps: float = 1e-20,
):
generator = torch.Generator() # generator is always cpu for reproducibility
generator.manual_seed(seed)
vocab_size = len(tokenizer)
self.uniform = torch.clamp(
torch.rand((vocab_size * k, vocab_size), generator=generator, dtype=torch.float32),
min=eps,
max=1 - eps,
)
self.tokenizer = tokenizer
self.k = k
self.seed = seed
self.eps = eps
self.vocab_size = vocab_size
def detect(self, text: str) -> float:
"""
Returns p-value, where null hypothesis is that the text is not watermarked.
Under null hypothesis, each u is Uniform(0, 1), so each score (-log(1 -u )) is Exp(1).
So the sum of scores is distributed as Gamma(n_tokens, 1).
"""
tokens = self.tokenizer.encode(text, return_tensors="pt", add_special_tokens=False)[0] # (seq_len,)
seq_len = tokens.shape[0]
score = 0
# TODO tensorize
for i in range(self.k, seq_len):
prev_tokens_sum = torch.sum(tokens[i - self.k:i], dim=-1)
token = tokens[i]
u = self.uniform[prev_tokens_sum, token]
score += -torch.log(1 - u)
p_value = scipy.stats.gamma.sf(score, seq_len - self.k, loc=0, scale=1)
return p_value