From 09eeb78a9f9ad25184dba84ba16398995b966702 Mon Sep 17 00:00:00 2001 From: Wenwen Qu Date: Tue, 26 Nov 2024 15:56:50 +0800 Subject: [PATCH 1/4] add async communicator --- .../model/moe/ampipe/async_communicator.py | 169 ++++++++++++++++++ 1 file changed, 169 insertions(+) create mode 100644 internlm/model/moe/ampipe/async_communicator.py diff --git a/internlm/model/moe/ampipe/async_communicator.py b/internlm/model/moe/ampipe/async_communicator.py new file mode 100644 index 00000000..abf02cd2 --- /dev/null +++ b/internlm/model/moe/ampipe/async_communicator.py @@ -0,0 +1,169 @@ +# Copyright (c) 2024, Huawei Technologies Co., Ltd. All rights reserved. +# http://www.apache.org/licenses/LICENSE-2.0 +import torch + +from internlm.core.context import ParallelMode +from internlm.core.context import global_context as gpc +from internlm.model.moe.utils import all_to_all + + +class AsyncCommunication: + def __init__(self, fwd_args, bwd_args=None): + self.bwd_args = bwd_args + self.fwd_args = fwd_args + + def comm_before_moe_mlp_fwd(self, ctx, dispatched_input): + cur_degree = self.fwd_args.cur_degree + a2a_events = self.fwd_args.a2a_events + mlp_inputs = self.fwd_args.mlp_inputs + a2a_inputs = self.fwd_args.a2a_inputs + + model_args = gpc.config.model + pipe_experts = model_args.use_pipe_experts + num_experts = model_args.num_experts + num_local_experts = num_experts // gpc.get_world_size(ParallelMode.EXPERT) + + # 不开启ampipe_tp_sp_comm_overlap时,不切分专家维度,直接做全量专家的all2all + if not model_args.ampipe_tp_sp_comm_overlap: + a2a_tokens, a2a_handle = dispatched_inputs, _ = all_to_all( + dispatched_input, group=self.ep_group, async_op=True + ) + a2a_events.append(a2a_handle) + mlp_inputs[cur_degree] = a2a_tokens + return mlp_inputs + + # TODO: 这一段似乎和tp的特定优化有关,interevo似乎还没实现,待定 + # 开启ampipe_tp_sp_comm_overlap时,按照专家切分token后再all2all + # chunk_list = dispatched_input.chunk(num_experts) + # for exp_index in range(num_local_experts): + # chunks = chunk_list[exp_index:num_experts:num_local_experts] + # a2a_tokens = torch.cat(chunks) + # # pipe-experts适配 + # if pipe_experts: + # comm_result = self._pipe_expert_comm_before_moe_mlp_fwd(ctx, exp_index, a2a_tokens) + # if comm_result is not None: + # continue + # # 不开启pipe_experts或者pipe_experts_multi_data < ampipe_degree时不再切分token,直接all2all + # output, a2a_handle = all_to_all(a2a_tokens, group=self.ep_group, async_op=True) + # index = cur_degree * num_local_experts + exp_index + # mlp_inputs[index] = output + # a2a_events[index] = a2a_handle + # # 不提前析构通信tensor,保证正常释放通信后tensor内存 + # a2a_inputs.append(a2a_tokens) + # return mlp_inputs + + def comm_before_moe_mlp_bwd(self, ctx, grad_moe_out_chunk): + cur_degree = self.bwd_args.cur_degree + a2a_events = self.bwd_args.a2a_events + grad_mlp_input_list = self.bwd_args.grad_mlp_input_list + grad_a2a_input_list = self.bwd_args.grad_a2a_input_list + # 反向第一次all2all + # 纯ep通信隐藏 + if not gpc.config.model.ampipe_tp_sp_comm_overlap: + grad_mlp_input_list[cur_degree], a2a_handle = all_to_all( + grad_moe_out_chunk, group=self.ep_group, async_op=True + ) + a2a_events.insert(0, a2a_handle) + return grad_mlp_input_list + + # tp-sp域&ep域通信隐藏适配 + # chunk_list = grad_moe_out_chunk.chunk(ctx.num_experts) + # for exp_index in range(ctx.num_local_experts): + # chunks = chunk_list[exp_index:ctx.num_experts:ctx.num_local_experts] + # grad_mlp_tokens = torch.cat(chunks) + # # pipe-experts适配 + # if ctx.pipe_experts: + # comm_result = self._pipe_expert_comm_before_moe_mlp_bwd(ctx, exp_index, grad_mlp_tokens) + # if comm_result is not None: + # continue + # # 不开启pipe_experts或者pipe_experts_multi_data < ampipe_degree时不再切分token,直接all2all + # grad_a2a_tokens, a2a_handle = async_all_to_all(grad_mlp_tokens) + # index = (ctx.pipe_degree - 1 - cur_degree) * ctx.num_local_experts + exp_index + # grad_mlp_input_list[index] = grad_a2a_tokens + # a2a_events[index] = a2a_handle + # # 不提前析构通信tensor,保证正常释放通信后tensor内存 + # grad_a2a_input_list[index] = grad_mlp_tokens + # return grad_mlp_input_list + + # TODO 跟多专家pipe有关,目前internevo尚未合入相关逻辑 + # def _pipe_expert_comm_before_moe_mlp_fwd(self, ctx, exp_index, input_tokens): + # cur_degree = self.fwd_args.cur_degree + # a2a_events = self.fwd_args.a2a_events + # mlp_inputs = self.fwd_args.mlp_inputs + # a2a_inputs = self.fwd_args.a2a_inputs + # ag_events = self.fwd_args.ag_events + # model_args = gpc.config.model + # pipe_degree = model_args.ampipe_degree + # pipe_experts_multi_data = model_args.pipe_experts_multi_data + # pipe_experts_multi_stream = model_args.pipe_experts_multi_stream + # # pipe_experts_multi_data > ampipe_degree时, 对token的C维度再切分 + # ctx.slice_size = slice_size = pipe_experts_multi_data // pipe_degree + # a2a_token_chunk = input_tokens.chunk(slice_size, dim=1) + # # 多流场景下pipe_experts_multi_data必须大于等于ampipe_degree + # if pipe_experts_multi_data >= pipe_degree and pipe_experts_multi_stream: + # for i in range(slice_size): + # # 计算列表中索引适配pipe_experts + # index = cur_degree * slice_size + exp_index * pipe_experts_multi_data + i + # if (cur_degree + exp_index + i) == 0 and gpc.config.parallel.get("sequence_parallel", False): + # a2a_token, a2a_handle = async_all_to_all(a2a_token_chunk[i]) + # else: + # a2a_token, a2a_handle = async_all_to_all(a2a_token_chunk[i], ag_events[index]) + # a2a_events[index] = a2a_handle + # mlp_inputs[index] = a2a_token + # if args.sequence_parallel: + # ag_token, ag_handle = async_fw_all_gather(a2a_token, a2a_handle, ampipe_with_mlp_multistream=True, + # index=index) + # ag_events[index] = ag_handle + # mlp_inputs[index] = ag_token + # return mlp_inputs + # # 非多流场景下pipe_experts_multi_data必须大于ampipe_degree + # elif pipe_experts_multi_data > pipe_degree and not pipe_experts_multi_stream: + # for i in range(slice_size): + # a2a_token, a2a_handle = async_all_to_all(a2a_token_chunk[i]) + # index = cur_degree * slice_size + exp_index * pipe_experts_multi_data + i + # a2a_events[index] = a2a_handle + # mlp_inputs[index] = a2a_token + # a2a_inputs.append(a2a_token_chunk[i]) + # return mlp_inputs + # return None + + # def _pipe_expert_comm_before_moe_mlp_bwd(self, ctx, exp_index, grad_tokens): + # cur_degree = self.bwd_args.cur_degree + # a2a_events = self.bwd_args.a2a_events + # grad_mlp_input_list = self.bwd_args.grad_mlp_input_list + # ag_events = self.bwd_args.ag_events + # args = get_args() + # pipe_degree = args.ampipe_degree + # grad_token_list = grad_tokens.chunk(ctx.slice_size, dim=1) + # # 多流场景下pipe_experts_multi_data必须大于等于ampipe_degree + # if ctx.pipe_experts_multi_data >= pipe_degree and ctx.pipe_experts_multi_stream: + # for i in range(ctx.slice_size): + # # 计算列表中索引适配pipe_experts + # index = (pipe_degree - 1 - cur_degree) * ctx.slice_size + exp_index * ctx.pipe_experts_multi_data + i + # if cur_degree == pipe_degree - 1 and (exp_index + i) == 0 and args.sequence_parallel: + # a2a_token, a2a_handle = async_all_to_all(grad_token_list[i]) + # else: + # a2a_token, a2a_handle = async_all_to_all(grad_token_list[i], ag_events[index]) + # a2a_events[index] = a2a_handle + # grad_mlp_input_list[index] = a2a_token + # if args.sequence_parallel: + # ag_token, ag_handle = async_all_gather(a2a_token, a2a_handle, is_bwd=True) + # ag_events[index] = ag_handle + # grad_mlp_input_list[index] = ag_token + # return grad_mlp_input_list + # # 非多流场景下pipe_experts_multi_data必须大于ampipe_degree + # elif ctx.pipe_experts_multi_data > pipe_degree and not ctx.pipe_experts_multi_stream: + # for i in range(ctx.slice_size): + # a2a_token, a2a_handle = async_all_to_all(grad_token_list[i]) + # index = (pipe_degree - 1 - cur_degree) * ctx.slice_size + exp_index * ctx.pipe_experts_multi_data + i + # a2a_events[index] = a2a_handle + # grad_mlp_input_list[index] = a2a_token + # return grad_mlp_input_list + # return None + + # def fw_all_gather_not_multistream(self): + # self.fwd_args.a2a_events[0].wait() + # # 释放通信内存 + # self.fwd_args.a2a_inputs.pop() + # _, ag_handle = async_fw_all_gather(self.fwd_args.mlp_inputs[0]) + # self.fwd_args.ag_events.append(ag_handle) From 191061fddf90fc7b4b8abc260dab64b06412ac3c Mon Sep 17 00:00:00 2001 From: Wenwen Qu Date: Thu, 5 Dec 2024 16:14:00 +0800 Subject: [PATCH 2/4] add explicit fwd bwd for attn --- internlm/model/modules/linear.py | 80 ++++ internlm/model/moe/ampipe/ampipe.py | 626 +++++++++++++++++++++++++ internlm/model/moe/ampipe/fa_helper.py | 261 +++++++++++ internlm/model/ops/norm.py | 18 + 4 files changed, 985 insertions(+) create mode 100644 internlm/model/moe/ampipe/ampipe.py create mode 100644 internlm/model/moe/ampipe/fa_helper.py diff --git a/internlm/model/modules/linear.py b/internlm/model/modules/linear.py index 29070b42..05c7ae43 100644 --- a/internlm/model/modules/linear.py +++ b/internlm/model/modules/linear.py @@ -555,6 +555,86 @@ def fused_dense_func( ) +def explicit_fused_dense_forward( + ctx, + x: torch.Tensor, + weight: torch.Tensor, + communicator: Union[TPCommunicator, WPCommunicator], + module: Optional[nn.Module] = None, + bias: Optional[torch.Tensor] = None, + return_residual: bool = False, + use_grouped_linear: bool = False, + **kwargs, +): + if communicator.communication_mode() == "wp": + if not use_grouped_linear: + return WPFusedDenseFunc.forward( + ctx, + x, + weight, + bias, + module, + communicator, + return_residual, + ) + else: + batch_sizes = kwargs.pop("batch_sizes", None) + backend = kwargs.pop("backend", "gmm") + full_weight_shape = kwargs.pop("full_weight_shape", None) + return GroupedGemmWPFusedDenseFunc.forward( + ctx, + x, + weight, + module, + communicator, + batch_sizes, + backend, + full_weight_shape, + ) + else: # mtp, msp, and fsp + if not use_grouped_linear: + return SPFusedDenseFunc.forward( + ctx, + x, + weight, + bias, + communicator, + return_residual, + ) + else: + # TODO: support grouped linear for mtp, msp, and fsp + batch_sizes = kwargs.pop("batch_sizes", None) + backend = kwargs.pop("backend", "gmm") + return GroupedGemmSPFusedDenseFunc.forward( + ctx, + x, + weight, + batch_sizes, + backend, + ) + + +def explicit_fused_dense_backward( + ctx, + grad_output: torch.Tensor, +): + if communicator.communication_mode() == "wp": + if not use_grouped_linear: + grad_input, grad_weight, grad_bias, *_ = WPFusedDenseFunc.backward(ctx, grad_output) + else: + grad_input, grad_weight = GroupedGemmWPFusedDenseFunc.backward(ctx, grad_output) + grad_bias = None + else: # mtp, msp, and fsp + if not use_grouped_linear: + grad_input, grad_weight, grad_bias, *_ = SPFusedDenseFunc.backward(ctx, grad_output) + else: + # TODO: support grouped linear for mtp, msp, and fsp + grad_input, grad_weight = GroupedGemmSPFusedDenseFunc.backward(ctx, grad_output) + grad_bias = None + + return grad_input, grad_weight, grad_bias + + class ParallelLinearWithCommExt(nn.Linear): """ Parallel linear with commuication extention. diff --git a/internlm/model/moe/ampipe/ampipe.py b/internlm/model/moe/ampipe/ampipe.py new file mode 100644 index 00000000..8e77015c --- /dev/null +++ b/internlm/model/moe/ampipe/ampipe.py @@ -0,0 +1,626 @@ +import torch +from einops import rearrange +from torch.autograd.function import NestedIOFunction + +from internlm.core.context import random, ParallelMode + +from .fa_helper import flash_attn_bwd, flash_attn_fwd +import os +DEBUG=int(os.environ.get('DEBUG', 1)) + +class PrepareMoE(torch.autograd.Function): + def forward(ctx, tokens): + pass + + def backward(ctx, grad_tokens): + pass + +class FakeContext(): + def save_for_backward(self, *tensors: torch.Tensor): + self.to_save = tensors + @property + def saved_tensors(self): + return self.to_save # type: ignore[misc] + + +def bias_dropout_add(x, bias, residual, prob, training): + # type: (Tensor, Tensor, Tensor, float, bool) -> Tensor + out = torch.nn.functional.dropout(x + bias, p=prob, training=training) + out = residual + out + return out + +@torch.jit.script +def bias_dropout_add_fused_train(x: torch.Tensor, + bias: torch.Tensor, + residual: torch.Tensor, + prob: float) -> torch.Tensor: + return bias_dropout_add(x, bias, residual, prob, True) + +def bias_dropout_add_ln_fwd(ctx, inp, residual, bias, prob, ln, bias_dropout_add_exec_handler): + ctx.inp = inp + ctx.residual = residual + ctx.bias = bias.detach() + + inp.requires_grad = True + ctx.bias.requires_grad = True + residual.requires_grad = True + with torch.enable_grad(): + ln_input = bias_dropout_add_fused_train(inp, ctx.bias, residual, prob) + ctx.ln_input = ln_input + #TODO + output = ln.explicit_fwd(ctx, ln_input) + return output, ln_input + +def bias_dropout_add_ln_bwd(ctx, grad_ln_outs, grad_ln_ins, ln): + grad_fusion, grad_ln_weight = ln.explicit_bwd(ctx, grad_ln_outs) + with torch.enable_grad(): + ctx.ln_input.backward(grad_fusion + grad_ln_ins) + + return grad_ln_weight, ctx.inp.grad, ctx.residual.grad, ctx.bias.grad + +streams = {} +def get_current(dev=None): + return torch.cuda.current_stream() + +def get_comp0(dev=None): + #return torch.cuda.current_stream() + if 'h' in streams: + return streams['h'] + streams['h'] = torch.cuda.Stream() + return streams['h'] + +def get_comm(dev=None): + #return torch.cuda.current_stream() + if 'm' in streams: + return streams['m'] + streams['m'] = torch.cuda.Stream() + return streams['m'] + +DISABLEPiPE=int(os.environ.get('DISABLEPiPE', 0)) +if DISABLEPiPE: + get_comp0 = get_current + get_comm = get_current + +class AttMoEPipe(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, hidden_states, ln_weight, ln_bias, proj_bias, non_params): + #torch.cuda.synchronize() + #t0 = time.time() + + ctx.non_params = non_params + flash, attn, dense_layer, pipe_degree, ln, hidden_dropout, bias_dropout_add_exec_handler, moe \ + = non_params + + ctx.batch_size, seqlen, ctx.head = q.size(0), q.size(1), q.size(2) + + assert seqlen % pipe_degree == 0 + + pipe_degree = pipe_degree + context_layers = [] + chunk_len = seqlen // pipe_degree + base = 0 + cu_seqlens = torch.arange(0, (ctx.batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, + device=q.device) + cu_seqlens_q = torch.arange(0, (ctx.batch_size + 1) * chunk_len, step=chunk_len, dtype=torch.int32, + device=q.device) + + ctx.flash_ctx = [] + ctx.dense_ctx = [] + ctx.bdal_ctx = [] + ctx.prepare_ctx = [] + ctx.mlp_ctx = [] + ctx.post_ctx = [] + ctx.loss_ctx = [] + + hidden_states_chunks = hidden_states.chunk(pipe_degree, dim=0) + + ln_outs = [] + ln_ins = [] + moe_outs = [] + + intermediate = [None] * pipe_degree + dispatchers = [None] * pipe_degree + scoreses = [None] * pipe_degree + tokens_per_experts = [None] * pipe_degree + + + attn_events = [] + a2a1_events = [] + comp_events = [] + a2a2_events = [] + + get_comp0().wait_stream(torch.cuda.current_stream()) + + + + for c in range(pipe_degree): + with torch.cuda.stream(get_comp0()): + + q_use = rearrange(q[:,base:base+chunk_len], 'b s ... -> (b s) ...') + flash_ctx = FakeContext() + + with random.seed(ParallelMode.Tensor): + output_chunk = flash_attn_fwd(flash_ctx, + q_use, k, v, cu_seqlens_q, cu_seqlens, chunk_len, seqlen, + flash.dropout_p if flash.training else 0.0, + softmax_scale=flash.softmax_scale, causal=True, + causal_q_offset=base, #fixed , + version=1 + ) + ctx.flash_ctx.append(flash_ctx) + context_layers.append(rearrange(output_chunk, '(b s) h d -> s b (h d)', b=ctx.batch_size).contiguous()) + base += chunk_len + + dense_ctx = FakeContext() + context_layers[-1] = dense_layer.explicit_fwd(dense_ctx, context_layers[-1]) + ctx.dense_ctx.append(dense_ctx) + + #if DEBUG: + # tensor_list = [torch.empty_like(context_layers[-1] ) for _ in range(mpu.get_tensor_model_parallel_world_size())] + # torch.distributed.all_gather(tensor_list, context_layers[-1] , group=mpu.get_tensor_model_parallel_group()) + # for t in tensor_list: + # assert (t == context_layers[-1] ).all().item(), "not same mlp input across tp ranks" + + bdal_ctx = FakeContext() + ln_output, ln_input = bias_dropout_add_ln_fwd(bdal_ctx, context_layers[-1], \ + hidden_states_chunks[c], dense_layer.bias, hidden_dropout, ln, bias_dropout_add_exec_handler) + ctx.bdal_ctx.append(bdal_ctx) + ln_ins.append(ln_input) + #ln_outs.append(ln_output) + #if DEBUG: + # tensor_list = [torch.empty_like(ln_output) for _ in range(mpu.get_tensor_model_parallel_world_size())] + # torch.distributed.all_gather(tensor_list, ln_output, group=mpu.get_tensor_model_parallel_group()) + # for t in tensor_list: + # assert (t == ln_output).all().item(), "not same a2a input across tp ranks" + + + prepare_ctx = FakeContext() + a2a_tokens, dispatcher, origin_shape, scores, tokens_per_expert\ + = moe.tutel_prepare_fwd(prepare_ctx, ln_output) + ctx.prepare_ctx.append(prepare_ctx) + + dispatchers[c] = dispatcher + intermediate[c] = a2a_tokens + tokens_per_experts[c] = tokens_per_expert + scoreses[c] = scores + + attn_events.append(torch.cuda.current_stream().record_event()) + + for c in range(pipe_degree): + with torch.cuda.stream(get_comm()): + torch.cuda.current_stream().wait_event(attn_events[c]) + #size [4, 256, 512] + #print("input moe: ", intermediate[c].mean(), torch.distributed.get_rank(), mpu.get_tensor_model_parallel_world_size()) + #if DEBUG: + # tensor_list = [torch.empty_like(intermediate[c]) for _ in range(mpu.get_tensor_model_parallel_world_size())] + # torch.distributed.all_gather(tensor_list, intermediate[c], group=mpu.get_tensor_model_parallel_group()) + # for t in tensor_list: + # assert (t == intermediate[c]).all().item(), "not same across tp ranks" + + a2a_tokens = moe.tutel_a2a_scatter(intermediate[c], [mpu.get_tensor_model_parallel_world_size(), mpu.get_tensor_model_parallel_group()]) + intermediate[c] = a2a_tokens + + a2a1_events.append(torch.cuda.current_stream().record_event()) + + + #get_comp0().wait_stream(get_comm()) + #t1 = time.time() + torch.cuda.synchronize() + + for c in range(pipe_degree): + with torch.cuda.stream(get_comp0()): + torch.cuda.current_stream().wait_event(a2a1_events[c]) + + mlp_ctx = FakeContext() + mlp_out = moe.tutel_mlp_fwd(mlp_ctx, intermediate[c]) + ctx.mlp_ctx.append(mlp_ctx) + intermediate[c] = mlp_out + + comp_events.append(torch.cuda.current_stream().record_event()) + #print("mlp: ", mlp_out.numel() * 4 * 6 * 8) + + for c in range(pipe_degree): + with torch.cuda.stream(get_comm()): + torch.cuda.current_stream().wait_event(comp_events[c]) + + a2a_tokens = moe.tutel_a2a_gather(intermediate[c], [mpu.get_tensor_model_parallel_world_size(), mpu.get_tensor_model_parallel_group()]) + intermediate[c] = a2a_tokens + + a2a2_events.append(torch.cuda.current_stream().record_event()) + + + for c in range(pipe_degree): + with torch.cuda.stream(get_comp0()): + + torch.cuda.current_stream().wait_event(a2a2_events[c]) + + post_ctx = FakeContext() + post_out = moe.tutel_post_fwd(post_ctx, intermediate[c], dispatchers[c]) + + + + ctx.post_ctx.append(post_ctx) + + post_out = post_out.view(origin_shape) + moe_outs.append(post_out) + + loss_ctx = FakeContext() + moe.tutel_loss(loss_ctx, scoreses[c], tokens_per_experts[c]) + ctx.loss_ctx.append(loss_ctx) + + #time.sleep(1000) + torch.cuda.current_stream().wait_stream(get_comp0()) + + + ret = torch.cat(moe_outs), torch.cat(ln_ins) + #torch.cuda.synchronize() + #te = time.time() + #if torch.distributed.get_rank() == 0: + # print("elapsed fwd: ", te - t0, te - t1, t1 - t0) + return ret + + @staticmethod + def backward(ctx, grad_mlp_outs, grad_ln_ins): + flash, attn, dense_layer, pipe_degree, ln, hidden_dropout, bias_dropout_add_exec_handler, moe\ + = ctx.non_params + + grad_mlp_outs = grad_mlp_outs.chunk(pipe_degree) + grad_ln_ins = grad_ln_ins.chunk(pipe_degree) + + grad_k, grad_v = None, None + grad_ln_weight, grad_ln_bias, bias_grad = None, None, None + grad_q = [] + grad_h = [] + intermediate = [None] * pipe_degree + gates_s_grads = [None] * pipe_degree + + post_events = [] + a2a2_events = [] + comp_events = [] + a2a1_events = [] + + get_comp0().wait_stream(torch.cuda.current_stream()) + + for c in range(0, pipe_degree): + with torch.cuda.stream(get_comp0()): + intermediate[c] = grad_mlp_outs[c].view(-1, grad_mlp_outs[c].size(-1)) + + intermediate[c], gates_s_grads[c] = moe.tutel_post_bwd(ctx.post_ctx[c], intermediate[c]) + + post_events.append(torch.cuda.current_stream().record_event()) + + for c in range(0, pipe_degree): + with torch.cuda.stream(get_comm()): + torch.cuda.current_stream().wait_event(post_events[c]) + intermediate[c] = moe.tutel_a2a_scatter(intermediate[c], [mpu.get_tensor_model_parallel_world_size(), mpu.get_tensor_model_parallel_group()]) + a2a2_events.append(torch.cuda.current_stream().record_event()) + + torch.cuda.synchronize() + + for c in range(0, pipe_degree): + with torch.cuda.stream(get_comp0()): + torch.cuda.current_stream().wait_event(a2a2_events[c]) + intermediate[c] = moe.tutel_mlp_bwd(ctx.mlp_ctx[c], intermediate[c]) + comp_events.append(torch.cuda.current_stream().record_event()) + + + + for c in range(0, pipe_degree): + with torch.cuda.stream(get_comm()): + torch.cuda.current_stream().wait_event(comp_events[c]) + intermediate[c] = moe.tutel_a2a_gather(intermediate[c], [mpu.get_tensor_model_parallel_world_size(), mpu.get_tensor_model_parallel_group()]) + a2a1_events.append(torch.cuda.current_stream().record_event()) + + for c in range(0, pipe_degree): + with torch.cuda.stream(get_comp0()): + torch.cuda.current_stream().wait_event(a2a1_events[c]) + + + #if DEBUG: + # tensor_list = [torch.empty_like(intermediate[c]) for _ in range(mpu.get_tensor_model_parallel_world_size())] + # torch.distributed.all_gather(tensor_list, intermediate[c], group=mpu.get_tensor_model_parallel_group()) + # for t in tensor_list: + # assert (t == intermediate[c]).all().item(), "not same across tp ranks" + + grad_ln_out = moe.tutel_prepare_bwd(ctx.prepare_ctx[c], moe.get_loss_grad(ctx.loss_ctx[c]), \ + intermediate[c], gates_s_grads[c]) + + + #if DEBUG: + # tensor_list = [torch.empty_like(grad_ln_out) for _ in range(mpu.get_tensor_model_parallel_world_size())] + # torch.distributed.all_gather(tensor_list, grad_ln_out, group=mpu.get_tensor_model_parallel_group()) + # for t in tensor_list: + # assert (t == grad_ln_out).all().item(), "not same across tp ranks" + + #if mpu.get_tensor_model_parallel_world_size() > 1: + # torch.distributed.all_reduce(grad_ln_out, op=torch.distributed.ReduceOp.SUM, group=mpu.get_tensor_model_parallel_group()) + + d_grad_ln_weight, d_grad_ln_bias, grad_dense, d_hidden_grad, d_bias_grad = \ + bias_dropout_add_ln_bwd(ctx.bdal_ctx[c], grad_ln_out, grad_ln_ins[c], ln) + grad_h.append(d_hidden_grad) + grad_ln_weight = grad_ln_weight + d_grad_ln_weight if grad_ln_weight is not None else d_grad_ln_weight + grad_ln_bias = grad_ln_bias + d_grad_ln_bias if grad_ln_bias is not None else d_grad_ln_bias + bias_grad = bias_grad + d_bias_grad if bias_grad is not None else d_bias_grad + + feed_flash = dense_layer.explicit_bwd(ctx.dense_ctx[c], grad_dense) + + d_q, d_k, d_v = flash_attn_bwd(ctx.flash_ctx[c], rearrange(feed_flash, 's b (h d) -> (b s) h d', h=ctx.head)) + grad_k = grad_k + d_k if grad_k is not None else d_k + grad_v = grad_v + d_v if grad_v is not None else d_v + grad_q.append(d_q) + + torch.cuda.current_stream().wait_stream(get_comp0()) + + return torch.cat([rearrange(gq, '(b s) ... -> b s ...', b=ctx.batch_size) for gq in grad_q], dim=1), \ + grad_k, grad_v, torch.cat(grad_h), grad_ln_weight, grad_ln_bias, bias_grad, None + + + + +class XAttMoEPipe(torch.autograd.Function): + @staticmethod + def forward(ctx, q, k, v, hidden_states, ln_weight, ln_bias, proj_bias, non_params): + ctx.non_params = non_params + flash, attn, dense_layer, pipe_degree, ln, hidden_dropout, bias_dropout_add_exec_handler, moe \ + = non_params + + ctx.batch_size, seqlen, ctx.head = q.size(0), q.size(1), q.size(2) + + assert seqlen % pipe_degree == 0 + + pipe_degree = pipe_degree + context_layers = [] + chunk_len = seqlen // pipe_degree + base = 0 + cu_seqlens = torch.arange(0, (ctx.batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, + device=q.device) + cu_seqlens_q = torch.arange(0, (ctx.batch_size + 1) * chunk_len, step=chunk_len, dtype=torch.int32, + device=q.device) + + ctx.flash_ctx = [] + ctx.dense_ctx = [] + ctx.bdal_ctx = [] + ctx.prepare_ctx = [] + ctx.mlp_ctx = [] + ctx.post_ctx = [] + ctx.loss_ctx = [] + + hidden_states_chunks = hidden_states.chunk(pipe_degree, dim=0) + + ln_outs = [] + ln_ins = [] + moe_outs = [] + + intermediate = [None] * pipe_degree + dispatchers = [None] * pipe_degree + scoreses = [None] * pipe_degree + tokens_per_experts = [None] * pipe_degree + origin_shapes = [None] * pipe_degree + + attn_events = [] + a2a1_events = [] + comp_events = [] + a2a2_events = [] + + + + def pre_comp(c, base): + q_use = rearrange(q[:,base:base+chunk_len], 'b s ... -> (b s) ...') + flash_ctx = FakeContext() + with tensor_parallel.get_cuda_rng_tracker().fork(): + output_chunk = flash_attn_fwd(flash_ctx, + q_use, k, v, cu_seqlens_q, cu_seqlens, chunk_len, seqlen, + flash.dropout_p if flash.training else 0.0, + softmax_scale=flash.softmax_scale, causal=True, + causal_q_offset=base, #fixed , + version=1 + ) + + + ctx.flash_ctx.append(flash_ctx) + context_layers.append(rearrange(output_chunk, '(b s) h d -> s b (h d)', b=ctx.batch_size).contiguous()) + + dense_ctx = FakeContext() + context_layers[-1] = dense_layer.explicit_fwd(dense_ctx, context_layers[-1]) + ctx.dense_ctx.append(dense_ctx) + + bdal_ctx = FakeContext() + ln_output, ln_input = bias_dropout_add_ln_fwd(bdal_ctx, context_layers[-1], \ + hidden_states_chunks[c], dense_layer.bias, hidden_dropout, ln, bias_dropout_add_exec_handler) + ctx.bdal_ctx.append(bdal_ctx) + ln_ins.append(ln_input) + #ln_outs.append(ln_output) + + prepare_ctx = FakeContext() + a2a_tokens, dispatcher, origin_shapes[c], scores, tokens_per_expert\ + = moe.tutel_prepare_fwd(prepare_ctx, ln_output) + ctx.prepare_ctx.append(prepare_ctx) + + dispatchers[c] = dispatcher + intermediate[c] = a2a_tokens + tokens_per_experts[c] = tokens_per_expert + scoreses[c] = scores + + def a2a1(c): + a2a_tokens = moe.tutel_a2a_scatter(intermediate[c], [mpu.get_tensor_model_parallel_world_size(), mpu.get_tensor_model_parallel_group()]) + intermediate[c] = a2a_tokens + + def mlpx(c): + mlp_ctx = FakeContext() + mlp_out = moe.tutel_mlp_fwd(mlp_ctx, intermediate[c]) + ctx.mlp_ctx.append(mlp_ctx) + intermediate[c] = mlp_out + + def a2a2(c): + a2a_tokens = moe.tutel_a2a_gather(intermediate[c], [mpu.get_tensor_model_parallel_world_size(), mpu.get_tensor_model_parallel_group()]) + intermediate[c] = a2a_tokens + + def post_comp(c): + post_ctx = FakeContext() + post_out = moe.tutel_post_fwd(post_ctx, intermediate[c], dispatchers[c]) + ctx.post_ctx.append(post_ctx) + + post_out = post_out.view(origin_shapes[c]) + moe_outs.append(post_out) + + loss_ctx = FakeContext() + moe.tutel_loss(loss_ctx, scoreses[c], tokens_per_experts[c]) + ctx.loss_ctx.append(loss_ctx) + + get_comp0().wait_stream(torch.cuda.current_stream()) + + with torch.cuda.stream(get_comp0()): + pre_comp(0, base) + base += chunk_len + attn_events.append(torch.cuda.current_stream().record_event()) + + + for c in range(1, pipe_degree): + with torch.cuda.stream(get_comp0()): + pre_comp(c, base) + base += chunk_len + attn_events.append(torch.cuda.current_stream().record_event()) + + with torch.cuda.stream(get_comm()): + torch.cuda.current_stream().wait_event(attn_events[c - 1]) + a2a1(c - 1) + a2a1_events.append(torch.cuda.current_stream().record_event()) + + + with torch.cuda.stream(get_comp0()): + torch.cuda.current_stream().wait_event(a2a1_events[c - 1]) + mlpx(c - 1) + comp_events.append(torch.cuda.current_stream().record_event()) + + with torch.cuda.stream(get_comm()): + torch.cuda.current_stream().wait_event(comp_events[c - 1]) + a2a2(c - 1) + a2a2_events.append(torch.cuda.current_stream().record_event()) + + for c in range(0, pipe_degree - 1): + with torch.cuda.stream(get_comp0()): + torch.cuda.current_stream().wait_event(a2a2_events[c]) + post_comp(c) + + with torch.cuda.stream(get_comm()): + torch.cuda.current_stream().wait_event(attn_events[pipe_degree - 1]) + a2a1(pipe_degree - 1) + a2a1_events.append(torch.cuda.current_stream().record_event()) + + with torch.cuda.stream(get_comp0()): + torch.cuda.current_stream().wait_event(a2a1_events[pipe_degree - 1]) + mlpx(pipe_degree - 1) + comp_events.append(torch.cuda.current_stream().record_event()) + + with torch.cuda.stream(get_comm()): + torch.cuda.current_stream().wait_event(comp_events[pipe_degree - 1]) + a2a2(pipe_degree - 1) + a2a2_events.append(torch.cuda.current_stream().record_event()) + + with torch.cuda.stream(get_comp0()): + torch.cuda.current_stream().wait_event(a2a2_events[pipe_degree - 1]) + post_comp(pipe_degree - 1) + ''' + for c in range(pipe_degree): + with torch.cuda.stream(get_comp0()): + pre_comp(c, base) + base += chunk_len + attn_events.append(torch.cuda.current_stream().record_event()) + + for c in range(pipe_degree): + with torch.cuda.stream(get_comm()): + torch.cuda.current_stream().wait_event(attn_events[c]) + a2a1(c) + a2a1_events.append(torch.cuda.current_stream().record_event()) + + for c in range(pipe_degree): + with torch.cuda.stream(get_comp0()): + torch.cuda.current_stream().wait_event(a2a1_events[c]) + mlpx(c) + comp_events.append(torch.cuda.current_stream().record_event()) + #print("mlp: ", mlp_out.numel() * 4 * 6 * 8) + + for c in range(pipe_degree): + with torch.cuda.stream(get_comm()): + torch.cuda.current_stream().wait_event(comp_events[c]) + + a2a2(c) + a2a2_events.append(torch.cuda.current_stream().record_event()) + + for c in range(pipe_degree): + with torch.cuda.stream(get_comp0()): + + torch.cuda.current_stream().wait_event(a2a2_events[c]) + + post_comp(c) + ''' + torch.cuda.current_stream().wait_stream(get_comp0()) + + + return torch.cat(moe_outs), torch.cat(ln_ins) + + @staticmethod + def backward(ctx, grad_mlp_outs, grad_ln_ins): + flash, attn, dense_layer, pipe_degree, ln, hidden_dropout, bias_dropout_add_exec_handler, moe\ + = ctx.non_params + + grad_mlp_outs = grad_mlp_outs.chunk(pipe_degree) + grad_ln_ins = grad_ln_ins.chunk(pipe_degree) + + grad_k, grad_v = None, None + grad_q = [] + grad_h = [] + + tokens_grad = grad_mlp_outs[0].view(-1, grad_mlp_outs[0].size(-1)) + tokens_grad, gates_s_grad = moe.tutel_post_bwd(ctx.post_ctx[0], tokens_grad) + + tokens_grad = moe.tutel_a2a_scatter(tokens_grad, [mpu.get_tensor_model_parallel_world_size(), mpu.get_tensor_model_parallel_group()]) + + g_mlp_in = moe.tutel_mlp_bwd(ctx.mlp_ctx[0], tokens_grad) + + tokens_grad = moe.tutel_a2a_gather(g_mlp_in, [mpu.get_tensor_model_parallel_world_size(), mpu.get_tensor_model_parallel_group()]) + + grad_ln_out = moe.tutel_prepare_bwd(ctx.prepare_ctx[0], moe.get_loss_grad(ctx.loss_ctx[0]), tokens_grad, gates_s_grad) + + + grad_ln_weight, grad_ln_bias, grad_dense, hidden_grad, bias_grad = \ + bias_dropout_add_ln_bwd(ctx.bdal_ctx[0], grad_ln_out, grad_ln_ins[0], ln) + grad_h.append(hidden_grad) + + #gradients of proj_weight is directly accumulated + feed_flash = dense_layer.explicit_bwd(ctx.dense_ctx[0], grad_dense) + + d_q, grad_k, grad_v = flash_attn_bwd(ctx.flash_ctx[0], rearrange(feed_flash, 's b (h d) -> (b s) h d', h=ctx.head)) + grad_q.append(d_q) + for c in range(1, pipe_degree): + + tokens_grad = grad_mlp_outs[c].view(-1, grad_mlp_outs[c].size(-1)) + + tokens_grad, gates_s_grad = moe.tutel_post_bwd(ctx.post_ctx[c], tokens_grad) + + tokens_grad = moe.tutel_a2a_scatter(tokens_grad, [mpu.get_tensor_model_parallel_world_size(), mpu.get_tensor_model_parallel_group()]) + + g_mlp_in = moe.tutel_mlp_bwd(ctx.mlp_ctx[c], tokens_grad) + + tokens_grad = moe.tutel_a2a_gather(g_mlp_in, [mpu.get_tensor_model_parallel_world_size(), mpu.get_tensor_model_parallel_group()]) + + grad_ln_out = moe.tutel_prepare_bwd(ctx.prepare_ctx[c], moe.get_loss_grad(ctx.loss_ctx[c]), \ + tokens_grad, gates_s_grad) + + + d_grad_ln_weight, d_grad_ln_bias, grad_dense, d_hidden_grad, d_bias_grad = \ + bias_dropout_add_ln_bwd(ctx.bdal_ctx[c], grad_ln_out, grad_ln_ins[c], ln) + grad_h.append(d_hidden_grad) + grad_ln_weight += d_grad_ln_weight + grad_ln_bias += d_grad_ln_bias + bias_grad += d_bias_grad + + feed_flash = dense_layer.explicit_bwd(ctx.dense_ctx[c], grad_dense) + + d_q, d_k, d_v = flash_attn_bwd(ctx.flash_ctx[c], rearrange(feed_flash, 's b (h d) -> (b s) h d', h=ctx.head)) + grad_k += d_k + grad_v += d_v + grad_q.append(d_q) + + return torch.cat([rearrange(gq, '(b s) ... -> b s ...', b=ctx.batch_size) for gq in grad_q], dim=1), \ + grad_k, grad_v, torch.cat(grad_h), grad_ln_weight, grad_ln_bias, bias_grad, None \ No newline at end of file diff --git a/internlm/model/moe/ampipe/fa_helper.py b/internlm/model/moe/ampipe/fa_helper.py new file mode 100644 index 00000000..4b343a20 --- /dev/null +++ b/internlm/model/moe/ampipe/fa_helper.py @@ -0,0 +1,261 @@ +import torch +import flash_attn_2_cuda +import flash_attn_cuda + +def _flash_attn1_forward(q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p, softmax_scale, causal, return_softmax, num_splits=0, + generator=None, causal_q_offset=0): + """ + num_splits: how much to parallelize over the seqlen_q dimension. num_splits=0 means + it will be set by an internal heuristic. We're exposing num_splits mostly for benchmarking. + Don't change it unless you know what you're doing. + """ + softmax_lse, rng_state, *rest = flash_attn_cuda.fwd( + q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, + softmax_scale, False, causal, return_softmax, num_splits, causal_q_offset, generator + ) + + S_dmask = rest[0] if return_softmax else None + return out, softmax_lse, rng_state, S_dmask + + +def _flash_attn1_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, + rng_state=None, num_splits=0, generator=None, causal_q_offset=0): + """ + num_splits: whether to parallelize over the seqlen_k dimension (num_splits > 1) or + not (num_splits = 1). num_splits=0 means it will be set by an internal heuristic. + Any value above 1 will call the same kernel (i.e. num_splits=2 would call the same kernel + as num_splits=3), so effectively the choices are 0, 1, and 2. + This hyperparameter can be tuned for performance, but default value (heuristic) should work fine. + """ + dout = dout.contiguous() # CUDA code assumes that dout is contiguous + _, _, _, softmax_d = flash_attn_cuda.bwd( + dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, + num_splits, causal_q_offset, generator, rng_state) + + return dq, dk, dv, softmax_d + + + +def _flash_attn_varlen_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p, softmax_scale, causal, return_softmax, causal_q_offset): + maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x + q, k, v = [maybe_contiguous(x) for x in (q, k, v)] + out, q, k, v, out_padded, softmax_lse, S_dmask = flash_attn_2_cuda.varlen_fwd( + q, k, v, None, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, + softmax_scale, False, causal, return_softmax, causal_q_offset, None + ) + + return out, q, k, v, out_padded, softmax_lse, S_dmask + +def _flash_attn_varlen_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, + cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p, softmax_scale, causal, causal_q_offset): + maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x + # dq, dk, dv are allocated by us so they should already be contiguous + dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)] + dq, dk, dv, softmax_d, = flash_attn_2_cuda.varlen_bwd( + dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, + max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, False, causal, causal_q_offset, None + ) + + return dq, dk, dv, softmax_d + +import time +TIMERS = {} +SELECT = {} +TRYS = {} + +class FlashAttnFuncMerge(torch.autograd.Function): + + @staticmethod + def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, + softmax_scale, causal, return_softmax, deterministic, causal_q_offset, version): + ctx.version = version + timeit = False + assert version != 0, "RNG STATE IS DIFFERENT, CAN NOT MIX UP USING VERSION 1 AND VERSION 2" + if version == 0: + key = (max_seqlen_q, max_seqlen_k, causal, causal_q_offset, 0) + if key in SELECT: + version = SELECT[key] + elif key in TRYS: + tried = TRYS[key] + for i in [1, 2]: + if i not in tried: + tried.append(i) + version = i + timeit = True + break + if version == 0: + min_id = 0 + real_time = 10000 + for i in [1, 2]: + if TIMERS[key][i] < real_time: + real_time = TIMERS[key][i] + min_id = i + SELECT[key] = min_id + version = min_id + else: + TRYS[key] = [1] + version = 1 + timeit = True + + if timeit: + torch.cuda.synchronize() + t0 = time.time() + + if version == 1: + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None + out, softmax_lse, _, S_dmask = _flash_attn1_forward( + q, k, v, torch.empty_like(q), cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax, + causal_q_offset=causal_q_offset + ) + ctx.save_for_backward(q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state) + ctx.dropout_p = dropout_p + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.softmax_scale = softmax_scale + ctx.causal = causal + ctx.deterministic = deterministic + ctx.causal_q_offset = causal_q_offset + + elif version == 2: + ctx.deterministic = deterministic + rng_state = torch.cuda.get_rng_state() if dropout_p > 0 else None + ctx.rng_state = rng_state + if softmax_scale is None: + softmax_scale = q.shape[-1] ** (-0.5) + ctx.causal_q_offset = causal_q_offset + out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_varlen_forward( + q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0, causal_q_offset=causal_q_offset + ) + ctx.save_for_backward(q, k, v, out_padded, softmax_lse, + cu_seqlens_q, cu_seqlens_k, rng_state) + ctx.dropout_p = dropout_p + ctx.max_seqlen_q = max_seqlen_q + ctx.max_seqlen_k = max_seqlen_k + ctx.softmax_scale = softmax_scale + ctx.causal = causal + else: + assert False + + if timeit: + torch.cuda.synchronize() + t1 = time.time() + if key in TIMERS: + TIMERS[key][version] = t1 - t0 + else: + TIMERS[key] = {version: t1 - t0} + return out if not return_softmax else (out, softmax_lse, S_dmask) + + @staticmethod + def backward(ctx, dout, *args): + version = ctx.version + timeit = False + if version == 0: + key = (ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.causal, ctx.causal_q_offset, 1) + if key in SELECT: + version = SELECT[key] + elif key in TRYS: + tried = TRYS[key] + for i in [1, 2]: + if i not in tried: + tried.append(i) + version = i + timeit = True + break + if version == 0: + min_id = 0 + real_time = 10000 + for i in [1, 2]: + if TIMERS[key][i] < real_time: + real_time = TIMERS[key][i] + min_id = i + SELECT[key] = min_id + version = min_id + else: + TRYS[key] = [1] + version = 1 + timeit = True + + if timeit: + + torch.cuda.synchronize() + t0 = time.time() + + if version == 1: + q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors + + if rng_state is not None: + cur_rng_state = torch.cuda.get_rng_state() + torch.cuda.set_rng_state(rng_state) + dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) + _flash_attn1_backward( + dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, + ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, causal= ctx.causal, num_splits=1 if ctx.deterministic else 0, + causal_q_offset=ctx.causal_q_offset, rng_state=None + ) + if rng_state is not None: + torch.cuda.set_rng_state(cur_rng_state) + elif version == 2: + q, k, v, out, softmax_lse, cu_seqlens_q, cu_seqlens_k, rng_state = ctx.saved_tensors + + if rng_state is not None: + cur_rng_state = torch.cuda.get_rng_state() + torch.cuda.set_rng_state(rng_state) + + dq, dk, dv = torch.empty_like(q), torch.empty_like(k), torch.empty_like(v) + _flash_attn_varlen_backward( + dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, + ctx.max_seqlen_q, ctx.max_seqlen_k, ctx.dropout_p, ctx.softmax_scale, ctx.causal, causal_q_offset=ctx.causal_q_offset + ) + dq = dq[..., :dout.shape[-1]] # We could have padded the head dimension + dk = dk[..., :dout.shape[-1]] + dv = dv[..., :dout.shape[-1]] + if rng_state is not None: + torch.cuda.set_rng_state(cur_rng_state) + else: + assert False + + + if timeit: + torch.cuda.synchronize() + t1 = time.time() + if key in TIMERS: + TIMERS[key][version] = t1 - t0 + else: + TIMERS[key] = {version: t1 - t0} + return dq, dk, dv, None, None, None, None, None, None, None, None, None, None, None + + + + + + +def flash_attn_megablock_call(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p=0.0, softmax_scale=None, causal=False, + return_attn_probs=False, deterministic=False, causal_q_offset=0, version=1): + + return FlashAttnFuncMerge.apply( + q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, + dropout_p, softmax_scale, causal, return_attn_probs, deterministic, causal_q_offset, version + ) + +def flash_attn_fwd(ctx, q_use, k, v, cu_seqlens_q, cu_seqlens, chunk_len, seqlen, dropout_p, softmax_scale, + causal, causal_q_offset, version): + assert causal and version == 1 + out = FlashAttnFuncMerge.forward(ctx, q_use, k, v, cu_seqlens_q, cu_seqlens, chunk_len, seqlen, + dropout_p, softmax_scale, True, False, False, causal_q_offset, 1 + ) + return out + + +def flash_attn_bwd(ctx, dout): + ret = FlashAttnFuncMerge.backward(ctx, dout) + return ret[0], ret[1], ret[2] \ No newline at end of file diff --git a/internlm/model/ops/norm.py b/internlm/model/ops/norm.py index 8565db4c..86280681 100644 --- a/internlm/model/ops/norm.py +++ b/internlm/model/ops/norm.py @@ -76,6 +76,24 @@ def forward(self, _input: torch.Tensor): _norm_func = manual_rms_norm return _norm_func(_input, self.weight, self.normalized_shape, self.eps, self.add_unit_offset) + + def explicit_forward(self, ctx, _input: torch.Tensor): + if apex_rmsnorm_impl: + args = _cast_if_autocast_enabled(input, self.weight, self.normalized_shape, self.eps) + with torch.amp.autocast('cuda', enabled=False): + return FusedRMSNormAffineMixedDtypesFunction.forward(ctx, *args) + else: + assert False + + def explicit_forward(self, ctx, grad_output: torch.Tensor): + if apex_rmsnorm_impl: + with torch.amp.autocast('cuda', enabled=False): + grad_input, grad_weight, *_ = FusedRMSNormAffineMixedDtypesFunction.backward(ctx, grad_output) + else: + assert False + + return grad_input, grad_weight + def reset_parameters(self): if self.add_unit_offset: init.zeros_(self.weight) From 09701139c633c801a376004b7edf83849b2981e7 Mon Sep 17 00:00:00 2001 From: Wenwen Qu Date: Fri, 13 Dec 2024 11:21:08 +0800 Subject: [PATCH 3/4] add test code --- internlm/model/modeling_moe.py | 54 +- internlm/model/modules/mha.py | 4 + internlm/model/moe/ampipe/ampipe.py | 279 +------- internlm/model/moe/ampipe/moe_layer.py | 782 +++++++++++++++++++++ internlm/model/moe/ampipe/tutel_adapter.py | 99 +++ 5 files changed, 925 insertions(+), 293 deletions(-) create mode 100644 internlm/model/moe/ampipe/moe_layer.py create mode 100644 internlm/model/moe/ampipe/tutel_adapter.py diff --git a/internlm/model/modeling_moe.py b/internlm/model/modeling_moe.py index cdb1084b..125dcbf6 100644 --- a/internlm/model/modeling_moe.py +++ b/internlm/model/modeling_moe.py @@ -26,6 +26,8 @@ from internlm.solver.activation_checkpoint import activation_checkpoint from internlm.utils.logger import get_logger +from internlm.model.moe.ampipe.ampipe import AttMoEPipe, bias_dropout_add_fused_train + logger = get_logger(__file__) @@ -217,30 +219,46 @@ def _dropout_and_norm_attn(_hidden_states): residual = residual.to(torch.float32) mixer_kwargs = convert_attn_args_to_kwargs(args, kwargs) - hidden_states = self.mixer(hidden_states, **mixer_kwargs) - def _dropout_and_norm_ffn(_residual, _hidden_states): - _dropped = self.dropout2(_hidden_states) - _residual = (_dropped + _residual) if _residual is not None else _dropped - _hidden_states = self.norm2(_residual.float()) - return _residual, _hidden_states + if gpc.config.model.ampipe_degree < 1: + hidden_states = self.mixer(hidden_states, **mixer_kwargs) + + def _dropout_and_norm_ffn(_residual, _hidden_states): + _dropped = self.dropout2(_hidden_states) + _residual = (_dropped + _residual) if _residual is not None else _dropped + _hidden_states = self.norm2(_residual.float()) + return _residual, _hidden_states + + if self.dropout_selective_checkpoint: + residual, hidden_states = activation_checkpoint(_dropout_and_norm_ffn, False, residual, hidden_states) + else: + residual, hidden_states = _dropout_and_norm_ffn(residual, hidden_states) + + if self.residual_in_fp32: + residual = residual.to(torch.float32) + + # MLP. + if self.num_experts <= 1: # dense mlp output + hidden_states = self.mlp(hidden_states) + moe_loss = torch.tensor(0.0, device=hidden_states.device, dtype=hidden_states.dtype) + else: # MoE output + hidden_states, moe_loss, _ = self.mlp(hidden_states) - if self.dropout_selective_checkpoint: - residual, hidden_states = activation_checkpoint(_dropout_and_norm_ffn, False, residual, hidden_states) else: - residual, hidden_states = _dropout_and_norm_ffn(residual, hidden_states) + mixer_kwargs["skip_score"] = True + q, k, v = self.mixer(hidden_states, **mixer_kwargs) - if self.residual_in_fp32: - residual = residual.to(torch.float32) + flash = self.mixer.inner_attn + dense_layer = self.mixer.out_proj + ln = self.norm2 - # MLP. - if self.num_experts <= 1: # dense mlp output - hidden_states = self.mlp(hidden_states) - moe_loss = torch.tensor(0.0, device=hidden_states.device, dtype=hidden_states.dtype) - else: # MoE output - hidden_states, moe_loss, _ = self.mlp(hidden_states) + hidden_states, residual = AttMoEPipe.apply(q, k, v, hidden_states, + ln.weight, ln.bias, dense_layer.bias, + [flash, dense_layer, + gpc.config.model.ampipe_degree, ln, self.dropout2.p, \ + self.mlp.moe_layer]) - return hidden_states + residual, moe_loss + return hidden_states + residual, None class Internlm1MoE(BaseModel): diff --git a/internlm/model/modules/mha.py b/internlm/model/modules/mha.py index 42418a21..5413b8ce 100644 --- a/internlm/model/modules/mha.py +++ b/internlm/model/modules/mha.py @@ -206,9 +206,13 @@ def _training(self, x, **kwargs): # rotary embedding indexes = kwargs.pop("indexes", 0) max_seqlen = kwargs.get("max_seqlen", None) + skip_score = kwargs.get("skip_score", False) q = self.rotary_emb(q, offsets=indexes, cache_type="query", interleaved=self.interleaved, max_seqlen=max_seqlen) k = self.rotary_emb(k, offsets=indexes, cache_type="key", interleaved=self.interleaved, max_seqlen=max_seqlen) + if skip_score: + return q, k, v + # self attention kwargs = _convert_cu_seqlens_for_qksplited(kwargs) if gpc.config.data.use_packed_dataset is False or self.training is False: diff --git a/internlm/model/moe/ampipe/ampipe.py b/internlm/model/moe/ampipe/ampipe.py index 8e77015c..3fc2ec33 100644 --- a/internlm/model/moe/ampipe/ampipe.py +++ b/internlm/model/moe/ampipe/ampipe.py @@ -88,7 +88,7 @@ def forward(ctx, q, k, v, hidden_states, ln_weight, ln_bias, proj_bias, non_para #t0 = time.time() ctx.non_params = non_params - flash, attn, dense_layer, pipe_degree, ln, hidden_dropout, bias_dropout_add_exec_handler, moe \ + flash, dense_layer, pipe_degree, ln, hidden_dropout, bias_dropout_add_exec_handler, moe \ = non_params ctx.batch_size, seqlen, ctx.head = q.size(0), q.size(1), q.size(2) @@ -244,9 +244,9 @@ def forward(ctx, q, k, v, hidden_states, ln_weight, ln_bias, proj_bias, non_para post_out = post_out.view(origin_shape) moe_outs.append(post_out) - loss_ctx = FakeContext() - moe.tutel_loss(loss_ctx, scoreses[c], tokens_per_experts[c]) - ctx.loss_ctx.append(loss_ctx) + # loss_ctx = FakeContext() + # moe.dummy_moe_loss(loss_ctx, scoreses[c], tokens_per_experts[c]) + # ctx.loss_ctx.append(loss_ctx) #time.sleep(1000) torch.cuda.current_stream().wait_stream(get_comp0()) @@ -351,276 +351,5 @@ def backward(ctx, grad_mlp_outs, grad_ln_ins): torch.cuda.current_stream().wait_stream(get_comp0()) - return torch.cat([rearrange(gq, '(b s) ... -> b s ...', b=ctx.batch_size) for gq in grad_q], dim=1), \ - grad_k, grad_v, torch.cat(grad_h), grad_ln_weight, grad_ln_bias, bias_grad, None - - - - -class XAttMoEPipe(torch.autograd.Function): - @staticmethod - def forward(ctx, q, k, v, hidden_states, ln_weight, ln_bias, proj_bias, non_params): - ctx.non_params = non_params - flash, attn, dense_layer, pipe_degree, ln, hidden_dropout, bias_dropout_add_exec_handler, moe \ - = non_params - - ctx.batch_size, seqlen, ctx.head = q.size(0), q.size(1), q.size(2) - - assert seqlen % pipe_degree == 0 - - pipe_degree = pipe_degree - context_layers = [] - chunk_len = seqlen // pipe_degree - base = 0 - cu_seqlens = torch.arange(0, (ctx.batch_size + 1) * seqlen, step=seqlen, dtype=torch.int32, - device=q.device) - cu_seqlens_q = torch.arange(0, (ctx.batch_size + 1) * chunk_len, step=chunk_len, dtype=torch.int32, - device=q.device) - - ctx.flash_ctx = [] - ctx.dense_ctx = [] - ctx.bdal_ctx = [] - ctx.prepare_ctx = [] - ctx.mlp_ctx = [] - ctx.post_ctx = [] - ctx.loss_ctx = [] - - hidden_states_chunks = hidden_states.chunk(pipe_degree, dim=0) - - ln_outs = [] - ln_ins = [] - moe_outs = [] - - intermediate = [None] * pipe_degree - dispatchers = [None] * pipe_degree - scoreses = [None] * pipe_degree - tokens_per_experts = [None] * pipe_degree - origin_shapes = [None] * pipe_degree - - attn_events = [] - a2a1_events = [] - comp_events = [] - a2a2_events = [] - - - - def pre_comp(c, base): - q_use = rearrange(q[:,base:base+chunk_len], 'b s ... -> (b s) ...') - flash_ctx = FakeContext() - with tensor_parallel.get_cuda_rng_tracker().fork(): - output_chunk = flash_attn_fwd(flash_ctx, - q_use, k, v, cu_seqlens_q, cu_seqlens, chunk_len, seqlen, - flash.dropout_p if flash.training else 0.0, - softmax_scale=flash.softmax_scale, causal=True, - causal_q_offset=base, #fixed , - version=1 - ) - - - ctx.flash_ctx.append(flash_ctx) - context_layers.append(rearrange(output_chunk, '(b s) h d -> s b (h d)', b=ctx.batch_size).contiguous()) - - dense_ctx = FakeContext() - context_layers[-1] = dense_layer.explicit_fwd(dense_ctx, context_layers[-1]) - ctx.dense_ctx.append(dense_ctx) - - bdal_ctx = FakeContext() - ln_output, ln_input = bias_dropout_add_ln_fwd(bdal_ctx, context_layers[-1], \ - hidden_states_chunks[c], dense_layer.bias, hidden_dropout, ln, bias_dropout_add_exec_handler) - ctx.bdal_ctx.append(bdal_ctx) - ln_ins.append(ln_input) - #ln_outs.append(ln_output) - - prepare_ctx = FakeContext() - a2a_tokens, dispatcher, origin_shapes[c], scores, tokens_per_expert\ - = moe.tutel_prepare_fwd(prepare_ctx, ln_output) - ctx.prepare_ctx.append(prepare_ctx) - - dispatchers[c] = dispatcher - intermediate[c] = a2a_tokens - tokens_per_experts[c] = tokens_per_expert - scoreses[c] = scores - - def a2a1(c): - a2a_tokens = moe.tutel_a2a_scatter(intermediate[c], [mpu.get_tensor_model_parallel_world_size(), mpu.get_tensor_model_parallel_group()]) - intermediate[c] = a2a_tokens - - def mlpx(c): - mlp_ctx = FakeContext() - mlp_out = moe.tutel_mlp_fwd(mlp_ctx, intermediate[c]) - ctx.mlp_ctx.append(mlp_ctx) - intermediate[c] = mlp_out - - def a2a2(c): - a2a_tokens = moe.tutel_a2a_gather(intermediate[c], [mpu.get_tensor_model_parallel_world_size(), mpu.get_tensor_model_parallel_group()]) - intermediate[c] = a2a_tokens - - def post_comp(c): - post_ctx = FakeContext() - post_out = moe.tutel_post_fwd(post_ctx, intermediate[c], dispatchers[c]) - ctx.post_ctx.append(post_ctx) - - post_out = post_out.view(origin_shapes[c]) - moe_outs.append(post_out) - - loss_ctx = FakeContext() - moe.tutel_loss(loss_ctx, scoreses[c], tokens_per_experts[c]) - ctx.loss_ctx.append(loss_ctx) - - get_comp0().wait_stream(torch.cuda.current_stream()) - - with torch.cuda.stream(get_comp0()): - pre_comp(0, base) - base += chunk_len - attn_events.append(torch.cuda.current_stream().record_event()) - - - for c in range(1, pipe_degree): - with torch.cuda.stream(get_comp0()): - pre_comp(c, base) - base += chunk_len - attn_events.append(torch.cuda.current_stream().record_event()) - - with torch.cuda.stream(get_comm()): - torch.cuda.current_stream().wait_event(attn_events[c - 1]) - a2a1(c - 1) - a2a1_events.append(torch.cuda.current_stream().record_event()) - - - with torch.cuda.stream(get_comp0()): - torch.cuda.current_stream().wait_event(a2a1_events[c - 1]) - mlpx(c - 1) - comp_events.append(torch.cuda.current_stream().record_event()) - - with torch.cuda.stream(get_comm()): - torch.cuda.current_stream().wait_event(comp_events[c - 1]) - a2a2(c - 1) - a2a2_events.append(torch.cuda.current_stream().record_event()) - - for c in range(0, pipe_degree - 1): - with torch.cuda.stream(get_comp0()): - torch.cuda.current_stream().wait_event(a2a2_events[c]) - post_comp(c) - - with torch.cuda.stream(get_comm()): - torch.cuda.current_stream().wait_event(attn_events[pipe_degree - 1]) - a2a1(pipe_degree - 1) - a2a1_events.append(torch.cuda.current_stream().record_event()) - - with torch.cuda.stream(get_comp0()): - torch.cuda.current_stream().wait_event(a2a1_events[pipe_degree - 1]) - mlpx(pipe_degree - 1) - comp_events.append(torch.cuda.current_stream().record_event()) - - with torch.cuda.stream(get_comm()): - torch.cuda.current_stream().wait_event(comp_events[pipe_degree - 1]) - a2a2(pipe_degree - 1) - a2a2_events.append(torch.cuda.current_stream().record_event()) - - with torch.cuda.stream(get_comp0()): - torch.cuda.current_stream().wait_event(a2a2_events[pipe_degree - 1]) - post_comp(pipe_degree - 1) - ''' - for c in range(pipe_degree): - with torch.cuda.stream(get_comp0()): - pre_comp(c, base) - base += chunk_len - attn_events.append(torch.cuda.current_stream().record_event()) - - for c in range(pipe_degree): - with torch.cuda.stream(get_comm()): - torch.cuda.current_stream().wait_event(attn_events[c]) - a2a1(c) - a2a1_events.append(torch.cuda.current_stream().record_event()) - - for c in range(pipe_degree): - with torch.cuda.stream(get_comp0()): - torch.cuda.current_stream().wait_event(a2a1_events[c]) - mlpx(c) - comp_events.append(torch.cuda.current_stream().record_event()) - #print("mlp: ", mlp_out.numel() * 4 * 6 * 8) - - for c in range(pipe_degree): - with torch.cuda.stream(get_comm()): - torch.cuda.current_stream().wait_event(comp_events[c]) - - a2a2(c) - a2a2_events.append(torch.cuda.current_stream().record_event()) - - for c in range(pipe_degree): - with torch.cuda.stream(get_comp0()): - - torch.cuda.current_stream().wait_event(a2a2_events[c]) - - post_comp(c) - ''' - torch.cuda.current_stream().wait_stream(get_comp0()) - - - return torch.cat(moe_outs), torch.cat(ln_ins) - - @staticmethod - def backward(ctx, grad_mlp_outs, grad_ln_ins): - flash, attn, dense_layer, pipe_degree, ln, hidden_dropout, bias_dropout_add_exec_handler, moe\ - = ctx.non_params - - grad_mlp_outs = grad_mlp_outs.chunk(pipe_degree) - grad_ln_ins = grad_ln_ins.chunk(pipe_degree) - - grad_k, grad_v = None, None - grad_q = [] - grad_h = [] - - tokens_grad = grad_mlp_outs[0].view(-1, grad_mlp_outs[0].size(-1)) - tokens_grad, gates_s_grad = moe.tutel_post_bwd(ctx.post_ctx[0], tokens_grad) - - tokens_grad = moe.tutel_a2a_scatter(tokens_grad, [mpu.get_tensor_model_parallel_world_size(), mpu.get_tensor_model_parallel_group()]) - - g_mlp_in = moe.tutel_mlp_bwd(ctx.mlp_ctx[0], tokens_grad) - - tokens_grad = moe.tutel_a2a_gather(g_mlp_in, [mpu.get_tensor_model_parallel_world_size(), mpu.get_tensor_model_parallel_group()]) - - grad_ln_out = moe.tutel_prepare_bwd(ctx.prepare_ctx[0], moe.get_loss_grad(ctx.loss_ctx[0]), tokens_grad, gates_s_grad) - - - grad_ln_weight, grad_ln_bias, grad_dense, hidden_grad, bias_grad = \ - bias_dropout_add_ln_bwd(ctx.bdal_ctx[0], grad_ln_out, grad_ln_ins[0], ln) - grad_h.append(hidden_grad) - - #gradients of proj_weight is directly accumulated - feed_flash = dense_layer.explicit_bwd(ctx.dense_ctx[0], grad_dense) - - d_q, grad_k, grad_v = flash_attn_bwd(ctx.flash_ctx[0], rearrange(feed_flash, 's b (h d) -> (b s) h d', h=ctx.head)) - grad_q.append(d_q) - for c in range(1, pipe_degree): - - tokens_grad = grad_mlp_outs[c].view(-1, grad_mlp_outs[c].size(-1)) - - tokens_grad, gates_s_grad = moe.tutel_post_bwd(ctx.post_ctx[c], tokens_grad) - - tokens_grad = moe.tutel_a2a_scatter(tokens_grad, [mpu.get_tensor_model_parallel_world_size(), mpu.get_tensor_model_parallel_group()]) - - g_mlp_in = moe.tutel_mlp_bwd(ctx.mlp_ctx[c], tokens_grad) - - tokens_grad = moe.tutel_a2a_gather(g_mlp_in, [mpu.get_tensor_model_parallel_world_size(), mpu.get_tensor_model_parallel_group()]) - - grad_ln_out = moe.tutel_prepare_bwd(ctx.prepare_ctx[c], moe.get_loss_grad(ctx.loss_ctx[c]), \ - tokens_grad, gates_s_grad) - - - d_grad_ln_weight, d_grad_ln_bias, grad_dense, d_hidden_grad, d_bias_grad = \ - bias_dropout_add_ln_bwd(ctx.bdal_ctx[c], grad_ln_out, grad_ln_ins[c], ln) - grad_h.append(d_hidden_grad) - grad_ln_weight += d_grad_ln_weight - grad_ln_bias += d_grad_ln_bias - bias_grad += d_bias_grad - - feed_flash = dense_layer.explicit_bwd(ctx.dense_ctx[c], grad_dense) - - d_q, d_k, d_v = flash_attn_bwd(ctx.flash_ctx[c], rearrange(feed_flash, 's b (h d) -> (b s) h d', h=ctx.head)) - grad_k += d_k - grad_v += d_v - grad_q.append(d_q) - return torch.cat([rearrange(gq, '(b s) ... -> b s ...', b=ctx.batch_size) for gq in grad_q], dim=1), \ grad_k, grad_v, torch.cat(grad_h), grad_ln_weight, grad_ln_bias, bias_grad, None \ No newline at end of file diff --git a/internlm/model/moe/ampipe/moe_layer.py b/internlm/model/moe/ampipe/moe_layer.py new file mode 100644 index 00000000..aa84d776 --- /dev/null +++ b/internlm/model/moe/ampipe/moe_layer.py @@ -0,0 +1,782 @@ +from typing import Optional + +import numpy as np +import torch +import torch.nn.functional as F + +import tutel + +from internlm.core.context import ParallelMode +from internlm.core.context import global_context as gpc +from internlm.model.moe.base_layer import BaseMoELayer +from internlm.model.moe.megablocks.mlp import MegaBlockFeedForward +from internlm.model.moe.utils import all_to_all + +try: + from megablocks import ops +except (ModuleNotFoundError, ImportError): + ops = None + +import tutel.impls.communicate as C +import tutel_custom_kernel + +class BWDDEBUG(torch.autograd.Function): + @staticmethod + def forward(ctx, inp, info): + ctx.info = info + return inp + + @staticmethod + def backward(ctx, grad_inp): + if torch.distributed.get_rank() == 0: + print("CALLING BWD: ", ctx.info) + return grad_inp, None + + + +def create_fake(x): + return megablocks_ops.fake_tensor(x) + return x.detach() #if last line reports error + +class NoBuffer(torch.autograd.Function): + @staticmethod + def forward(ctx, x): + fake = create_fake(x) #watch out, do not access to x's data_ptr + return fake + @staticmethod + def backward(ctx, g): + return g + +class NoBufferAssist(torch.autograd.Function): + @staticmethod + def forward(ctx, x, assistant): + #x.size() == assistant + assert x.size() == assistant.size() + assert assistant.requires_grad == False + return assistant + @staticmethod + def backward(ctx, g): + return g, None + +class MLP_TP_F(torch.autograd.Function): + @staticmethod + def forward(ctx, tokens, group): + ctx.group = group + return tokens + @staticmethod + def backward(ctx, g_tokens): + torch.distributed.all_reduce(g_tokens, op=torch.distributed.ReduceOp.SUM, group=ctx.group) + return g_tokens, None + +class MLP_TP_G(torch.autograd.Function): + @staticmethod + def forward(ctx, tokens, group): + torch.distributed.all_reduce(tokens, op=torch.distributed.ReduceOp.SUM, group=group) + return tokens + @staticmethod + def backward(ctx, g_tokens): + return g_tokens, None + +_LOAD_BALANCING_LOSS = [] + +_MoE_Layer = [] + +import os +FAKE_A2A_SCALE=int(os.environ.get('FAKE_A2A_SCALE', 1)) + +def save_load_balancing_loss(loss, idx=-1): + global _LOAD_BALANCING_LOSS + if idx == -1 or len(_LOAD_BALANCING_LOSS) <= idx: + _LOAD_BALANCING_LOSS.append(loss) + else: + _LOAD_BALANCING_LOSS[idx] = (_LOAD_BALANCING_LOSS[idx][0] + loss[0], torch.cat([_LOAD_BALANCING_LOSS[idx][1], loss[1]])) + +def get_load_balancing_loss(): + global _LOAD_BALANCING_LOSS + return _LOAD_BALANCING_LOSS + + +def clear_load_balancing_loss(): + global _LOAD_BALANCING_LOSS + _LOAD_BALANCING_LOSS.clear() + +def get_world_size(group=None): + try: + return torch.distributed.get_world_size(group) + except: + return 1 + +# def batched_load_balancing_loss(args : Arguments): +# # tokens_per_expert[i].shape = (num_experts) +# # expert_scores[i].shape = (tokens, num_experts) +# tokens_per_expert, expert_scores = zip(*get_load_balancing_loss()) +# num_layers_per_pipeline_stage = ( +# gpc.config.model.num_layers // gpc.get_world_rank(ParallelMode.PIPELINE)) +# if args.num_layers_per_virtual_pipeline_stage is not None: +# num_layers_per_pipeline_stage = args.num_layers_per_virtual_pipeline_stage + +# if len(tokens_per_expert) != num_layers_per_pipeline_stage: +# raise ValueError( +# f"Expected {num_layers_per_pipeline_stage} token_per_experts " +# f"but found {len(tokens_per_expert)}.\nnum_layers = " +# f"{args.num_layers}\npipeline_model_parallel_size = " +# f"{args.pipeline_model_parallel_size}\n" +# "num_layers_per_virtual_pipeline_stage" +# f" = {args.num_layers_per_virtual_pipeline_stage}") +# if len(expert_scores) != num_layers_per_pipeline_stage: +# raise ValueError( +# f"Expected {num_layers_per_pipeline_stage} expert_scores " +# f"but found {len(tokens_per_expert)}.\nnum_layers = " +# f"{args.num_layers}\npipeline_model_parallel_size = " +# f"{args.pipeline_model_parallel_size}\n" +# "num_layers_per_virtual_pipeline_stage" +# f" = {args.num_layers_per_virtual_pipeline_stage}") + +# # Verify the shape of the tokens_per_expert and expert_scores tensors. +# assert all([ +# x.ndim == 1 and x.numel() == args.moe_num_experts +# for x in tokens_per_expert +# ]) + +# tokens = expert_scores[0].shape[0] +# assert all([ +# (x.ndim == 2 and x.shape[1] == args.moe_num_experts and +# x.shape[0] == tokens) for x in expert_scores +# ]) + + +# # Concatenate the contributions of each layer and convert to +# # the correct types and formats for the dot product. +# if args.moe_lbl_in_fp32: +# expert_scores = torch.cat(expert_scores, dim=1).float().mean(dim=0) +# else: +# expert_scores = torch.cat(expert_scores, dim=1).mean(dim=0) +# tokens_per_expert = torch.cat(tokens_per_expert).to(expert_scores.dtype) + +# expected_values = num_layers_per_pipeline_stage * args.moe_num_experts +# assert tokens_per_expert.numel() == expected_values +# assert expert_scores.numel() == expected_values + +# # Calculate the total scale across all factors. +# # +# # loss_weight * num_experts / (num_layers * tokens * top_k) +# scale_numerator = ( +# args.moe_num_experts * +# args.moe_loss_weight +# ) +# scale_denominator = ( +# args.num_layers * +# tokens * +# args.moe_top_k +# ) +# scale = scale_numerator / scale_denominator +# return scale * torch.dot(tokens_per_expert, expert_scores) + +class AmpipeMegaBlockMoE(BaseMoELayer): + """ + Built on the paper and library Megablocks as described in + https://arxiv.org/abs/2211.15841. This implementation is + strictly equivalent to standard MoE with full capacity (no + dropped tokens). It's faster since it formulates MoE operations + in terms of block-sparse operations to accomodate imbalanced + assignments of tokens to experts, whereas standard MoE either + (1) drop tokens at the cost of reduced performance or (2) set + capacity factor to number of experts and thus waste computation + and memory on padding. + """ + + def __init__( + self, + in_features: int, + hidden_features: int, + out_features: int, + num_experts: int, + top_k: int, + ep_group: Optional[torch.distributed.ProcessGroup], + ep_size: int, + device: Optional[torch.device] = None, + dtype: Optional[torch.device] = None, + mlp_layer_fusion: bool = False, # pylint: disable=W0613 + multiple_of: int = 256, + activation_type: str = "swiglu", # pylint: disable=W0613 + capacity_factor: float = 1.0, + drop_tokens: bool = True, + ) -> None: + assert not gpc.config.parallel.sequence_parallel, "do not support sequence parallel" + assert ops is not None, 'MegaBlocks not found, please run "pip install megablocks".' + self.top_k = top_k + self.num_experts = num_experts + + tp_size = gpc.get_world_size(ParallelMode.TENSOR) + self.ffn_dim = multiple_of * ((hidden_features + multiple_of - 1) // multiple_of) + self.capacity_factor = capacity_factor + self.drop_tokens = drop_tokens + assert self.ffn_dim % tp_size == 0 + super().__init__( + torch.nn.Linear(in_features, num_experts, bias=False), + MegaBlockFeedForward( + in_features, + self.ffn_dim // tp_size, + out_features, + num_experts // ep_size, + device, + dtype, + ), + ep_group, + ep_size, + 1, + ) + + # Calculate the number of bits needed to represent the expert indices + # so that we can pass it to radix sort. + self.sort_end_bit = max(int(np.ceil(np.log2(self.num_experts))), 1) + self.quantize_scatter_num_bits = -1 + # re-init the number of experts in each device + self.num_local_experts = num_experts // ep_size + + self.forward_fn = self._parallel_forward if gpc.expert_parallel_size > 1 else self._forward + + def expert_capacity(self, tokens, top_k): + world_size = gpc.get_world_size(ParallelMode.EXPERT) # mpu.get_expert_parallel_world_size(self.args) + tokens_per_expert = top_k * tokens * world_size / self.num_experts + return int(self.capacity_factor * tokens_per_expert) + + def indices_and_bins(self, top_expert): + # Sort the expert ids to produce the scatter/gather + # indices for the permutation. + # + # TODO(tgale): Is it worth doing this conversion to 32-bit + # prior? Could we place the `torch.max` operation to return + # 32-bit expert indices? + top_expert = top_expert.int() + bin_ids, indices = ops.sort(top_expert, self.sort_end_bit) + + # Histogram the expert ids to identify the number of + # tokens routed to each expert. + # + # TODO(tgale): Does the sorted data produce a more favorable + # data distribution for histogram? Or is the op parallelism + # worth more? + tokens_per_expert = ops.histogram(top_expert, self.num_experts) + + # Calculate the bin bounds for the sorted tokens. + bins = ops.inclusive_cumsum(tokens_per_expert, 0) + bins = bins.view(1) if len(bins.size()) == 0 else bins + return indices, bin_ids, bins, tokens_per_expert + + def _forward(self, x, expert_weights, top_experts) -> torch.Tensor: + """ + x: (sequence_length, model_dim) + gate_logits: (sequence_length, n_experts) + """ + with torch.no_grad(): + indices, _, bins, tokens_per_expert = self.indices_and_bins(top_experts) + # If expert_capacity is set to zero, set the number of tokens + # per expert to the maximum we need to avoid dropping tokens. + tokens, _ = x.size() + expert_capacity = self.expert_capacity(tokens, top_k=self.top_k) + if not self.drop_tokens: + expert_capacity = torch.max(tokens_per_expert).item() + + out = self.permute_and_compute(x, indices, expert_weights, bins, expert_capacity, top_k=self.top_k) + + return out, tokens_per_expert.flatten() + + def _parallel_forward(self, x, expert_weights, top_experts): + # NOTE: This function implements the same computation as forward_once + # but with expert model parallelism. + # + # 1. Permute the tokens locally so that they are grouped by their + # expert assignments. This allows us to transfer all of the tokens + # for a remote device in one communication primitive. + # + # 2. Permute the tokens across the expert parallel devices. After + # this is completed each device has all of the tokens assigned to + # its set of experts in its local HBM. + # + # 3. Permute the tokens locally so that they are grouped by their + # expert assignement. After the distributed permutation the tokens + # are grouped by which device they came from. We re-order them + # locally to allow for efficient computation. + # + # After this series of permutations we compute the linear layers + # and then repeat these three steps in reverse to produce the final + # output. + # + # Compute the mapping of local tokens to experts. + """ + x: (sequence_length, model_dim) + gate_logits: (sequence_length, n_experts) + """ + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = self.indices_and_bins(top_experts) + + # Pass token count information to the device on which the + # target expert resides. + # e.g. tokens_per_expert = (1,2,1,0) in g1 + # tokens_per_expert = (2,0,2,0) in g2 + # then:parallel_tokens_per_expert = (1,2,2,0) in g1 + # parallel_tokens_per_expert = (1,0,2,0) in g2 + parallel_tokens_per_expert = torch.empty_like(tokens_per_expert) + tpe_handle = torch.distributed.all_to_all_single( + parallel_tokens_per_expert, tokens_per_expert, group=gpc.get_group(ParallelMode.EXPERT), async_op=True + ) + + # Permute locally and without any padding so that tokens for each + # parallel device are stored contiguously. + # + # This view updates the shape of the tensor from [sl, bs, hs] to + # [sl * bs, hs] prior to the permutation. + x = x.view(-1, x.shape[-1]) # TODO can be deleted + x = ops.gather(x, indices, bin_ids, bins, self.top_k) + + # Compute the number of tokens that will be received from each + # device and permute the input data across the devices. + with torch.no_grad(): + tpe_handle.wait() + experts_per_rank = self.num_local_experts # mpu.experts_per_rank(self.args) + + # Reshape to [world_size, num_experts_per_rank]. + world_size = gpc.get_world_size(ParallelMode.EXPERT) # mpu.get_expert_parallel_world_size(self.args) + tokens_per_expert = tokens_per_expert.view( + world_size, experts_per_rank + ) # ((1,2), (1,0)) in g1, ((2,0),(2,0)) in g2 + parallel_tokens_per_expert = parallel_tokens_per_expert.view( + world_size, experts_per_rank + ) # ((1,2), (2,0)) in g1, ((1,0),(2,0)) in g2 + + # TODO(tgale): It might be faster to do this on the GPU and + # then communicate the results back to the host. + send_counts = tokens_per_expert.cpu().sum(dim=-1) + parallel_tokens_per_expert_cpu = parallel_tokens_per_expert.cpu() + recv_counts = parallel_tokens_per_expert_cpu.sum(dim=-1) + + # Convert the send/recv counts to lists. + send_counts = send_counts.tolist() + recv_counts = recv_counts.tolist() + tokens_received = sum(recv_counts) + + # Start the cross-device permutation asynchronously so we can + # overlap communication with computation. + parallel_x, parallel_x_handle = all_to_all( + x, recv_counts, send_counts, gpc.get_group(ParallelMode.EXPERT), async_op=True + ) + + with torch.no_grad(): + # After we do the cross-device permutation we have the tokens on the + # correct device but not yet grouped by expert because we received + # tokens from each device as contiguous chunks. To group the tokens + # for expert computation we'll do one more local permutation. The + # rest of this torch.no_grad() scope sets up the indices and bins + # for this permutation. + + replicate_bins = ops.inclusive_cumsum(parallel_tokens_per_expert.flatten(), 0) + replicate_bins = replicate_bins.view(1) if len(replicate_bins.size()) == 0 else replicate_bins + + # Construct the expert indices for the permuted tokens. + parallel_top_expert = torch.remainder( + torch.arange(self.num_experts, dtype=torch.int32, device=indices.device), + self.num_local_experts, # mpu.experts_per_rank(self.args), + ) + parallel_top_expert = ops.replicate( + parallel_top_expert.unsqueeze(dim=0), replicate_bins, tokens_received + ).flatten() + + # TODO(tgale): The sort_end_bit here can be reduced. + _, parallel_indices = ops.sort(parallel_top_expert, self.sort_end_bit) + + # Calculate the bins boundaries from the token counts. + parallel_tokens_per_expert = parallel_tokens_per_expert.sum(dim=0, dtype=torch.int) + parallel_bins = ops.inclusive_cumsum(parallel_tokens_per_expert, 0) + parallel_bins = parallel_bins.view(1) if len(parallel_bins.size()) == 0 else parallel_bins + + # If expert_capacity is set to zero, set the number of tokens + # per expert to the maximum we need to avoid dropping tokens. + tokens, _ = x.size() + expert_capacity = self.expert_capacity(tokens, top_k=1) + if not self.drop_tokens: + expert_capacity = torch.max(parallel_tokens_per_expert).item() + + # Locally permute the tokens and perform the expert computation. + # Block to make sure that the cross-device permutation is complete. + parallel_x_handle.wait() + parallel_x = self.permute_and_compute( + parallel_x, + parallel_indices, + None, # expert_weights + parallel_bins, + expert_capacity, + top_k=1, + ) + + # Un-permute the tokens across the devices. + x, _ = all_to_all(parallel_x, send_counts, recv_counts, gpc.get_group(ParallelMode.EXPERT)) + + # Un-permute locally to setup for the next series of operations. + x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k, self.quantize_scatter_num_bits) + return x, tokens_per_expert.flatten() + + def permute_and_compute(self, x, indices, expert_weights, bins, expert_capacity, top_k): # unused # unused + # Route the tokens for MoE computation. + x = x.view(-1, x.shape[-1]) + x = ops.binned_gather(x, indices, bins, expert_capacity, top_k) + + # Perform the expert computation + # First Dense x Dense -> Sparse for w1 and w3, + # (top_k * sequence_length + padding, ffn_dim * n_experts) + x = self.experts(x) + + # Un-route the data for the MoE output. + return ops.binned_scatter(x, indices, expert_weights, bins, top_k) + + def load_balancing_loss(self, tokens_per_expert, expert_scores): + """Calculate the load balancing loss contribution.""" + assert len(expert_scores.size()) == 2 + tokens, num_experts = expert_scores.size() + assert num_experts == self.num_experts + assert len(tokens_per_expert.size()) == 1 + (num_experts,) = tokens_per_expert.size() + assert num_experts == self.num_experts + scale = self.num_experts / (tokens * self.top_k) + return scale * torch.dot(tokens_per_expert.to(expert_scores.dtype), expert_scores.mean(dim=0)) + + def dummy_moe_loss(self, *args, **kwargs): + return None + + + def parallel_forward_prepare(self, x, top_expert): + # NOTE: This function implements the same computation as forward_once + # but with expert model parallelism. + # + # 1. Permute the tokens locally so that they are grouped by their + # expert assignments. This allows us to transfer all of the tokens + # for a remote device in one communication primitive. + # + # 2. Permute the tokens across the expert parallel devices. After + # this is completed each device has all of the tokens assigned to + # its set of experts in its local HBM. + # + # 3. Permute the tokens locally so that they are grouped by their + # expert assignement. After the distributed permutation the tokens + # are grouped by which device they came from. We re-order them + # locally to allow for efficient computation. + # + # After this series of permutations we compute the linear layers + # and then repeat these three steps in reverse to produce the final + # output. + # + # Compute the mapping of local tokens to experts. + with torch.no_grad(): + indices, bin_ids, bins, tokens_per_expert = ( + self.indices_and_bins(top_expert)) + + # Pass token count information to the device on which the + # target expert resides. + parallel_tokens_per_expert = torch.empty_like( + tokens_per_expert) + tpe_handle = torch.distributed.all_to_all_single( + parallel_tokens_per_expert, + tokens_per_expert, + group=gpc.get_group(ParallelMode.EXPERT), + async_op=True) + + # Permute locally and without any padding so that tokens for each + # parallel device are stored contiguously. + # + # TODO(tgale): We can tune these kernels for this special case by + # skipping the memset if tokens == padded_tokens and also taking + # in an optional padded_tokens rather than copying it from the + # device. + # + # This view updates the shape of the tensor from [sl, bs, hs] to + # [sl * bs, hs] prior to the permutation. + x = x.view(-1, x.shape[-1]) + x = ops.gather(x, indices, bin_ids, bins, self.top_k) + + # Compute the number of tokens that will be received from each + # device and permute the input data across the devices. + with torch.no_grad(): + tpe_handle.wait() + world_size = gpc.get_world_size(ParallelMode.EXPERT) + + # Reshape to [world_size, num_experts_per_rank]. + tokens_per_expert = tokens_per_expert.view(world_size, -1) + parallel_tokens_per_expert = ( + parallel_tokens_per_expert.view(world_size, -1)) + + # TODO(tgale): It might be faster to do this on the GPU and + # then communicate the results back to the host. + send_counts = tokens_per_expert.cpu().sum(dim=-1) + recv_counts = parallel_tokens_per_expert.cpu().sum(dim=-1) + + # Convert the send/recv counts to lists. + send_counts = send_counts.tolist() + recv_counts = recv_counts.tolist() + tokens_received = sum(recv_counts) + + # After we do the cross-device permutation we have the tokens on the + # correct device but not yet grouped by expert because we received + # tokens from each device as contiguous chunks. To group the tokens + # for expert computation we'll do one more local permutation. The + # rest of this torch.no_grad() scope sets up the indices and bins + # for this permutation. + replicate_bins = ops.inclusive_cumsum( + parallel_tokens_per_expert.flatten(), 0) + replicate_bins = ( + replicate_bins.view(1) + if not len(replicate_bins.size()) + else replicate_bins + ) + + # Construct the expert indices for the permuted tokens. + parallel_top_expert = torch.remainder( + torch.arange( + self.num_experts, dtype=torch.int32, device=indices.device), + self.num_local_experts, + ) + parallel_top_expert = ops.replicate( + parallel_top_expert.unsqueeze(dim=0), + replicate_bins, tokens_received).flatten() + + # TODO(tgale): The sort_end_bit here can be reduced. + parallel_bin_ids, parallel_indices = ops.sort( + parallel_top_expert, self.sort_end_bit) + + # Calculate the bins boundaries from the token counts. + parallel_tokens_per_expert = parallel_tokens_per_expert.sum( + dim=0, dtype=torch.int) + parallel_bins = ops.inclusive_cumsum( + parallel_tokens_per_expert, 0) + parallel_bins = ( + parallel_bins.view(1) + if not len(parallel_bins.size()) + else parallel_bins + ) + + # If expert_capacity is set to zero, set the number of tokens + # per expert to the maximum we need to avoid dropping tokens. + tokens, hs = x.size() + expert_capacity = self.expert_capacity(tokens, top_k=1) + if not self.drop_tokens: + expert_capacity = torch.max(parallel_tokens_per_expert).item() + + return x, recv_counts, send_counts, parallel_tokens_per_expert, \ + parallel_indices, parallel_bin_ids, parallel_bins, expert_capacity, \ + indices, bin_ids, bins, tokens_per_expert + + def parallel_forward_a2a1(self, x, recv_counts, send_counts): + # Permute the tokens across the devices. + parallel_x = all_to_all( + x, recv_counts, send_counts, + gpc.get_group(ParallelMode.EXPERT)) + return parallel_x + # Locally permute the tokens and perform the expert computation. + def parallel_forward_compute(self, parallel_x, + parallel_tokens_per_expert, + parallel_indices, + parallel_bin_ids, + parallel_bins, + expert_capacity): + + parallel_x = self.permute_and_compute( + parallel_x, + parallel_indices, + None, # expert_weights + parallel_bins, + expert_capacity, + top_k=1, + ) + return parallel_x + + def parallel_forward_a2a2(self, parallel_x, send_counts, recv_counts): + # Un-permute the tokens across the devices. + x = all_to_all( + parallel_x, send_counts, recv_counts, + gpc.get_group(ParallelMode.EXPERT)) + return x + + def parallel_forward_post(self, x, indices, bin_ids, bins, tokens_per_expert): + + # Un-permute locally to setup for the next series of operations. + x = ops.scatter(x, indices, bin_ids, expert_weights, bins, self.top_k, self.quantize_scatter_num_bits) + return x, tokens_per_expert.flatten() + + def tutel_prepare(self, x, scores): + origin_shape = x.shape + x = x.view(-1, origin_shape[-1]) + crit, top_experts = tutel.tutel_moe.extract_critical(scores, + top_k = self.args.moe_top_k, + loss_fn = None, + capacity_factor = self.args.moe_capacity_factor + ) + + tokens_per_expert = ops.histogram(top_experts.view(-1), self.num_experts) + + y = tutel.tutel_moe.fast_encode(x.to(scores.dtype), crit, True).to(x.dtype) + return y, tokens_per_expert, crit + + def tutel_a2a1(self, x): + return tutel.impls.communicate.all_to_all(x, 1, 0, use_2dh=False, group=gpc.get_group(ParallelMode.EXPERT)) + + def tutel_a2a2(self, x): + return tutel.impls.communicate.all_to_all(x, 0, 1, use_2dh=False, group=gpc.get_group(ParallelMode.EXPERT)) + + def tutel_post(self, x, crit, dtype): + y = tutel.tutel_moe.fast_decode(x.to(dtype), crit, True) + return y + + def hash_forward(self, x, timers, start, stop): + pass + + def tutel_prepare_fwd(self, ctx, x): + ctx.x0 = x.detach() + ctx.x0.requires_grad = True + with torch.enable_grad(): + scores = self.router.tutel_forward(ctx.x0) + ctx.scores = scores + origin_shape = x.shape + x = x.view(-1, origin_shape[-1]) + + y, tokens_per_expert, dispatcher = tutel.impls.fast_dispatch.extract_critical_encode(ctx, x, scores, + top_k = self.args.moe_top_k, + loss_fn = None, + capacity_factor = self.args.moe_capacity_factor + ) + + return y, dispatcher, origin_shape, scores, tokens_per_expert + #y, crit, dispatcher = tutel.tutel_moe.fast_encode(x.to(scores.dtype), crit, True).to(x.dtype) + + + def tutel_prepare_bwd(self, ctx, g_score, g_tokens, g_gates): + + grad_x = tutel.impls.fast_dispatch.encode_bwd(ctx, g_tokens) + for g_gate, gate in zip(g_gates, ctx.gates_s): + gate.backward(g_gate) + + #print("score0:", ctx.scores0.grad) + ctx.scores.backward(g_score + ctx.scores0.grad) + #print("bwd: ", g_tokens.size(), grad_x.size(), ctx.x0.size(), ctx.x0.grad.size()) + grad_x = grad_x.view(ctx.x0.grad.size()) + return grad_x + ctx.x0.grad + + def tutel_mlp_fwd(self, ctx, tokens): + ctx.tokens = tokens.detach() + ctx.tokens.requires_grad = True + with torch.enable_grad(): + y = self.mlp(ctx.tokens) + ctx.y = NoBuffer.apply(y) + return y + + def tutel_mlp_bwd(self, ctx, g_tokens): + ctx.y.backward(g_tokens) + return ctx.tokens.grad + + def tutel_a2a_scatter(self, tokens, tp_info): + group = self.args.expert_parallel_group + world_size = get_world_size(group) #world size not include TP ranks + if world_size == 1: + return tokens + + tokens = tokens.contiguous() + output = torch.empty_like(tokens) + + C.AllToAllStatus.init(group, -1, -1) + tutel_custom_kernel.all_to_all_with_scale(tokens, output, FAKE_A2A_SCALE) + ''' + torch.distributed.all_to_all_single(output, tokens, group=group) + if FAKE_A2A_SCALE > 1: + for i in range(FAKE_A2A_SCALE - 1): + torch.distributed.all_to_all_single(output, tokens, group=group) + ''' + + + output = output.view([world_size, -1] + list(output.shape[1:])) + output = output.permute([1, 0] + list(range(2, output.dim()))) + #print("o0.size: ", output.size()) #torch.Size([1, 8, 1280, 512]) + output = output.contiguous().view(list(output.shape[:1]) + [-1] + list(output.shape[3:])) + #[1, 10240, 512] + #y = tutel.impls.communicate.all_to_all(y, 1, 0, use_2dh=False, group=self.args.expert_parallel_group) + return output + + def tutel_a2a_scatter_p0(self, tokens): + world_size = get_world_size(self.args.expert_parallel_group) + if world_size == 1: + return tokens + tokens = tokens.contiguous() + output = torch.empty_like(tokens) + return tokens, output + + def tutel_a2a_scatter_p1(self, tokens, output): + C.AllToAllStatus.init(self.args.expert_parallel_group, -1, -1) + tutel_custom_kernel.all_to_all_with_scale(tokens, output, FAKE_A2A_SCALE) + + def tutel_a2a_scatter_p2(self, output): + output = output.view([world_size, -1] + list(output.shape[1:])) + output = output.permute([1, 0] + list(range(2, output.dim()))) + #print("o0.size: ", output.size()) #torch.Size([1, 8, 1280, 512]) + output = output.contiguous().view(list(output.shape[:1]) + [-1] + list(output.shape[3:])) + return output + + def tutel_a2a_gather(self, tokens, tp_info): + group = self.args.expert_parallel_group + world_size = get_world_size(group) + if world_size == 1: + return tokens + + + + reshaped_input = tokens.view(list(tokens.shape[:1]) + [world_size, -1] + list(tokens.shape[2:])) + reshaped_input = reshaped_input.permute([1, 0] + list(range(2, reshaped_input.dim()))).contiguous() + #simple_all_to_all(reshaped_input, group, background=True) + local_input = torch.empty_like(reshaped_input) + + C.AllToAllStatus.init(group, -1, -1) + tutel_custom_kernel.all_to_all_with_scale(reshaped_input, local_input, FAKE_A2A_SCALE) + + + ''' + torch.distributed.all_to_all_single(local_input, reshaped_input, group=group) + + if FAKE_A2A_SCALE > 1: + for i in range(FAKE_A2A_SCALE - 1): + torch.distributed.all_to_all_single(local_input, reshaped_input, group=group) + ''' + local_input = local_input.view([-1] + list(local_input.shape[2:])) + + if tp_info[0] > 1 : + torch.distributed.all_reduce(local_input, op=torch.distributed.ReduceOp.SUM, group=tp_info[1]) + + return local_input + + def tutel_post_fwd(self, ctx, tokens, dispatcher): + + tokens = tutel.impls.fast_dispatch.decode_fwd(ctx, tokens, dispatcher) + + return tokens + + def tutel_post_bwd(self, ctx, g_tokens): + tokens_grad, scores_grad = tutel.impls.fast_dispatch.decode_bwd(ctx, g_tokens) + return tokens_grad, scores_grad + + def forward(self, *inputs) -> torch.Tensor: + # optional reshape + x = inputs[0] + input_shape = x.shape + x = x.view(-1, input_shape[-1]) + + # gate_logits: (sequence_length, n_experts) + gate_logits = self.gate(x) + + # all_probs: (sequence_length, n_experts) and upcast for softmax + all_probs = F.softmax(gate_logits, dim=-1, dtype=torch.float) + # weights, selected_experts: (sequence_length, top-k) + expert_weights, top_experts = torch.topk(all_probs, self.top_k, dim=-1) + expert_weights /= expert_weights.sum(dim=-1, keepdim=True) + expert_weights = expert_weights.flatten() + top_experts = top_experts.flatten() + + x, tokens_per_expert = self.forward_fn(x, expert_weights, top_experts) + + self.l_aux = self.dummy_moe_loss(tokens_per_expert, all_probs) + + return x.view(*input_shape) \ No newline at end of file diff --git a/internlm/model/moe/ampipe/tutel_adapter.py b/internlm/model/moe/ampipe/tutel_adapter.py new file mode 100644 index 00000000..26ad1bc4 --- /dev/null +++ b/internlm/model/moe/ampipe/tutel_adapter.py @@ -0,0 +1,99 @@ +import torch + +import tutel.impls.losses +from tutel.impls.fast_dispatch import compute_sorted_location, GatingDecoder, GatingEncoder, TutelMoeFastDispatcher +from tutel.jit_kernels.gating import fast_cumsum_sub_one +from tutel.impls.communicate import simple_all_reduce + +def extract_critical_encode(ctx, x, scores, top_k, loss_fn=losses.gshard_loss, capacity_factor=1.0, batch_prioritized_routing=False, normalize_gate=True, alignment=1, group=None, inequivalent_tokens=False): + num_global_experts = int(scores.size(1)) + top_k, top_k_original = min(top_k, num_global_experts), top_k + topk_indices = torch.topk(scores, top_k, dim=1).indices + + indices_s = [x.view(-1) for x in topk_indices.chunk(top_k, dim=1)] + + masks_se = [losses._one_hot_with_dtype(x, num_classes=num_global_experts, dtype=x.dtype) for x in indices_s] + ctx.scores0 = scores.detach() + ctx.scores0.requires_grad = True + with torch.enable_grad(): + gates_s = [(ctx.scores0 * x).sum(dim=1) for x in masks_se] + ctx.gates_s = gates_s + + l_loss = loss_fn(scores, topk_indices) if loss_fn is not None else None + + if batch_prioritized_routing: + importance_scores = -1 * scores.max(dim=1)[0] + compute_location = lambda x: compute_sorted_location(x, importance_scores) + else: + compute_location = fast_cumsum_sub_one + + locations1 = compute_location(masks_se[0]) + + locations_s = [torch.sum(locations1 * masks_se[0], dim=1).to(torch.int32)] + + if top_k > 1: + acc_base = None + for k in range(1, top_k): + acc_base = torch.sum(masks_se[k - 1], dim=0, keepdim=True) if acc_base is None else acc_base + torch.sum(masks_se[k - 1], dim=0, keepdim=True) + locations2 = compute_location(masks_se[k]) + locations2 += acc_base + locations_s.append(torch.sum(locations2 * masks_se[k], dim=1).to(torch.int32)) + + if normalize_gate: + denom_s = torch.clamp(sum(gates_s), min=torch.finfo(gates_s[0].dtype).eps) + gates_s = [x / denom_s for x in gates_s] + + indices_s = [x.to(torch.int32) for x in indices_s] + + if inequivalent_tokens: + num_samples = torch.tensor(scores.size(0), device=scores.device) + if + num_samples = int(simple_all_reduce(num_samples, group=group, op=torch.distributed.ReduceOp.MAX)) + else: + num_samples = int(scores.size(0)) + + samples_per_expert = (num_samples + num_global_experts - 1) // num_global_experts + if capacity_factor > 0: + capacity = top_k * int(capacity_factor * samples_per_expert) + else: + capacity = torch.max(torch.cat(locations_s, dim=0)) + capacity = int(simple_all_reduce(capacity, group=group, op=torch.distributed.ReduceOp.MAX)) + 1 + if capacity_factor < 0: + capacity = min(capacity, top_k * int(-capacity_factor * samples_per_expert)) + + remainder = capacity % alignment + if remainder > 0: + capacity = capacity + alignment - remainder + + if get_world_rank(group) == 0: + logging.info(f"Capacity = {capacity}, real-time capacity-factor for top-{top_k_original} = {capacity / (top_k * samples_per_expert)}") + + crit = (num_global_experts, indices_s, locations_s, gates_s, capacity) + top_experts = topk_indices + + tokens_per_expert = torch.histc(top_experts, bins=num_global_experts, min=0, max=num_global_experts) + + dispatcher = TutelMoeFastDispatcher(num_global_experts, 0, x.size(-1), x.dtype) + dispatcher.update(indices_s, locations_s, gates_s, capacity, is_postscore=True) + + assert dispatcher.dtype == torch.float16 and x.dtype == torch.float16 and torch.float16 == dispatcher.original_dtype + x = GatingEncoder.forward(ctx, dispatcher, x) + ctx.original_x_shape = x.size() + return x.view(num_global_experts, -1, x.size(-1)), tokens_per_expert, dispatcher + + +def encode_bwd(ctx, grad_y): + grad_y = grad_y.view(ctx.original_x_shape) + grad_xs = GatingEncoder.backward(ctx, grad_y) + return grad_xs[1] + +def decode_fwd(ctx, x, dispatcher): + #dispatcher.decode(x).view(-1, x.size(-1)) + assert dispatcher.dtype == torch.float16 and x.dtype == torch.float16 and torch.float16 == dispatcher.original_dtype + out = GatingDecoder.forward(ctx, dispatcher, x, *dispatcher.gates_) + + return out + +def decode_bwd(ctx, grad_y): + grads = GatingDecoder.backward(ctx, grad_y) + return grads[1], grads[2:] #scores grad \ No newline at end of file From 30b64b46c1a38ced7476c2a760b8ae529b5de5e0 Mon Sep 17 00:00:00 2001 From: Wenwen Qu Date: Wed, 18 Dec 2024 15:43:24 +0800 Subject: [PATCH 4/4] test --- internlm/model/modeling_moe.py | 10 +- internlm/model/modules/linear.py | 27 ++++++ internlm/model/moe/ampipe/ampipe.py | 35 +++---- internlm/model/moe/ampipe/fa_helper.py | 17 ++-- internlm/model/moe/ampipe/moe_layer.py | 104 +++++++++++++++------ internlm/model/moe/ampipe/tutel_adapter.py | 11 +-- internlm/model/moe/moe.py | 3 + internlm/model/ops/norm.py | 72 ++++++++++++-- 8 files changed, 203 insertions(+), 76 deletions(-) diff --git a/internlm/model/modeling_moe.py b/internlm/model/modeling_moe.py index 125dcbf6..f5eea3fc 100644 --- a/internlm/model/modeling_moe.py +++ b/internlm/model/modeling_moe.py @@ -3,6 +3,7 @@ import math from typing import Optional +from einops import rearrange import torch from torch import nn @@ -219,7 +220,6 @@ def _dropout_and_norm_attn(_hidden_states): residual = residual.to(torch.float32) mixer_kwargs = convert_attn_args_to_kwargs(args, kwargs) - if gpc.config.model.ampipe_degree < 1: hidden_states = self.mixer(hidden_states, **mixer_kwargs) @@ -251,9 +251,12 @@ def _dropout_and_norm_ffn(_residual, _hidden_states): flash = self.mixer.inner_attn dense_layer = self.mixer.out_proj ln = self.norm2 - + k, v = [rearrange(x, 'b s ... -> (b s) ...') for x in [k, v]] + # torch.cuda.synchronize() + # breakpoint() + # print(q, flush=True) hidden_states, residual = AttMoEPipe.apply(q, k, v, hidden_states, - ln.weight, ln.bias, dense_layer.bias, + ln.weight, None, dense_layer.bias, [flash, dense_layer, gpc.config.model.ampipe_degree, ln, self.dropout2.p, \ self.mlp.moe_layer]) @@ -333,6 +336,7 @@ def __init__( top_k: int = 1, num_shared_experts: int = 0, moe_layer_kwargs: dict = None, + ampipe_degree: str = None, # pylint: disable=W0613 ): super().__init__() diff --git a/internlm/model/modules/linear.py b/internlm/model/modules/linear.py index 05c7ae43..1f117208 100644 --- a/internlm/model/modules/linear.py +++ b/internlm/model/modules/linear.py @@ -717,6 +717,33 @@ def forward(self, input: torch.Tensor, batch_sizes: torch.Tensor = None) -> torc **mixer_kwargs, ) + def explicit_fwd(self, ctx, input: torch.Tensor, batch_sizes: torch.Tensor = None) -> torch.Tensor: # pylint: disable=W0622 + _class_name = self.__class__.__name__ + assert self._communicator is not None, f"{_class_name} should register with a communicator first." + + mixer_kwargs = {} + use_grouped_linear = getattr(self, "is_grouped_linear", False) + if use_grouped_linear: + mixer_kwargs = { + "batch_sizes": batch_sizes, + "backend": self.backend, + "full_weight_shape": self.full_weight_shape if hasattr(self, "full_weight_shape") else None, + } + + return explicit_fused_dense_forward( + ctx, + input, + self.weight, + communicator=self._communicator, + module=self, + bias=self.bias, + use_grouped_linear=use_grouped_linear, + **mixer_kwargs, + ) + + def explicit_bwd(self, ctx, grad_output: torch.Tensor): + return explicit_fused_dense_backward(ctx, grad_output) + class ColumnParallelLinear(ParallelLinearWithCommExt): """ diff --git a/internlm/model/moe/ampipe/ampipe.py b/internlm/model/moe/ampipe/ampipe.py index 3fc2ec33..e5f78c9e 100644 --- a/internlm/model/moe/ampipe/ampipe.py +++ b/internlm/model/moe/ampipe/ampipe.py @@ -36,7 +36,7 @@ def bias_dropout_add_fused_train(x: torch.Tensor, prob: float) -> torch.Tensor: return bias_dropout_add(x, bias, residual, prob, True) -def bias_dropout_add_ln_fwd(ctx, inp, residual, bias, prob, ln, bias_dropout_add_exec_handler): +def bias_dropout_add_ln_fwd(ctx, inp, residual, bias, prob, ln): ctx.inp = inp ctx.residual = residual ctx.bias = bias.detach() @@ -86,10 +86,8 @@ class AttMoEPipe(torch.autograd.Function): def forward(ctx, q, k, v, hidden_states, ln_weight, ln_bias, proj_bias, non_params): #torch.cuda.synchronize() #t0 = time.time() - ctx.non_params = non_params - flash, dense_layer, pipe_degree, ln, hidden_dropout, bias_dropout_add_exec_handler, moe \ - = non_params + flash, dense_layer, pipe_degree, ln, hidden_dropout, moe = non_params ctx.batch_size, seqlen, ctx.head = q.size(0), q.size(1), q.size(2) @@ -139,22 +137,21 @@ def forward(ctx, q, k, v, hidden_states, ln_weight, ln_bias, proj_bias, non_para q_use = rearrange(q[:,base:base+chunk_len], 'b s ... -> (b s) ...') flash_ctx = FakeContext() - with random.seed(ParallelMode.Tensor): + with random.seed(ParallelMode.TENSOR): output_chunk = flash_attn_fwd(flash_ctx, q_use, k, v, cu_seqlens_q, cu_seqlens, chunk_len, seqlen, - flash.dropout_p if flash.training else 0.0, + flash.dropout.p if flash.training else 0.0, softmax_scale=flash.softmax_scale, causal=True, causal_q_offset=base, #fixed , - version=1 + version=2 ) ctx.flash_ctx.append(flash_ctx) - context_layers.append(rearrange(output_chunk, '(b s) h d -> s b (h d)', b=ctx.batch_size).contiguous()) + context_layers.append(rearrange(output_chunk, '(b s) h d -> b s (h d)', b=ctx.batch_size).contiguous()) base += chunk_len dense_ctx = FakeContext() context_layers[-1] = dense_layer.explicit_fwd(dense_ctx, context_layers[-1]) ctx.dense_ctx.append(dense_ctx) - #if DEBUG: # tensor_list = [torch.empty_like(context_layers[-1] ) for _ in range(mpu.get_tensor_model_parallel_world_size())] # torch.distributed.all_gather(tensor_list, context_layers[-1] , group=mpu.get_tensor_model_parallel_group()) @@ -163,9 +160,10 @@ def forward(ctx, q, k, v, hidden_states, ln_weight, ln_bias, proj_bias, non_para bdal_ctx = FakeContext() ln_output, ln_input = bias_dropout_add_ln_fwd(bdal_ctx, context_layers[-1], \ - hidden_states_chunks[c], dense_layer.bias, hidden_dropout, ln, bias_dropout_add_exec_handler) + hidden_states_chunks[c], dense_layer.bias, hidden_dropout, ln) ctx.bdal_ctx.append(bdal_ctx) ln_ins.append(ln_input) + #ln_outs.append(ln_output) #if DEBUG: # tensor_list = [torch.empty_like(ln_output) for _ in range(mpu.get_tensor_model_parallel_world_size())] @@ -173,7 +171,6 @@ def forward(ctx, q, k, v, hidden_states, ln_weight, ln_bias, proj_bias, non_para # for t in tensor_list: # assert (t == ln_output).all().item(), "not same a2a input across tp ranks" - prepare_ctx = FakeContext() a2a_tokens, dispatcher, origin_shape, scores, tokens_per_expert\ = moe.tutel_prepare_fwd(prepare_ctx, ln_output) @@ -197,7 +194,7 @@ def forward(ctx, q, k, v, hidden_states, ln_weight, ln_bias, proj_bias, non_para # for t in tensor_list: # assert (t == intermediate[c]).all().item(), "not same across tp ranks" - a2a_tokens = moe.tutel_a2a_scatter(intermediate[c], [mpu.get_tensor_model_parallel_world_size(), mpu.get_tensor_model_parallel_group()]) + a2a_tokens = moe.tutel_a2a_scatter(intermediate[c]) intermediate[c] = a2a_tokens a2a1_events.append(torch.cuda.current_stream().record_event()) @@ -223,7 +220,7 @@ def forward(ctx, q, k, v, hidden_states, ln_weight, ln_bias, proj_bias, non_para with torch.cuda.stream(get_comm()): torch.cuda.current_stream().wait_event(comp_events[c]) - a2a_tokens = moe.tutel_a2a_gather(intermediate[c], [mpu.get_tensor_model_parallel_world_size(), mpu.get_tensor_model_parallel_group()]) + a2a_tokens = moe.tutel_a2a_gather(intermediate[c]) intermediate[c] = a2a_tokens a2a2_events.append(torch.cuda.current_stream().record_event()) @@ -236,9 +233,6 @@ def forward(ctx, q, k, v, hidden_states, ln_weight, ln_bias, proj_bias, non_para post_ctx = FakeContext() post_out = moe.tutel_post_fwd(post_ctx, intermediate[c], dispatchers[c]) - - - ctx.post_ctx.append(post_ctx) post_out = post_out.view(origin_shape) @@ -252,7 +246,7 @@ def forward(ctx, q, k, v, hidden_states, ln_weight, ln_bias, proj_bias, non_para torch.cuda.current_stream().wait_stream(get_comp0()) - ret = torch.cat(moe_outs), torch.cat(ln_ins) + ret = torch.cat(moe_outs), torch.cat(context_layers) #torch.cuda.synchronize() #te = time.time() #if torch.distributed.get_rank() == 0: @@ -261,8 +255,7 @@ def forward(ctx, q, k, v, hidden_states, ln_weight, ln_bias, proj_bias, non_para @staticmethod def backward(ctx, grad_mlp_outs, grad_ln_ins): - flash, attn, dense_layer, pipe_degree, ln, hidden_dropout, bias_dropout_add_exec_handler, moe\ - = ctx.non_params + flash, attn, dense_layer, pipe_degree, ln, hidden_dropout, moe = ctx.non_params grad_mlp_outs = grad_mlp_outs.chunk(pipe_degree) grad_ln_ins = grad_ln_ins.chunk(pipe_degree) @@ -292,7 +285,7 @@ def backward(ctx, grad_mlp_outs, grad_ln_ins): for c in range(0, pipe_degree): with torch.cuda.stream(get_comm()): torch.cuda.current_stream().wait_event(post_events[c]) - intermediate[c] = moe.tutel_a2a_scatter(intermediate[c], [mpu.get_tensor_model_parallel_world_size(), mpu.get_tensor_model_parallel_group()]) + intermediate[c] = moe.tutel_a2a_scatter(intermediate[c]) a2a2_events.append(torch.cuda.current_stream().record_event()) torch.cuda.synchronize() @@ -308,7 +301,7 @@ def backward(ctx, grad_mlp_outs, grad_ln_ins): for c in range(0, pipe_degree): with torch.cuda.stream(get_comm()): torch.cuda.current_stream().wait_event(comp_events[c]) - intermediate[c] = moe.tutel_a2a_gather(intermediate[c], [mpu.get_tensor_model_parallel_world_size(), mpu.get_tensor_model_parallel_group()]) + intermediate[c] = moe.tutel_a2a_gather(intermediate[c]) a2a1_events.append(torch.cuda.current_stream().record_event()) for c in range(0, pipe_degree): diff --git a/internlm/model/moe/ampipe/fa_helper.py b/internlm/model/moe/ampipe/fa_helper.py index 4b343a20..77e0708f 100644 --- a/internlm/model/moe/ampipe/fa_helper.py +++ b/internlm/model/moe/ampipe/fa_helper.py @@ -1,6 +1,7 @@ import torch import flash_attn_2_cuda -import flash_attn_cuda +# import flash_attn_cuda +flash_attn_cuda = None def _flash_attn1_forward(q, k, v, out, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal, return_softmax, num_splits=0, @@ -43,12 +44,12 @@ def _flash_attn_varlen_forward(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q dropout_p, softmax_scale, causal, return_softmax, causal_q_offset): maybe_contiguous = lambda x: x.contiguous() if x.stride(-1) != 1 else x q, k, v = [maybe_contiguous(x) for x in (q, k, v)] - out, q, k, v, out_padded, softmax_lse, S_dmask = flash_attn_2_cuda.varlen_fwd( + out, q, k, v, out_padded, softmax_lse, p, rng_state = flash_attn_2_cuda.varlen_fwd( q, k, v, None, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, - softmax_scale, False, causal, return_softmax, causal_q_offset, None + softmax_scale, False, causal, return_softmax, None ) - return out, q, k, v, out_padded, softmax_lse, S_dmask + return out, q, k, v, out_padded, softmax_lse, p, rng_state def _flash_attn_varlen_backward(dout, q, k, v, out, softmax_lse, dq, dk, dv, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, @@ -131,7 +132,7 @@ def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) ctx.causal_q_offset = causal_q_offset - out, q, k, v, out_padded, softmax_lse, S_dmask = _flash_attn_varlen_forward( + out, q, k, v, out_padded, softmax_lse, p, rng_state = _flash_attn_varlen_forward( q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, dropout_p, softmax_scale, causal=causal, return_softmax=return_softmax and dropout_p > 0, causal_q_offset=causal_q_offset ) @@ -152,7 +153,7 @@ def forward(ctx, q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k TIMERS[key][version] = t1 - t0 else: TIMERS[key] = {version: t1 - t0} - return out if not return_softmax else (out, softmax_lse, S_dmask) + return out if not return_softmax else (out, softmax_lse, p, rng_state) @staticmethod def backward(ctx, dout, *args): @@ -249,9 +250,9 @@ def flash_attn_megablock_call(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, def flash_attn_fwd(ctx, q_use, k, v, cu_seqlens_q, cu_seqlens, chunk_len, seqlen, dropout_p, softmax_scale, causal, causal_q_offset, version): - assert causal and version == 1 + # assert causal and version == 1 out = FlashAttnFuncMerge.forward(ctx, q_use, k, v, cu_seqlens_q, cu_seqlens, chunk_len, seqlen, - dropout_p, softmax_scale, True, False, False, causal_q_offset, 1 + dropout_p, softmax_scale, True, False, False, causal_q_offset, version ) return out diff --git a/internlm/model/moe/ampipe/moe_layer.py b/internlm/model/moe/ampipe/moe_layer.py index aa84d776..5e65a917 100644 --- a/internlm/model/moe/ampipe/moe_layer.py +++ b/internlm/model/moe/ampipe/moe_layer.py @@ -5,6 +5,7 @@ import torch.nn.functional as F import tutel +import megablocks_ops from internlm.core.context import ParallelMode from internlm.core.context import global_context as gpc @@ -12,6 +13,8 @@ from internlm.model.moe.megablocks.mlp import MegaBlockFeedForward from internlm.model.moe.utils import all_to_all +from internlm.model.moe.ampipe.tutel_adapter import extract_critical_encode, encode_bwd, decode_fwd, decode_bwd + try: from megablocks import ops except (ModuleNotFoundError, ImportError): @@ -172,6 +175,47 @@ def get_world_size(group=None): # scale = scale_numerator / scale_denominator # return scale * torch.dot(tokens_per_expert, expert_scores) + +class TopKGate(torch.nn.Module): + """Gate module which implements Top2Gating as described in Gshard_. + :: + gate = TopKGate(model_dim, num_experts) + l_aux, combine_weights, dispatch_mask = gate(input) + .. Gshard_: https://arxiv.org/pdf/2006.16668.pdf + Args: + model_dim (int): + size of model embedding dimension + num_experts (ints): + number of experts in model + """ + + wg: torch.nn.Linear + + def __init__( + self, + model_dim: int, + num_experts: int, + topk: int = 1, + noisy_gate_policy: Optional[str] = None, + ) -> None: + super().__init__() + + # Deepspeed's mechisms, alway use fp32 + self.wg = torch.nn.Linear(model_dim, num_experts, bias=False) + self.k = topk + + self.noisy_gate_policy = noisy_gate_policy + + def forward(self, inputs: torch.Tensor): + # input jittering + if self.noisy_gate_policy == "Jitter" and self.training: + inputs = multiplicative_jitter(inputs, device=inputs.device) + logits = self.wg(inputs) + gates = F.softmax(logits, dim=1) + + return gates + + class AmpipeMegaBlockMoE(BaseMoELayer): """ Built on the paper and library Megablocks as described in @@ -213,7 +257,11 @@ def __init__( self.drop_tokens = drop_tokens assert self.ffn_dim % tp_size == 0 super().__init__( - torch.nn.Linear(in_features, num_experts, bias=False), + TopKGate( + in_features, + num_experts, + top_k, + ), MegaBlockFeedForward( in_features, self.ffn_dim // tp_size, @@ -607,7 +655,7 @@ def tutel_prepare(self, x, scores): crit, top_experts = tutel.tutel_moe.extract_critical(scores, top_k = self.args.moe_top_k, loss_fn = None, - capacity_factor = self.args.moe_capacity_factor + capacity_factor = self.capacity_factor ) tokens_per_expert = ops.histogram(top_experts.view(-1), self.num_experts) @@ -632,24 +680,23 @@ def tutel_prepare_fwd(self, ctx, x): ctx.x0 = x.detach() ctx.x0.requires_grad = True with torch.enable_grad(): - scores = self.router.tutel_forward(ctx.x0) + scores = self.gate(ctx.x0.view(-1, x.shape[-1])) ctx.scores = scores origin_shape = x.shape x = x.view(-1, origin_shape[-1]) - y, tokens_per_expert, dispatcher = tutel.impls.fast_dispatch.extract_critical_encode(ctx, x, scores, - top_k = self.args.moe_top_k, + y, tokens_per_expert, dispatcher = extract_critical_encode(ctx, x, scores, + top_k = self.top_k, loss_fn = None, - capacity_factor = self.args.moe_capacity_factor + capacity_factor = self.capacity_factor ) - return y, dispatcher, origin_shape, scores, tokens_per_expert #y, crit, dispatcher = tutel.tutel_moe.fast_encode(x.to(scores.dtype), crit, True).to(x.dtype) def tutel_prepare_bwd(self, ctx, g_score, g_tokens, g_gates): - grad_x = tutel.impls.fast_dispatch.encode_bwd(ctx, g_tokens) + grad_x = encode_bwd(ctx, g_tokens) for g_gate, gate in zip(g_gates, ctx.gates_s): gate.backward(g_gate) @@ -663,7 +710,7 @@ def tutel_mlp_fwd(self, ctx, tokens): ctx.tokens = tokens.detach() ctx.tokens.requires_grad = True with torch.enable_grad(): - y = self.mlp(ctx.tokens) + y = self.experts(ctx.tokens) ctx.y = NoBuffer.apply(y) return y @@ -671,30 +718,28 @@ def tutel_mlp_bwd(self, ctx, g_tokens): ctx.y.backward(g_tokens) return ctx.tokens.grad - def tutel_a2a_scatter(self, tokens, tp_info): - group = self.args.expert_parallel_group - world_size = get_world_size(group) #world size not include TP ranks + def tutel_a2a_scatter(self, tokens): + # group = gpc.get_group(ParallelMode.EXPERT) + world_size = gpc.get_world_size(ParallelMode.EXPERT) #world size not include TP ranks if world_size == 1: return tokens - tokens = tokens.contiguous() - output = torch.empty_like(tokens) + # tokens = tokens.contiguous() + # output = torch.empty_like(tokens) - C.AllToAllStatus.init(group, -1, -1) - tutel_custom_kernel.all_to_all_with_scale(tokens, output, FAKE_A2A_SCALE) + # C.AllToAllStatus.init(group, -1, -1) + # tutel_custom_kernel.all_to_all_with_scale(tokens, output, FAKE_A2A_SCALE) ''' torch.distributed.all_to_all_single(output, tokens, group=group) if FAKE_A2A_SCALE > 1: for i in range(FAKE_A2A_SCALE - 1): torch.distributed.all_to_all_single(output, tokens, group=group) ''' - + output, _ = all_to_all(tokens, group=gpc.get_group(ParallelMode.EXPERT)) output = output.view([world_size, -1] + list(output.shape[1:])) output = output.permute([1, 0] + list(range(2, output.dim()))) - #print("o0.size: ", output.size()) #torch.Size([1, 8, 1280, 512]) output = output.contiguous().view(list(output.shape[:1]) + [-1] + list(output.shape[3:])) - #[1, 10240, 512] #y = tutel.impls.communicate.all_to_all(y, 1, 0, use_2dh=False, group=self.args.expert_parallel_group) return output @@ -717,9 +762,9 @@ def tutel_a2a_scatter_p2(self, output): output = output.contiguous().view(list(output.shape[:1]) + [-1] + list(output.shape[3:])) return output - def tutel_a2a_gather(self, tokens, tp_info): - group = self.args.expert_parallel_group - world_size = get_world_size(group) + def tutel_a2a_gather(self, tokens): + # group = gpc.get_group(ParallelMode.EXPERT) + world_size = gpc.get_world_size(ParallelMode.EXPERT) if world_size == 1: return tokens @@ -728,11 +773,12 @@ def tutel_a2a_gather(self, tokens, tp_info): reshaped_input = tokens.view(list(tokens.shape[:1]) + [world_size, -1] + list(tokens.shape[2:])) reshaped_input = reshaped_input.permute([1, 0] + list(range(2, reshaped_input.dim()))).contiguous() #simple_all_to_all(reshaped_input, group, background=True) - local_input = torch.empty_like(reshaped_input) + # local_input = torch.empty_like(reshaped_input) - C.AllToAllStatus.init(group, -1, -1) - tutel_custom_kernel.all_to_all_with_scale(reshaped_input, local_input, FAKE_A2A_SCALE) + # C.AllToAllStatus.init(group, -1, -1) + # tutel_custom_kernel.all_to_all_with_scale(reshaped_input, local_input, FAKE_A2A_SCALE) + local_input, _ = all_to_all(reshaped_input, group=gpc.get_group(ParallelMode.EXPERT)) ''' torch.distributed.all_to_all_single(local_input, reshaped_input, group=group) @@ -743,19 +789,19 @@ def tutel_a2a_gather(self, tokens, tp_info): ''' local_input = local_input.view([-1] + list(local_input.shape[2:])) - if tp_info[0] > 1 : - torch.distributed.all_reduce(local_input, op=torch.distributed.ReduceOp.SUM, group=tp_info[1]) + # if tp_info[0] > 1 : + # torch.distributed.all_reduce(local_input, op=torch.distributed.ReduceOp.SUM, group=tp_info[1]) return local_input def tutel_post_fwd(self, ctx, tokens, dispatcher): - tokens = tutel.impls.fast_dispatch.decode_fwd(ctx, tokens, dispatcher) + tokens = decode_fwd(ctx, tokens.view(-1, tokens.shape(-1)), dispatcher) return tokens def tutel_post_bwd(self, ctx, g_tokens): - tokens_grad, scores_grad = tutel.impls.fast_dispatch.decode_bwd(ctx, g_tokens) + tokens_grad, scores_grad = decode_bwd(ctx, g_tokens) return tokens_grad, scores_grad def forward(self, *inputs) -> torch.Tensor: diff --git a/internlm/model/moe/ampipe/tutel_adapter.py b/internlm/model/moe/ampipe/tutel_adapter.py index 26ad1bc4..8fb92c33 100644 --- a/internlm/model/moe/ampipe/tutel_adapter.py +++ b/internlm/model/moe/ampipe/tutel_adapter.py @@ -1,6 +1,6 @@ import torch -import tutel.impls.losses +from tutel.impls import losses from tutel.impls.fast_dispatch import compute_sorted_location, GatingDecoder, GatingEncoder, TutelMoeFastDispatcher from tutel.jit_kernels.gating import fast_cumsum_sub_one from tutel.impls.communicate import simple_all_reduce @@ -47,7 +47,6 @@ def extract_critical_encode(ctx, x, scores, top_k, loss_fn=losses.gshard_loss, c if inequivalent_tokens: num_samples = torch.tensor(scores.size(0), device=scores.device) - if num_samples = int(simple_all_reduce(num_samples, group=group, op=torch.distributed.ReduceOp.MAX)) else: num_samples = int(scores.size(0)) @@ -65,8 +64,8 @@ def extract_critical_encode(ctx, x, scores, top_k, loss_fn=losses.gshard_loss, c if remainder > 0: capacity = capacity + alignment - remainder - if get_world_rank(group) == 0: - logging.info(f"Capacity = {capacity}, real-time capacity-factor for top-{top_k_original} = {capacity / (top_k * samples_per_expert)}") + # if get_world_rank(group) == 0: + # logging.info(f"Capacity = {capacity}, real-time capacity-factor for top-{top_k_original} = {capacity / (top_k * samples_per_expert)}") crit = (num_global_experts, indices_s, locations_s, gates_s, capacity) top_experts = topk_indices @@ -76,7 +75,7 @@ def extract_critical_encode(ctx, x, scores, top_k, loss_fn=losses.gshard_loss, c dispatcher = TutelMoeFastDispatcher(num_global_experts, 0, x.size(-1), x.dtype) dispatcher.update(indices_s, locations_s, gates_s, capacity, is_postscore=True) - assert dispatcher.dtype == torch.float16 and x.dtype == torch.float16 and torch.float16 == dispatcher.original_dtype + # assert dispatcher.dtype == torch.float16 and x.dtype == torch.float16 and torch.float16 == dispatcher.original_dtype x = GatingEncoder.forward(ctx, dispatcher, x) ctx.original_x_shape = x.size() return x.view(num_global_experts, -1, x.size(-1)), tokens_per_expert, dispatcher @@ -89,7 +88,7 @@ def encode_bwd(ctx, grad_y): def decode_fwd(ctx, x, dispatcher): #dispatcher.decode(x).view(-1, x.size(-1)) - assert dispatcher.dtype == torch.float16 and x.dtype == torch.float16 and torch.float16 == dispatcher.original_dtype + # assert dispatcher.dtype == torch.float16 and x.dtype == torch.float16 and torch.float16 == dispatcher.original_dtype out = GatingDecoder.forward(ctx, dispatcher, x, *dispatcher.gates_) return out diff --git a/internlm/model/moe/moe.py b/internlm/model/moe/moe.py index 0bd35e5b..ea1072d7 100644 --- a/internlm/model/moe/moe.py +++ b/internlm/model/moe/moe.py @@ -9,6 +9,7 @@ from internlm.model.moe.gshard_layer import GShardMoELayer from internlm.model.moe.megablocks.megablock_dmoe import MegaBlockdMoE from internlm.model.moe.megablocks.megablock_moe import MegaBlockMoE +from internlm.model.moe.ampipe.moe_layer import AmpipeMegaBlockMoE from internlm.utils.logger import get_logger # global llm logger @@ -24,6 +25,8 @@ def new_moe_layer(moe_type: str, **kwargs): return MegaBlockMoE(**kwargs) elif moe_type == "MegaBlock-Dropless": return MegaBlockdMoE(**kwargs) + elif moe_type == "AMPipe": + return AmpipeMegaBlockMoE(**kwargs) else: raise ValueError(f"Unsupported model type: {moe_type}") diff --git a/internlm/model/ops/norm.py b/internlm/model/ops/norm.py index 86280681..c4f451ee 100644 --- a/internlm/model/ops/norm.py +++ b/internlm/model/ops/norm.py @@ -1,6 +1,7 @@ # adopted from https://github.com/NVIDIA/apex/blob/master/apex/normalization/fused_layer_norm import numbers +import importlib import torch from torch.nn import init @@ -12,13 +13,14 @@ logger = get_logger(__file__) internlm_accelerator = get_accelerator() -try: - from apex.normalization.fused_layer_norm import mixed_dtype_fused_rms_norm_affine +# try: +from apex.normalization.fused_layer_norm import mixed_dtype_fused_rms_norm_affine +from apex._autocast_utils import _cast_if_autocast_enabled - apex_rmsnorm_impl = True -except (ModuleNotFoundError, ImportError): - logger.warning("The torch implementation for MixFusedRMSNorm is slower than apex. Please note this!") - apex_rmsnorm_impl = False +apex_rmsnorm_impl = True +# except (ModuleNotFoundError, ImportError): +# logger.warning("The torch implementation for MixFusedRMSNorm is slower than apex. Please note this!") +# apex_rmsnorm_impl = False try: from deeplink_ext.internevo_ops import MixedFusedRMSNorm as _RMSNormDIPU @@ -53,6 +55,58 @@ def manual_rms_norm(my_input, weight, normalized_shape, eps, add_unit_offset=Fal else: return weight * my_input +global fused_layer_norm_cuda +fused_layer_norm_cuda = None + +class FusedRMSNormAffineFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, input, weight, normalized_shape, eps, memory_efficient=False): + global fused_layer_norm_cuda + if fused_layer_norm_cuda is None: + fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") + ctx.normalized_shape = normalized_shape + ctx.eps = eps + ctx.memory_efficient = memory_efficient + input_ = input.contiguous() + weight_ = weight.contiguous() + output, invvar = fused_layer_norm_cuda.rms_forward_affine( + input_, ctx.normalized_shape, weight_, ctx.eps) + if ctx.memory_efficient: + ctx.save_for_backward(output, weight_, invvar) + else: + ctx.save_for_backward(input_, weight_, invvar) + return output + + @staticmethod + def backward(ctx, grad_output): + input_or_output, weight_, invvar = ctx.saved_tensors + grad_input = grad_weight = None + grad_input, grad_weight = fused_layer_norm_cuda.rms_backward_affine( + grad_output.contiguous(), invvar, input_or_output, + ctx.normalized_shape, weight_, ctx.eps, ctx.memory_efficient + ) + return grad_input, grad_weight, None, None, None + +class FusedRMSNormAffineMixedDtypesFunction(FusedRMSNormAffineFunction): + + @staticmethod + def forward(ctx, input, weight, normalized_shape, eps, memory_efficient=False): + global fused_layer_norm_cuda + if fused_layer_norm_cuda is None: + fused_layer_norm_cuda = importlib.import_module("fused_layer_norm_cuda") + ctx.normalized_shape = normalized_shape + ctx.eps = eps + ctx.memory_efficient = memory_efficient + input_ = input.contiguous() + weight_ = weight.contiguous() + output, invvar = fused_layer_norm_cuda.rms_forward_affine_mixed_dtypes( + input_, ctx.normalized_shape, weight_, ctx.eps + ) + if ctx.memory_efficient: + ctx.save_for_backward(output, weight_, invvar) + else: + ctx.save_for_backward(input_, weight_, invvar) + return output class _RMSNorm(torch.nn.Module): """A generic module for RMS normalization.""" @@ -77,15 +131,15 @@ def forward(self, _input: torch.Tensor): return _norm_func(_input, self.weight, self.normalized_shape, self.eps, self.add_unit_offset) - def explicit_forward(self, ctx, _input: torch.Tensor): + def explicit_fwd(self, ctx, _input: torch.Tensor): if apex_rmsnorm_impl: - args = _cast_if_autocast_enabled(input, self.weight, self.normalized_shape, self.eps) + args = _cast_if_autocast_enabled(_input, self.weight, self.normalized_shape, self.eps) with torch.amp.autocast('cuda', enabled=False): return FusedRMSNormAffineMixedDtypesFunction.forward(ctx, *args) else: assert False - def explicit_forward(self, ctx, grad_output: torch.Tensor): + def explicit_bwd(self, ctx, grad_output: torch.Tensor): if apex_rmsnorm_impl: with torch.amp.autocast('cuda', enabled=False): grad_input, grad_weight, *_ = FusedRMSNormAffineMixedDtypesFunction.backward(ctx, grad_output)