Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added early stopping capability with multibleu.perl (extra commits removed) #19

Open
wants to merge 4 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions experiments/nmt/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,25 @@ def prototype_state():
state['weight_init_fn'] = 'sample_weights_classic'
state['weight_scale'] = 0.01

# ----- BLEU VALIDATION OPTIONS ----

# Location of the evaluation script
state['bleu_script'] = None
# Location of the validation set
state['validation_set'] = None
# boolean, whether or not to write the validation set to file
state['output_validation_set'] = False
# Location of the validation set output, if different
# fom default
state['validation_set_out'] = None
# Location of what to compare the output translation to (gt)
state['validation_set_grndtruth'] = None
# Beam size during sampling
state['beam_size'] = None
# Number of steps between every validation
state['bleu_val_frequency'] = None


# ---- REGULARIZATION -----

# WARNING: dropout is not tested and probably does not work.
Expand Down
182 changes: 177 additions & 5 deletions experiments/nmt/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,21 @@
import cPickle
import logging
import pprint

import re
import numpy
import time

from groundhog.trainer.SGD_adadelta import SGD as SGD_adadelta
from groundhog.trainer.SGD import SGD as SGD
from groundhog.trainer.SGD_momentum import SGD as SGD_momentum
from groundhog.mainLoop import MainLoop
from experiments.nmt import\
RNNEncoderDecoder, prototype_state, get_batch_iterator
RNNEncoderDecoder, prototype_state, get_batch_iterator, sample,\
BeamSearch, parse_input
import experiments.nmt

from subprocess import Popen, PIPE

logger = logging.getLogger(__name__)

class RandomSamplePrinter(object):
Expand Down Expand Up @@ -45,10 +49,162 @@ def cut_eol(words):
if len(x_words) == 0:
continue

print "Input: {}".format(" ".join(x_words))
print "Target: {}".format(" ".join(y_words))
self.__print_samples("Input", x_words, self.state['source_encoding'])
self.__print_samples("Target", y_words, self.state['target_encoding'])

#if self.state['source_encoding'] == "uft8":
# print u"Input: {}".format(" ".join(x_words))
#elif self.state['source_encoding'] == "ascii":
# print "Input: {}".format(" ".join(x_words))

#if self.state['target_encoding'] == "utf8":
# print u"Target: {}".format(" ".join(y_words))
#elif self.state['target_encoding'] == "ascii":
# print "Target: {}".format(" ".join(y_words))

self.model.get_samples(self.state['seqlen'] + 1, self.state['n_samples'], x[:len(x_words)])
sample_idx += 1
def __print_samples(self, output_name, words, encoding):
if encoding == "utf8":
print u"{}: {}".format(output_name, " ".join(words))
elif encoding == "ascii":
print "{}: {}".format(output_name, " ".join(words))
else:
print "Unknown encoding {}".format(encoding)

class BleuValidator(object):
"""
Object that evaluates the bleu score on the validation set.
Opens the subprocess to run the validation script, and
keeps track of the bleu scores over time
"""
def __init__(self, state, lm_model,
beam_search, ignore_unk=False,
normalize=False, verbose=False):
"""
Handles normal book-keeping of state variables,
but also reloads the bleu scores if they exists

:param state:
a state in the usual groundhog sense
:param lm_model:
a groundhog language model
:param beam_search:
beamsearch object used for sampling
:param ignore_unk
whether or not to ignore unknown characters
:param normalize
whether or not to normalize the score by the length
of the sentence
:param verbose
whether or not to also write the ranslation to the file
specified by validation_set_out

"""

args = dict(locals())
args.pop('self')
self.__dict__.update(**args)

self.indx_word = cPickle.load(open(state['word_indx'],'rb'))
self.idict_src = cPickle.load(open(state['indx_word'],'r'))
self.n_samples = state['beam_size']
self.best_bleu = 0

self.val_bleu_curve = []
if state['reload']:
try:
bleu_score = numpy.load(state['prefix'] + 'val_bleu_scores.npz')
self.val_bleu_curve = bleu_score['bleu_scores'].tolist()
print "BleuScores Reloaded"
except:
print "BleuScores not Found"

# for utf8, we assume that we are interested in character based bleu, this might not be
if state['target_encoding'] == 'utf8':
self.multibleu_cmd = ['perl', state['bleu_script'], '-char', state['validation_set_grndtruth'], '<']
else:
self.multibleu_cmd = ['perl', state['bleu_script'], state['validation_set_grndtruth'], '<']

if verbose:
if 'validation_set_out' not in state.keys():
self.state['validation_set_out'] = state['prefix'] + 'validation_out.txt'

def __call__(self):
"""
Opens the file for the validation set and creates a subprocess
for the multi-bleu script.

Returns a boolean indicating whether the current model should
be saved.
"""

print "Started Validation: "
val_start_time = time.time()
fsrc = open(self.state['validation_set'], 'r')
mb_subprocess = Popen(self.multibleu_cmd, stdin=PIPE, stdout=PIPE)
total_cost = 0.0

if self.verbose:
ftrans = open(self.state['validation_set_out'], 'w')

for i, line in enumerate(fsrc):
"""
Load the sentence, retrieve the sample, write to file
"""
if self.state['source_encoding'] == 'utf8':
seqin = line.strip().decode('utf-8')
else:
seqin = line.strip()
seq, parsed_in = parse_input(self.state, self.indx_word, seqin, idx2word=self.idict_src)

# draw sample, checking to ensure we don't get an empty string back
trans, costs, _ = sample(self.lm_model, seq, self.n_samples,
beam_search=self.beam_search, ignore_unk=self.ignore_unk, normalize=self.normalize)
try:
best = numpy.argmin(costs)
total_cost += costs[best]
trans_out = trans[best]
except ValueError:
print "Could not fine a translation for line: {}".format(i+1)
trans_out = u'UNK' if self.state['target_encoding'] == 'utf8' else 'UNK'

# Write to subprocess and file if it exists
if self.state['target_encoding'] == 'utf8':
print >> mb_subprocess.stdin, trans_out.encode('utf8').replace(" ","")
if self.verbose:
print >> ftrans, trans_out.encode('utf8').replace(" ","")
else:
print >> mb_subprocess.stdin, trans_out
if self.verbose:
print >> ftrans, trans_out

if i != 0 and i % 50 == 0:
print "Translated {} lines of validation set...".format(i)
mb_subprocess.stdin.flush()

print "Total cost of the validation: {}".format(total_cost)
fsrc.close()
if self.verbose:
ftrans.close()

# send end of file, read output.
mb_subprocess.stdin.close()
out_parse = re.match(r'BLEU = [-.0-9]+', mb_subprocess.stdout.readline())
print "Validation Took: {} minutes".format(float(time.time() - val_start_time)/60.)
assert out_parse is not None

# extract the score
bleu_score = float(out_parse.group()[6:])
self.val_bleu_curve.append(bleu_score)
print bleu_score
mb_subprocess.terminate()

# Determine whether or not we should save
if self.best_bleu < bleu_score:
self.best_bleu = bleu_score
return True
return False

def parse_args():
parser = argparse.ArgumentParser()
Expand All @@ -63,7 +219,9 @@ def parse_args():
def main():
args = parse_args()

# this loads the state specified in the prototype
state = getattr(experiments.nmt, args.proto)()
# this is based on the suggestion in the README.md in this foloder
if args.state:
if args.state.endswith(".py"):
state.update(eval(open(args.state).read()))
Expand All @@ -81,16 +239,30 @@ def main():
enc_dec.build()
lm_model = enc_dec.create_lm_model()

# If we are going to use validation with the bleu script, we
# will need early stopping
bleu_validator = None
if state['bleu_script'] is not None and state['validation_set'] is not None\
and state['validation_set_grndtruth'] is not None:
# make beam search
beam_search = BeamSearch(enc_dec)
beam_search.compile()
bleu_validator = BleuValidator(state, lm_model, beam_search, verbose=state['output_validation_set'])

logger.debug("Load data")
train_data = get_batch_iterator(state)
logger.debug("Compile trainer")

algo = eval(state['algo'])(lm_model, state, train_data)
logger.debug("Run training")

main = MainLoop(train_data, None, None, lm_model, algo, state, None,
reset=state['reset'],
bleu_val_fn = bleu_validator,
hooks=[RandomSamplePrinter(state, lm_model, train_data)]
if state['hookFreq'] >= 0
if state['hookFreq'] >= 0 and state['validation_set'] is not None
else None)

if state['reload']:
main.load()
if state['loopIters'] > 0:
Expand Down
15 changes: 15 additions & 0 deletions groundhog/mainLoop.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def __init__(self,
state,
channel,
hooks=None,
bleu_val_fn=None,
reset=-1,
train_cost=False,
validate_postprocess=None,
Expand Down Expand Up @@ -79,6 +80,10 @@ def __init__(self,
:param hooks: list of functions that are called every `hookFreq`
steps to carry on various diagnostics

:type bleu_val_fn: function
:param bleu_val_fn: list of functions that are called every `bleu_val_frequency`
which generates translations for validation, and calls the evaluation script

:type reset: int
:param reset: if larger than 0, the train_data iterator position is
reseted to 0 every `reset` number of updates
Expand Down Expand Up @@ -149,7 +154,9 @@ def __init__(self,
if self.channel is not None:
self.channel.save()

self.bleu_val_fn = bleu_val_fn
self.hooks = hooks

self.reset = reset

self.start_time = time.time()
Expand Down Expand Up @@ -336,6 +343,14 @@ def main(self):
self.step % self.state['hookFreq'] == 0 and \
self.hooks:
[fn() for fn in self.hooks]

if self.state['bleu_val_frequency'] is not None and \
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not good: Groundhog's main_loop should keep being task agnostic. Can you call you BLUE score computation with the standard validation callback?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree this is not ideal. But initially I thought it'd be nice to both track the standard cross entropy in the validation call back + compute bleu (occasionally) so we can compare how well we are optimizing both. Do you think we should have both in one validation call?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it makes sense, as both are validation costs. Alternatively you can do bleu-related things using the hook mechanism (which is currently used to print samples).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alright. I'm gonna implement a normal validation function using cross entropy, since that seems to be missing. Might be interesting to see how well it correlates with BLEU.

Kelvin

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice! But if just moved the code that calls BLEU computation from main_loop into the validation callback, we could merge this pull request.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see, I'll move this by the end of the week.

self.step % self.state['bleu_val_frequency'] == 0 \
and self.bleu_val_fn is not None and self.step > 0:
if self.bleu_val_fn():
self.model.save(self.state['prefix']+'best_bleu_'+'model.npz')
numpy.savez(self.state['prefix'] + 'val_bleu_scores.npz', bleu_scores=self.bleu_val_fn.val_bleu_curve)

if self.reset > 0 and self.step > 1 and \
self.step % self.reset == 0:
print 'Resetting the data iterator'
Expand Down