Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

只训练目标检测 ONLY DETECTION #226

Open
PengYuan0 opened this issue Jan 21, 2025 · 1 comment
Open

只训练目标检测 ONLY DETECTION #226

PengYuan0 opened this issue Jan 21, 2025 · 1 comment

Comments

@PengYuan0
Copy link

加载了End-to-end的预训练模型,只想更新模型中的det目标检测部分,所以设置了_C.TRAIN.DET_ONLY = True ,冻结了可行驶区域和车道线相关的网络,修改了dataset的加载方式,只加载目标检测的数据集,修改了loss.py,可行驶区域和车道线的loss为0,修改了train.py里面的权重加载代码,只加载requires_grad=true的部分,为什么就这样,可行驶区域和车道线的模型也被影响了???具体表现在我训练出来的第一轮epoch跑demo.py时,车道线和可行驶区域的分割很正常,但是慢慢的比如第30轮、40轮、100轮,跑出来效果十分不好,具体可以看下方的图片。

The End-to-end pre-trained model was loaded and only wanted to update the det target detection part of the model, so _C.TRAIN.DET_ONLY = True was set, the network related to the driveable area and lane lines was frozen, and the loading method of dataset was modified to load only the target detection dataset. Modified loss.py, the loss of driveable area and lane line is 0, modified the weight loading code in train.py, only loading the part of requires_grad=true,Why is it that models of driveable areas and lane lines are also affected?? Specifically, when I trained the first round of epoch running demo.py, the division of lane lines and driveable areas was normal, but slowly, such as the 30th, 40th, and 100th rounds, the running effect was very bad. For details, you can see the picture below.

train.py:
if os.path.exists(cfg.MODEL.PRETRAINED_DET):
logger.info("=> loading model weight in det branch from '{}'".format(cfg.MODEL.PRETRAINED))
det_idx_range = [str(i) for i in range(0,25)]
model_dict = model.state_dict()
checkpoint_file = cfg.MODEL.PRETRAINED_DET
checkpoint = torch.load(checkpoint_file)
begin_epoch = checkpoint['epoch']
last_epoch = checkpoint['epoch']
checkpoint_dict = {k: v for k, v in checkpoint['state_dict'].items() if k.split(".")[1] in det_idx_range}
checkpoint_dict = {k: v for k, v in checkpoint_dict.items() if
k in model_dict and model_dict[k].requires_grad}
model_dict.update(checkpoint_dict)
model.load_state_dict(model_dict)
logger.info("=> loaded det branch checkpoint '{}' ".format(checkpoint_file))

loss.py
def _forward_impl(self, predictions, targets, shapes, model):
"""

    Args:
        predictions: predicts of [[det_head1, det_head2, det_head3], drive_area_seg_head, lane_line_seg_head]
        targets: gts [det_targets, segment_targets, lane_targets]
        model:

    Returns:
        total_loss: sum of all the loss
        head_losses: list containing losses

    """
    cfg = self.cfg
    device = targets[0].device
    lcls, lbox, lobj = torch.zeros(1, device=device), torch.zeros(1, device=device), torch.zeros(1, device=device)
    tcls, tbox, indices, anchors = build_targets(cfg, predictions[0], targets[0], model)  # targets

    # Class label smoothing https://arxiv.org/pdf/1902.04103.pdf eqn 3
    cp, cn = smooth_BCE(eps=0.0)

    BCEcls, BCEobj, BCEseg = self.losses

    # Calculate Losses
    nt = 0  # number of targets
    no = len(predictions[0])  # number of outputs
    balance = [4.0, 1.0, 0.4] if no == 3 else [4.0, 1.0, 0.4, 0.1]  # P3-5 or P3-6

    # calculate detection loss
    for i, pi in enumerate(predictions[0]):  # layer index, layer predictions
        b, a, gj, gi = indices[i]  # image, anchor, gridy, gridx
        tobj = torch.zeros_like(pi[..., 0], device=device)  # target obj

        n = b.shape[0]  # number of targets
        if n:
            nt += n  # cumulative targets
            ps = pi[b, a, gj, gi]  # prediction subset corresponding to targets

            # Regression
            pxy = ps[:, :2].sigmoid() * 2. - 0.5
            pwh = (ps[:, 2:4].sigmoid() * 2) ** 2 * anchors[i]
            pbox = torch.cat((pxy, pwh), 1).to(device)  # predicted box
            iou = bbox_iou(pbox.T, tbox[i], x1y1x2y2=False, CIoU=True)  # iou(prediction, target)
            lbox += (1.0 - iou).mean()  # iou loss

            # Objectness
            tobj[b, a, gj, gi] = (1.0 - model.gr) + model.gr * iou.detach().clamp(0).type(tobj.dtype)  # iou ratio

            # Classification
            # print(model.nc)
            if model.nc > 1:  # cls loss (only if multiple classes)
                t = torch.full_like(ps[:, 5:], cn, device=device)  # targets
                t[range(n), tcls[i]] = cp
                lcls += BCEcls(ps[:, 5:], t)  # BCE
        lobj += BCEobj(pi[..., 4], tobj) * balance[i]  # obj loss

    # drive_area_seg_predicts = predictions[1].view(-1)
    # drive_area_seg_targets = targets[1].view(-1)
    # lseg_da = BCEseg(drive_area_seg_predicts, drive_area_seg_targets)

    # lane_line_seg_predicts = predictions[2].view(-1)
    # lane_line_seg_targets = targets[2].view(-1)
    # lseg_ll = BCEseg(lane_line_seg_predicts, lane_line_seg_targets)

    # metric = SegmentationMetric(2)
    # nb, _, height, width = targets[1].shape
    # pad_w, pad_h = shapes[0][1][1]
    # pad_w = int(pad_w)
    # pad_h = int(pad_h)
    # _,lane_line_pred=torch.max(predictions[2], 1)
    # _,lane_line_gt=torch.max(targets[2], 1)
    # lane_line_pred = lane_line_pred[:, pad_h:height-pad_h, pad_w:width-pad_w]
    # lane_line_gt = lane_line_gt[:, pad_h:height-pad_h, pad_w:width-pad_w]
    # metric.reset()
    # metric.addBatch(lane_line_pred.cpu(), lane_line_gt.cpu())
    # IoU = metric.IntersectionOverUnion()
    # liou_ll = 1 - IoU

    s = 3 / no  # output count scaling
    lcls *= cfg.LOSS.CLS_GAIN * s * self.lambdas[0]
    lobj *= cfg.LOSS.OBJ_GAIN * s * (1.4 if no == 4 else 1.) * self.lambdas[1]
    lbox *= cfg.LOSS.BOX_GAIN * s * self.lambdas[2]

    # lseg_da *= cfg.LOSS.DA_SEG_GAIN * self.lambdas[3]
    # lseg_ll *= cfg.LOSS.LL_SEG_GAIN * self.lambdas[4]
    # liou_ll *= cfg.LOSS.LL_IOU_GAIN * self.lambdas[5]

    
    if cfg.TRAIN.DET_ONLY or cfg.TRAIN.ENC_DET_ONLY or cfg.TRAIN.DET_ONLY:
        lseg_da = 0
        lseg_ll = 0
        liou_ll = 0
        
    if cfg.TRAIN.SEG_ONLY or cfg.TRAIN.ENC_SEG_ONLY:
        lcls = 0
        lobj = 0
        lbox = 0

    if cfg.TRAIN.LANE_ONLY:
        lcls = 0 * lcls
        lobj = 0 * lobj
        lbox = 0 * lbox
        lseg_da = 0 * lseg_da

    if cfg.TRAIN.DRIVABLE_ONLY:
        lcls = 0 * lcls
        lobj = 0 * lobj
        lbox = 0 * lbox
        lseg_ll = 0 * lseg_ll
        liou_ll = 0 * liou_ll

    loss = lbox + lobj + lcls
    # loss = lseg
    # return loss * bs, torch.cat((lbox, lobj, lcls, loss)).detach()
    return loss, (lbox.item(), lobj.item(), lcls.item(), loss.item())
@PengYuan0
Copy link
Author

图片上传失败了,但问题应该描述清楚了。

The photo upload failed, but the problem should have been described clearly.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant