SeqMRI train process

  • Post author:
  • Post category:其他




LOUPE大体流程

LOUPE(
  (samplers): ModuleList(
    (0): LOUPESampler(
      (gen_mask): LineConstrainedProbMask()
    )
  )
  (reconstructor): LOUPEUNet(
    (down_sample_layers): ModuleList(
      (0): ConvBlock(in_chans=2, out_chans=64, drop_prob=0.0)
      (1): ConvBlock(in_chans=64, out_chans=128, drop_prob=0.0)
      (2): ConvBlock(in_chans=128, out_chans=256, drop_prob=0.0)
      (3): ConvBlock(in_chans=256, out_chans=512, drop_prob=0.0)
    )
    (conv): ConvBlock(in_chans=512, out_chans=512, drop_prob=0.0)
    (up_sample_layers): ModuleList(
      (0): ConvBlock(in_chans=1024, out_chans=256, drop_prob=0.0)
      (1): ConvBlock(in_chans=512, out_chans=128, drop_prob=0.0)
      (2): ConvBlock(in_chans=256, out_chans=64, drop_prob=0.0)
      (3): ConvBlock(in_chans=128, out_chans=64, drop_prob=0.0)
    )
    (conv2): Sequential(
      (0): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
      (1): Conv2d(32, 1, kernel_size=(1, 1), stride=(1, 1))
      (2): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1))
    )
  )
)

in loup_envs_ocmr.py

train_data, val_data, test_data = self._create_datasets()

在这里插入图片描述

display_data = [val_data[i] for i in range(0, len(val_data))]

in real_ocmr_data.py

return self.transform(

kspace,

torch.zeros(kspace.shape[1]),

target,

fname.name

)

在这里插入图片描述

in loup_envs_ocmr.py

train_loader, val_loader, test_loader, display_loader

在这里插入图片描述

train_loupe_ocme.py

policy = NonRLTrainer(args, env, torch.device(args.device))

在这里插入图片描述

non_rl.py

    def _train_loupe(self):
        for epoch in range(self.start_epoch, self.end_epoch):
            self.epoch = epoch
            train_loss, train_time = self.train_epoch()
    def train_epoch(self):
        self.model.train()
        losses = []
        targets, preds = [], []
        metrics = Metrics(METRIC_FUNCS)
        avg_loss = 0.
        start_epoch = start_iter = time.perf_counter()
        global_step = self.epoch * len(self.train_loader)

        for iter, data in enumerate(self.train_loader):
            # self.scheduler.step()
            # input: [batch_size, num_channels, height, width] denoted as NCHW in other places
            # label: label of the current image (0~9 for mnist/fashion-mnist) default: -1
            # target: a copy of the input image for computing reconstruction loss in [NCHW]
            kspace, _, input, label, *ignored= data

            # adapt data to loupe
            target = input.clone().detach()
            target = transforms.complex_abs(target).unsqueeze(1)

            input = input.to(self.options.device)
            target = target.to(self.options.device)
            kspace = kspace.to(self.options.device)
            # label = label.to(self.options.device)

            """if self.options.noise_type == 'gaussian':
                kspace = transforms.add_gaussian_noise(self.options, kspace, mean=0., std=self.options.noise_level)
            """

            pred_dict = self.model(target, kspace)

            if (self.epoch == 0 or (self.epoch+1) % 1 == 0) and iter ==  0:
                data_for_vis_name = 'train_epoch={}_iter={}'.format(str(self.epoch+1), str(iter+1))
                self.model.visualize_and_save(self.options, self.epoch, data_for_vis_name)

            output = pred_dict['output']
            target_dict = {'target': target, 'label': label, 'kspace': kspace}
            meta = {'entropy_weight': self.options.entropy_weight, 'recon_weight': self.options.recon_weight,
                'kspace_weight': self.options.kspace_weight,
                'uncertainty_weight': self.options.uncertainty_weight if 'uncertainty_weight' in self.options.__dict__ else 0}

            loss, log_dict = self.model.loss(pred_dict, target_dict, meta, self.options.loss_type)

            self.optimizer.zero_grad()
            loss.backward()
            # torch.nn.utils.clip_grad_norm_(self.model.parameters(), 10)

            self.optimizer.step()

            self.writer.add_scalar('Train_Loss', loss.item(), global_step + iter)

            losses.append(loss.item())

            # target: 16*1*32*32
            # output: 16*1*32*32

            if isinstance(output, list):
                output = output[-1]

            target = target.cpu().detach().numpy()
            pred = output.cpu().detach().numpy()

            if iter % self.options.report_interval == 0:
                self.logger.info(
                    f'Epoch = [{1 + self.epoch:3d}/{self.options.num_epochs:3d}] '
                    f'Iter = [{iter:4d}/{len(self.train_loader):4d}] '
                    f'Time = {time.perf_counter() - start_iter:.4f}s',
                )
                for key, val in log_dict.items():
                    print('{} = {}'.format(key, val))

            start_iter = time.perf_counter()

            for t, p in zip(target, pred):
                metrics.push(t, p)

        print(metrics)
        self.writer.add_scalar('Train_MSE', metrics.means()['MSE'], self.epoch)
        self.writer.add_scalar('Train_NMSE', metrics.means()['NMSE'], self.epoch)
        self.writer.add_scalar('Train_PSNR', metrics.means()['PSNR'], self.epoch)
        self.writer.add_scalar('Train_SSIM', metrics.means()['SSIM'], self.epoch)

        return np.mean(np.array(losses)), time.perf_counter() - start_epoch
            self.scheduler.step(epoch)
            dev_loss, mean_sparsity, dev_time = self.evaluate()
    def evaluate(self):
        self.model.eval()
        losses = []
        sparsity = []
        targets, preds = [], []
        metrics = Metrics(METRIC_FUNCS)
        start = time.perf_counter()

        with torch.no_grad():
            for iter, data in enumerate(self.dev_loader):
                # input: [batch_size, num_channels, height, width] denoted as NCHW in other places
                # label: label of the current image (0~9 for mnist/fashion-mnist) default: -1
                # target: a copy of the input image for computing reconstruction loss in [NCHW]
                kspace, _, input, label, *ignored = data

                # adapt data to loupe
                target = input.clone().detach()
                target = transforms.complex_abs(target).unsqueeze(1)

                input = input.to(self.options.device)
                target = target.to(self.options.device)
                kspace = kspace.to(self.options.device)
                # label = label.to(self.options.device)

                pred_dict = self.model(target, kspace)

                if (self.epoch == 0 or (self.epoch+1) % 1 == 0) and iter == 0:
                    data_for_vis_name = 'eval_epoch=' + str(self.epoch+1)
                    self.model.visualize_and_save(self.options, self.epoch, data_for_vis_name)

                output = pred_dict['output']
                # only use the last reconstructed image to compute loss
                if isinstance(output, list):
                    output = output[-1]

                target_dict = {'target': target, 'label': label, 'kspace':kspace}
                meta = {'entropy_weight': self.options.entropy_weight, 'recon_weight': self.options.recon_weight,
                'uncertainty_weight': 0, 'kspace_weight': self.options.kspace_weight}

                loss, log_dict = self.model.loss(pred_dict, target_dict, meta, self.options.loss_type)

                mask = pred_dict['mask']
                sparsity.append(torch.mean(mask).item())
                losses.append(loss.item())

                # target: 16*1*32*32
                # output: 16*1*32*32
                target = target.cpu().numpy()
                pred = output.cpu().numpy()

                for t, p in zip(target, pred):
                    metrics.push(t, p)

            print(metrics)
            self.writer.add_scalar('Dev_MSE', metrics.means()['MSE'], self.epoch)
            self.writer.add_scalar('Dev_NMSE', metrics.means()['NMSE'], self.epoch)
            self.writer.add_scalar('Dev_PSNR', metrics.means()['PSNR'], self.epoch)
            self.writer.add_scalar('Dev_SSIM', metrics.means()['SSIM'], self.epoch)

            self.writer.add_scalar('Dev_Loss', np.mean(losses), self.epoch)

        return np.mean(losses), np.mean(sparsity), time.perf_counter() - start

在这里插入图片描述

loss在loupe.py

    def loss(self, pred_dict, target_dict, meta, loss_type):
        """
        Args:
            pred_dict:
                output: reconstructed image from downsampled kspace measurement
                energy: negative entropy of the probability mask
                mask: the binazried sampling mask (used for visualization)

            target_dict:
                target: original fully sampled image

            meta:
                recon_weight: weight of reconstruction loss
                entropy_weight: weight of the entropy loss (to encourage exploration)
        """
        target = target_dict['target']
        pred = pred_dict['output']
        energy = pred_dict['energy']
        
        if loss_type == 'l1':
            reconstruction_loss = F.l1_loss(pred, target, size_average=True) 
        elif loss_type == 'ssim':
            reconstruction_loss = -torch.mean(compute_ssim_torch(pred, target))
        elif loss_type == 'psnr':
            reconstruction_loss = - torch.mean(compute_psnr_torch(pred, target))
        else:
            raise NotImplementedError

        entropy_loss = torch.mean(energy)

        loss = entropy_loss * meta['entropy_weight'] + reconstruction_loss * meta['recon_weight']

        log_dict = {'Total Loss': loss.item(), 'Entropy': entropy_loss.item(), 'Reconstruction': reconstruction_loss.item()}

        return loss, log_dict

一直循环def _train_loupe(self):直到epoch结束



train的过程



从初始化mask到得到zero-filled image

pred_dict = self.model(target, kspace)

我这里设的batch_size=1,

所以target(1, 1, 256, 256) , target.min()=0, target.max()=1

kspace(1, 256, 256, 2)

在这里插入图片描述

non_rl.py中的 pred_dict = self.model(target, kspace)跳到loupe.py中的 masked_kspace, mask, neg_entropy, data_to_vis_sampler = self.samplers[0](kspace, self.sparsity)

又跳到samplers.py中的prob_mask = self.gen_mask(kspace)又跳到layer.py中的 logits = self.mask

mask = torch.sigmoid(self.slope * logits).view(1, 1, self.mask.shape[0], 1)

在这里插入图片描述

        logits = self.mask
        mask = torch.sigmoid(self.slope * logits).view(1, 1, self.mask.shape[0], 1) #252
        if self.preselect:
            if self.preselect_num % 2 ==0:
                zeros = torch.zeros(1, 1, self.preselect_num // 2, 1).to(input.device) #(1,1,2,1)
                mask = torch.cat([zeros, mask, zeros], dim=2) #(1,1,256,1)----》get prob_mask

samplers.py

        else:
            rescaled_mask = self.rescale(prob_mask, sparsity) #跳到layers.py 得到(1,1,256,1)   
            if self.training:
                binarized_mask = self.binarize(rescaled_mask)#跳到layers.py中的ThresholdRandomMaskSigmoidV1 得到(1,1,256,1)60个1

在这里插入图片描述

layers.py

def RescaleProbMap(batch_x, sparsity):
    """
        Rescale Probability Map
        given a prob map x, rescales it so that it obtains the desired sparsity

        if mean(x) > sparsity, then rescaling is easy: x' = x * sparsity / mean(x)
        if mean(x) < sparsity, one can basically do the same thing by rescaling
                                (1-x) appropriately, then taking 1 minus the result.
    """
    batch_size = len(batch_x) #我设置的batch_size=1
    ret = []
    for i in range(batch_size):
        x = batch_x[i:i+1] #batch_x (1,1,256,1),  x (1,1,256,1)
        xbar = torch.mean(x) #0.4692
        r = sparsity / (xbar) #0.234375/0.4692 =0.4995
        beta = (1-sparsity) / (1-xbar) # 1.4425

        # compute adjucement
        le = torch.le(r, 1).float() #1
        ret.append(le * x * r + (1-le) * (1 - (1 - x) * beta))

    return torch.cat(ret, dim=0) #遍历batch_size最后concat,这里我设的1,所以这里最后cat是(1,1,256,1)
layers.py

class ThresholdRandomMaskSigmoidV1(Function):
    def __init__(self):
        """
            Straight through estimator.
            The forward step stochastically binarizes the probability mask.
            The backward step estimate the non differentiable > operator using sigmoid with large slope (10).
        """
        super(ThresholdRandomMaskSigmoidV1, self).__init__()

    @staticmethod
    def forward(ctx, input):
        batch_size = len(input)
        probs = [] 
        results = [] 

        for i in range(batch_size):
            x = input[i:i+1]

            count = 0 
            while True:
                prob = x.new(x.size()).uniform_()
                result = (x > prob).float()

                if torch.isclose(torch.mean(result), torch.mean(x), atol=1e-3):
                    break

                count += 1 

                if count > 1000:
                    print(torch.mean(prob), torch.mean(result), torch.mean(x))
                    assert 0 

            probs.append(prob)
            results.append(result)

        results = torch.cat(results, dim=0)
        probs = torch.cat(probs, dim=0)
        ctx.save_for_backward(input, probs)

        return results  

……………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………………

samplers.py

            else: #preselect_num_one_side=2
                binarized_mask[..., :self.preselect_num_one_side, :] = 1
                binarized_mask[..., -self.preselect_num_one_side:, :] = 1 #把前两个和后两个都变成了1,一共64个1
     neg_entropy = self._mask_neg_entropy(rescaled_mask) #(1,1,256,1).ax=0,min=-0.6931
     masked_kspace = binarized_mask * kspace

在这里插入图片描述

在这里插入图片描述



reconstruction过程

loupe.py中的 recon = self.reconstructor(zero_filled_recon, 0)跳到reconstructor.py中的

    def forward(self, input, old_recon=None, eps=1e-8):
        else:
            output = input 
LOUPEUNet(
  (down_sample_layers): ModuleList(
    (0): ConvBlock(in_chans=2, out_chans=64, drop_prob=0.0)
    (1): ConvBlock(in_chans=64, out_chans=128, drop_prob=0.0)
    (2): ConvBlock(in_chans=128, out_chans=256, drop_prob=0.0)
    (3): ConvBlock(in_chans=256, out_chans=512, drop_prob=0.0)
  )
  (conv): ConvBlock(in_chans=512, out_chans=512, drop_prob=0.0)
  (up_sample_layers): ModuleList(
    (0): ConvBlock(in_chans=1024, out_chans=256, drop_prob=0.0)
    (1): ConvBlock(in_chans=512, out_chans=128, drop_prob=0.0)
    (2): ConvBlock(in_chans=256, out_chans=64, drop_prob=0.0)
    (3): ConvBlock(in_chans=128, out_chans=64, drop_prob=0.0)
  )
  (conv2): Sequential(
    (0): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
    (1): Conv2d(32, 1, kernel_size=(1, 1), stride=(1, 1))
    (2): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1))
  )
)

在这里插入图片描述



怎么得到mask

loup.py

        masked_kspace, mask, neg_entropy, data_to_vis_sampler = self.samplers[0](kspace, self.sparsity)
samplers.py

    def forward(self, kspace, sparsity): #sparsity=0.234375
        # kspace: NHWC
        # sparsity (float)
        prob_mask = self.gen_mask(kspace)
layers.py

class LineConstrainedProbMask(nn.Module):
    """
    A learnable probablistic mask with the same shape as the kspace measurement.
    The mask is constrinaed to include whole kspace lines in the readout direction
    """
    def __init__(self, shape=[32], slope=5, preselect=False, preselect_num=2):
        super(LineConstrainedProbMask, self).__init__()

    
        else:
            length = shape[0] #32

        self.preselect_num = preselect_num  #2
        self.preselect = preselect #False
        self.slope = slope #5
        init_tensor = self._slope_random_uniform(length)
        self.mask = nn.Parameter(init_tensor)

    def forward(self, input, eps=1e-10):
        """
        Args:
            input (torch.Tensor): Input tensor of shape NHWC

        Returns:
            (torch.Tensor): Output tensor of shape NHWC
        """
        logits = self.mask
        mask = torch.sigmoid(self.slope * logits).view(1, 1, self.mask.shape[0], 1) 

        if self.preselect:
            if self.preselect_num % 2 ==0:
                zeros = torch.zeros(1, 1, self.preselect_num // 2, 1).to(input.device) 
                mask = torch.cat([zeros, mask, zeros], dim=2)
            else:
                raise NotImplementedError()

        return mask 
    def _slope_random_uniform(self, shape, eps=1e-2): #shape=32
        """
            uniform random sampling mask with the shape as half of the kspace measurement
        """
        temp = torch.zeros([shape[0], shape[1]//2]).uniform_(eps, 1-eps)

        # logit with slope factor
        return -torch.log(1./temp-1.) / self.slope



sequential网络

INFO:activemri.baselines.non_rl:SequentialUnet(
  (reconstructor): LOUPEUNet(
    (down_sample_layers): ModuleList(
      (0): ConvBlock(in_chans=2, out_chans=64, drop_prob=0.0)
      (1): ConvBlock(in_chans=64, out_chans=128, drop_prob=0.0)
      (2): ConvBlock(in_chans=128, out_chans=256, drop_prob=0.0)
      (3): ConvBlock(in_chans=256, out_chans=512, drop_prob=0.0)
    )
    (conv): ConvBlock(in_chans=512, out_chans=512, drop_prob=0.0)
    (up_sample_layers): ModuleList(
      (0): ConvBlock(in_chans=1024, out_chans=256, drop_prob=0.0)
      (1): ConvBlock(in_chans=512, out_chans=128, drop_prob=0.0)
      (2): ConvBlock(in_chans=256, out_chans=64, drop_prob=0.0)
      (3): ConvBlock(in_chans=128, out_chans=64, drop_prob=0.0)
    )
    (conv2): Sequential(
      (0): Conv2d(64, 32, kernel_size=(1, 1), stride=(1, 1))
      (1): Conv2d(32, 1, kernel_size=(1, 1), stride=(1, 1))
      (2): Conv2d(1, 1, kernel_size=(1, 1), stride=(1, 1))
    )
  )
  (sampler): Sampler(
    (mask_net): KspaceLineConstrainedSampler(
      (conv_last): Sequential(
        (0): Linear(in_features=327680, out_features=512, bias=True)
        (1): ReLU(inplace=True)
        (2): Linear(in_features=512, out_features=512, bias=True)
        (3): ReLU(inplace=True)
        (4): Linear(in_features=512, out_features=512, bias=True)
        (5): ReLU(inplace=True)
        (6): Linear(in_features=512, out_features=512, bias=True)
        (7): ReLU(inplace=True)
        (8): Linear(in_features=512, out_features=256, bias=True)
      )
    )
  )
)



版权声明:本文为xuru_0927原创文章,遵循 CC 4.0 BY-SA 版权协议,转载请附上原文出处链接和本声明。