Skip to content

Commit

Permalink
reformat code using Black
Browse files Browse the repository at this point in the history
  • Loading branch information
Kaiyu Yang committed Jan 22, 2021
1 parent 37f93ce commit 97f1098
Show file tree
Hide file tree
Showing 33 changed files with 3,096 additions and 2,009 deletions.
368 changes: 216 additions & 152 deletions ASTactic/agent.py

Large diffs are not rendered by default.

58 changes: 36 additions & 22 deletions ASTactic/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from progressbar import ProgressBar
import os
import sys

sys.setrecursionlimit(100000)
import pickle
from collections import defaultdict
Expand All @@ -15,26 +16,24 @@


class ProofStepsData(Dataset):

def __init__(self, split, opts):
super().__init__()
self.opts = opts

if split in ['train', 'valid']:
self.proof_steps = glob(os.path.join(opts.datapath, split, '*.pickle'))
elif split == 'train_valid':
self.proof_steps = glob(os.path.join(opts.datapath, 'train/*.pickle')) + \
glob(os.path.join(opts.datapath, 'valid/*.pickle'))
if split in ["train", "valid"]:
self.proof_steps = glob(os.path.join(opts.datapath, split, "*.pickle"))
elif split == "train_valid":
self.proof_steps = glob(
os.path.join(opts.datapath, "train/*.pickle")
) + glob(os.path.join(opts.datapath, "valid/*.pickle"))
random.shuffle(self.proof_steps)
print('%d proof steps in %s' % (len(self), split))

print("%d proof steps in %s" % (len(self), split))

def __len__(self):
return len(self.proof_steps)


def __getitem__(self, idx):
'''
"""
step = {
'file': STR,
'proof_name': STR,
Expand All @@ -51,35 +50,50 @@ def __getitem__(self, idx):
'tactic_actions': [INT|STR],
'tactic_str': STR,
}
'''
proof_step = pickle.load(open(self.proof_steps[idx], 'rb'))
proof_step['goal'] = proof_step['goal']['ast']
proof_step['tactic_actions'] = proof_step['tactic']['actions']
proof_step['tactic_str'] = proof_step['tactic']['text']
del proof_step['tactic']
"""
proof_step = pickle.load(open(self.proof_steps[idx], "rb"))
proof_step["goal"] = proof_step["goal"]["ast"]
proof_step["tactic_actions"] = proof_step["tactic"]["actions"]
proof_step["tactic_str"] = proof_step["tactic"]["text"]
del proof_step["tactic"]

return proof_step


def create_dataloader(split, opts):
def merge(batch):
fields = ['file', 'proof_name', 'n_step', 'env', 'local_context', 'goal', 'is_synthetic', 'tactic_actions', 'tactic_str']
fields = [
"file",
"proof_name",
"n_step",
"env",
"local_context",
"goal",
"is_synthetic",
"tactic_actions",
"tactic_str",
]
data_batch = {key: [] for key in fields}
for example in batch:
for key, value in example.items():
if key not in fields:
continue
continue
data_batch[key].append(value)
return data_batch

ds = ProofStepsData(split, opts)
return DataLoader(ds, opts.batchsize, shuffle=split.startswith('train'), collate_fn=merge,
num_workers=opts.num_workers)
return DataLoader(
ds,
opts.batchsize,
shuffle=split.startswith("train"),
collate_fn=merge,
num_workers=opts.num_workers,
)


if __name__ == '__main__':
if __name__ == "__main__":
opts = parse_args()
loader = create_dataloader('train', opts)
loader = create_dataloader("train", opts)
bar = ProgressBar(max_value=len(loader))
for i, data_batch in enumerate(loader):
if i == 0:
Expand Down
136 changes: 89 additions & 47 deletions ASTactic/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,12 @@
import json
import os
import sys

sys.setrecursionlimit(100000)
sys.path.append(os.path.normpath(os.path.dirname(os.path.realpath(__file__))))
sys.path.append(os.path.normpath(os.path.join(os.path.dirname(os.path.realpath(__file__)), '../')))
sys.path.append(
os.path.normpath(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../"))
)
from hashlib import md5
from utils import log
from progressbar import ProgressBar
Expand All @@ -17,52 +20,70 @@
import pdb


if __name__ == '__main__':
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('method', type=str)
parser.add_argument('eval_id', type=str)
parser.add_argument('--datapath', type=str, default='../data')
parser.add_argument('--projs_split', type=str, default='../projs_split.json')
parser.add_argument('--split', choices=['train', 'valid', 'test'], type=str, default='test')
parser.add_argument('--file', type=str)
parser.add_argument('--proof', type=str)
parser.add_argument('--filter', type=str)
parser.add_argument('--path', type=str)
parser.add_argument('--output_dir', type=str, default='evaluation')
parser.add_argument('--max_num_tactics', type=int, default=300)
parser.add_argument('--timeout', type=int, default=600)
parser.add_argument('--hammer_timeout', type=int, default=100)
parser.add_argument('--depth_limit', type=int, default=50)
parser.add_argument('--beam_width', type=int, default=20) # lots of timeout when >200
parser.add_argument('--num_tactic_candidates', type=int, default=20)
parser.add_argument('--lens_norm', type=float, default=0.5, help='lengths normalization')
parser.add_argument('--tac_grammar', type=str, default='tactics.ebnf')
parser.add_argument('--term_embedding_dim', type=int, default=256)
parser.add_argument('--size_limit', type=int, default=50)
parser.add_argument('--embedding_dim', type=int, default=256, help='dimension of the grammar embeddings')
parser.add_argument('--symbol_dim', type=int, default=256, help='dimension of the terminal/nonterminal symbol embeddings')
parser.add_argument('--hidden_dim', type=int, default=256, help='dimension of the LSTM controller')
parser.add_argument('--seed', type=int, default=0)
parser.add_argument("method", type=str)
parser.add_argument("eval_id", type=str)
parser.add_argument("--datapath", type=str, default="../data")
parser.add_argument("--projs_split", type=str, default="../projs_split.json")
parser.add_argument(
"--split", choices=["train", "valid", "test"], type=str, default="test"
)
parser.add_argument("--file", type=str)
parser.add_argument("--proof", type=str)
parser.add_argument("--filter", type=str)
parser.add_argument("--path", type=str)
parser.add_argument("--output_dir", type=str, default="evaluation")
parser.add_argument("--max_num_tactics", type=int, default=300)
parser.add_argument("--timeout", type=int, default=600)
parser.add_argument("--hammer_timeout", type=int, default=100)
parser.add_argument("--depth_limit", type=int, default=50)
parser.add_argument(
"--beam_width", type=int, default=20
) # lots of timeout when >200
parser.add_argument("--num_tactic_candidates", type=int, default=20)
parser.add_argument(
"--lens_norm", type=float, default=0.5, help="lengths normalization"
)
parser.add_argument("--tac_grammar", type=str, default="tactics.ebnf")
parser.add_argument("--term_embedding_dim", type=int, default=256)
parser.add_argument("--size_limit", type=int, default=50)
parser.add_argument(
"--embedding_dim",
type=int,
default=256,
help="dimension of the grammar embeddings",
)
parser.add_argument(
"--symbol_dim",
type=int,
default=256,
help="dimension of the terminal/nonterminal symbol embeddings",
)
parser.add_argument(
"--hidden_dim", type=int, default=256, help="dimension of the LSTM controller"
)
parser.add_argument("--seed", type=int, default=0)
opts = parser.parse_args()
log(opts)
opts.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
if opts.device.type == 'cpu':
log('using CPU', 'WARNING')
opts.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
if opts.device.type == "cpu":
log("using CPU", "WARNING")

torch.manual_seed(opts.seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
np.random.seed(opts.seed)
random.seed(opts.seed)

if 'ours' in opts.method:
if "ours" in opts.method:
model = Prover(opts)
log('loading model checkpoint from %s..' % opts.path)
if opts.device.type == 'cpu':
checkpoint = torch.load(opts.path, map_location='cpu')
log("loading model checkpoint from %s.." % opts.path)
if opts.device.type == "cpu":
checkpoint = torch.load(opts.path, map_location="cpu")
else:
checkpoint = torch.load(opts.path)
model.load_state_dict(checkpoint['state_dict'])
model.load_state_dict(checkpoint["state_dict"])
model.to(opts.device)
else:
model = None
Expand All @@ -73,34 +94,55 @@
files = [opts.file]
else:
files = []
projs = json.load(open(opts.projs_split))['projs_' + opts.split]
projs = json.load(open(opts.projs_split))["projs_" + opts.split]
for proj in projs:
files.extend(glob(os.path.join(opts.datapath, '%s/**/*.json' % proj), recursive=True))
files.extend(
glob(os.path.join(opts.datapath, "%s/**/*.json" % proj), recursive=True)
)

if opts.filter:
files = [f for f in files if md5(f.encode('utf-8')).hexdigest().startswith(opts.filter)]
files = [
f
for f in files
if md5(f.encode("utf-8")).hexdigest().startswith(opts.filter)
]

print(files)
results = []
bar = ProgressBar(max_value=len(files))
for i, f in enumerate(files):
print('file: ', f)
#print('cuda memory allocated before file: ', torch.cuda.memory_allocated(opts.device), file=sys.stderr)
print("file: ", f)
# print('cuda memory allocated before file: ', torch.cuda.memory_allocated(opts.device), file=sys.stderr)
results.extend(agent.evaluate(f, opts.proof))
bar.update(i)

oup_dir = os.path.join(opts.output_dir, opts.eval_id)
if not os.path.exists(oup_dir):
os.makedirs(oup_dir)
os.makedirs(oup_dir)
if opts.filter is None and opts.file is None:
oup_file = os.path.join(oup_dir, 'results.json')
oup_file = os.path.join(oup_dir, "results.json")
elif opts.file is None:
oup_file = os.path.join(oup_dir, '%s.json' % opts.filter)
oup_file = os.path.join(oup_dir, "%s.json" % opts.filter)
elif opts.proof is None:
oup_file = os.path.join(oup_dir, '%s.json' % os.path.sep.join(opts.file.split(os.path.sep)[2:]).replace(os.path.sep, '-'))
oup_file = os.path.join(
oup_dir,
"%s.json"
% os.path.sep.join(opts.file.split(os.path.sep)[2:]).replace(
os.path.sep, "-"
),
)
else:
oup_file = os.path.join(oup_dir, '%s-%s.json' % (os.path.sep.join(opts.file.split(os.path.sep)[2:]).replace(os.path.sep, '-'), opts.proof))
oup_file = os.path.join(
oup_dir,
"%s-%s.json"
% (
os.path.sep.join(opts.file.split(os.path.sep)[2:]).replace(
os.path.sep, "-"
),
opts.proof,
),
)
opts = vars(opts)
del opts['device']
json.dump({'options': opts, 'results': results}, open(oup_file, 'wt'))
log('results saved to ' + oup_file)
del opts["device"]
json.dump({"options": opts, "results": results}, open(oup_file, "wt"))
log("results saved to " + oup_file)
27 changes: 19 additions & 8 deletions ASTactic/evaluation/display.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,22 +12,33 @@
num_fail = defaultdict(int)
avg_time = 0

for f in glob(os.path.join(sys.argv[1], '*.json')):
results = json.load(open(f))['results']
for f in glob(os.path.join(sys.argv[1], "*.json")):
results = json.load(open(f))["results"]
for r in results:
proj = r['filename'].split(os.path.sep)[2]
if r['success'] and r['time'] <= TIME_LIMIT:
proj = r["filename"].split(os.path.sep)[2]
if r["success"] and r["time"] <= TIME_LIMIT:
num_success[proj] += 1
avg_time += r['time']
avg_time += r["time"]
else:
num_fail[proj] += 1

total_success = 0
total_fail = 0
for proj in set(num_success.keys()).union(set(num_fail.keys())):
print('%50s:\t%.04f\t%d/%d' % (proj, num_success[proj] / (num_success[proj] + num_fail[proj]), num_success[proj], num_fail[proj]))
print(
"%50s:\t%.04f\t%d/%d"
% (
proj,
num_success[proj] / (num_success[proj] + num_fail[proj]),
num_success[proj],
num_fail[proj],
)
)
total_success += num_success[proj]
total_fail += num_fail[proj]

print('\nIN TOTAL:\t%.04f\t%d/%d' % (total_success / (total_success + total_fail), total_success, total_fail))
print('avg_time', avg_time / total_success)
print(
"\nIN TOTAL:\t%.04f\t%d/%d"
% (total_success / (total_success + total_fail), total_success, total_fail)
)
print("avg_time", avg_time / total_success)
Loading

0 comments on commit 97f1098

Please sign in to comment.