Skip to content

Commit

Permalink
Add init weight step
Browse files Browse the repository at this point in the history
  • Loading branch information
alldefector committed Jan 6, 2017
1 parent 68cee20 commit 293f464
Show file tree
Hide file tree
Showing 4 changed files with 80 additions and 10 deletions.
10 changes: 9 additions & 1 deletion doc/ps_example.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,15 @@ def main():
msgtype = msgtype.decode('utf-8')
logging.info('got msg: <%s, %s>, len = %d' % (worker_id, msgtype, len(buf)))

if msgtype == 'GRADIENTS':
if msgtype == 'WEIGHTS':
nweights = next(unpacker)
if weights is None:
weights = np.zeros(nweights, dtype=np.float32)
packer = msgpack.Packer(use_single_float=True, use_bin_type=True)
bw = np.asarray(weights).tobytes()
socket.send(packer.pack(bw))
logging.info('\tsent wgts[%d] %s ... %s' % (len(weights), weights[:3], weights[-3:]))
elif msgtype == 'GRADIENTS':
epochs, grads = next(unpacker), next(unpacker)
logging.info('\tepochs = %d' % epochs)

Expand Down
69 changes: 63 additions & 6 deletions src/dimmwitted.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,13 +80,15 @@ int gibbs(const CmdParser &args) {

// Initialize Gibbs sampling application.
DimmWitted dw(fg, fg->weights.get(), args);

if (dw.is_distributed) {
dw.connect_param_server();
dw.ps_send_msg("FG_LOADED");
dw.ps_reset_weights();
}

dw.learn();
if (!dw.is_distributed) {
if (!dw.is_distributed && dw.opts.n_learning_epoch > 0) {
// in distributed mode, PS manages weights
dw.dump_weights();
}
Expand Down Expand Up @@ -191,6 +193,61 @@ void inspect_vector(T *arr, size_t num) {
<< std::setprecision(ss) << std::endl;
}

void DimmWitted::ps_reset_weights() {
if (!is_distributed) return;
ps_fetch_weights();
// assigned weights to all factor graphs and clear gradients
InferenceResult &infrs = samplers[0].infrs;
for (size_t i = 1; i < n_samplers_; ++i) {
infrs.copy_weights_to(samplers[i].infrs);
samplers[i].infrs.reset_gradients();
}
}

void DimmWitted::ps_fetch_weights() {
if (!is_distributed) return;
InferenceResult &infrs = samplers[0].infrs;
const size_t nweight = infrs.nweights;

// REQUEST FORMAT:
// STRING worker_id
// STRING msgtype ("WEIGHTS")
// INT nweights

std::cout << "\tfetching weights[" << nweight << "]" << std::endl;
msgpack::sbuffer sbuf;
msgpack::packer<msgpack::sbuffer> pk(&sbuf);
pk.pack(opts.worker_id);
pk.pack("WEIGHTS");
pk.pack(nweight);
zmq::message_t msg(sbuf.data(), sbuf.size(), NULL /* no de-allocate */);
ps_socket->send(msg);

// RESPONSE FORMAT:
// FLOAT32[]::BYTES weights

zmq::message_t reply;
ps_socket->recv(&reply);
msgpack::unpacker pac;
pac.reserve_buffer(reply.size());
memcpy(pac.buffer(), reply.data(), reply.size());
pac.buffer_consumed(reply.size());

msgpack::object_handle oh;
pac.next(oh);
std::vector<char> bytes;
oh.get().convert(bytes);

// HACK: abusing grads to store weights
float *grads = infrs.weight_grads.get();
memcpy(grads, bytes.data(), bytes.size());
std::cout << "\treceived weights[" << nweight << "] ";
inspect_vector(grads, nweight);

COPY_ARRAY(grads, nweight, infrs.weight_values.get());
infrs.reset_gradients();
}

bool DimmWitted::ps_update_weights(InferenceResult &infrs, int epochs) {
if (!is_distributed) return false;

Expand All @@ -200,7 +257,7 @@ bool DimmWitted::ps_update_weights(InferenceResult &infrs, int epochs) {

// REQUEST FORMAT:
// STRING worker_id
// STRING msgtype
// STRING msgtype ("GRADIENTS")
// INT epochs
// FLOAT32[]::BYTES grads

Expand Down Expand Up @@ -244,6 +301,7 @@ bool DimmWitted::ps_update_weights(InferenceResult &infrs, int epochs) {
inspect_vector(grads, nweight);

COPY_ARRAY(grads, nweight, infrs.weight_values.get());
infrs.reset_gradients();

if (command == "STOP") {
std::cout << "\tSTOP." << std::endl;
Expand Down Expand Up @@ -310,14 +368,13 @@ void DimmWitted::learn() {

if (is_distributed) {
// distributed worker: sum the gradients
for (size_t i = 1; i < n_samplers_; ++i)
for (size_t i = 1; i < n_samplers_; ++i) {
infrs.merge_gradients_from(samplers[i].infrs);
samplers[i].infrs.reset_gradients();
}

// exchange local gradients for latest weights
stop = ps_update_weights(infrs, n_samplers_ * (i_epoch + 1));

// reset all gradients
for (size_t i = 0; i < n_samplers_; ++i) infrs.reset_gradients();
} else {
// standalone mode: summ the weights
stop = update_weights(infrs, t.elapsed(), current_stepsize, prev_weights);
Expand Down
6 changes: 6 additions & 0 deletions src/dimmwitted.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ class DimmWitted {
// Send a simple message (a string) to PS
void ps_send_msg(std::string message);

// Initialize weights from PS
void ps_reset_weights();

private:
/**
* Connection to paramemter server
Expand All @@ -106,6 +109,9 @@ class DimmWitted {
// Returns true if we can stop learning
bool ps_update_weights(InferenceResult& infrs, int epochs);

// Fetch latest weights from PS
void ps_fetch_weights();

size_t compute_n_epochs(size_t n_epoch);
};

Expand Down
5 changes: 2 additions & 3 deletions src/inference_result.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,9 @@ class InferenceResult {
void dump_marginals_in_text(std::ostream &text_output) const;

inline void update_weight(size_t wid, double stepsize, double gradient) {
double diff = stepsize * gradient;
if (is_distributed) {
// Distributed worker does batch GD and accumulates gradients
weight_grads[wid] += diff;
weight_grads[wid] += gradient;
} else {
// Standalone worker does SGD
double weight = weight_values[wid];
Expand All @@ -84,7 +83,7 @@ class InferenceResult {
default:
std::abort();
}
weight -= diff;
weight -= stepsize * gradient;
weight_values[wid] = weight;
}
}
Expand Down

0 comments on commit 293f464

Please sign in to comment.