Pytorch筆記: Quantization Aware Training (QAT)

Note of Quantization Aware Training (QAT) in Pytorch

Posted by imprld01 on Friday, December 10, 2021

目錄


Natively Supported Backends

Content From Pytorch Official Website: When preparing a quantized model, it is necessary to ensure that qconfig and the engine used for quantized computations match the backend on which the model will be executed. The qconfig controls the type of observers used during the quantization passes. The qengine controls whether fbgemm or qnnpack specific packing function is used when packing weights for linear and convolution functions and modules.

pytorch作quantize運算有兩種backend選擇,分別為 fbgemm 以及 qnnpack,分別對應x86和ARM:

  • x86 CPUs with AVX2 support or higher (w/o AVX2 some ops have inefficient implementations)

    # set the qconfig for QAT
    qconfig = torch.quantization.get_default_qat_qconfig('fbgemm')
    # set the qengine to control weight packing
    torch.backends.quantized.engine = 'fbgemm'
    
  • ARM CPUs (typically found in mobile/embedded devices)

    # set the qconfig for QAT
    qconfig = torch.quantization.get_default_qat_qconfig('qnnpack')
    # set the qengine to control weight packing
    torch.backends.quantized.engine = 'qnnpack'
    

Procedure of QAT

  1. 手動將addmulcat等op換成torch.nn.quantized.FloatFunctional裡面的op。這個步驟主要用來適應TinyNeuralNetwork套件,否則一般Pytorch的QAT並不需要。

    self.fadd = torch.nn.quantized.FloatFunctional()
    
    #y = x + b
    y = self.fadd.add(x, b)
    
  2. 產生QNN,在開頭與結尾加入QuantStubDeQuantStub

    class MyNN(nn.Module):
    
        def __init__(self):
            super(MyNN, self).__init__()
            self.conv1 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
            self.bn1 = nn.BatchNorm2d(64)
            self.relu1 = nn.ReLU(inplace=True)
            self.conv2 = nn.Conv2d(64, 64, kernel_size=3, stride=1, padding=1)
            self.bn2 = nn.BatchNorm2d(64)
            self.relu2 = nn.ReLU(inplace=True)
            self.quant = torch.quantization.QuantStub()
            self.deqnt = torch.quantization.DeQuantStub()
            self.ffadd = torch.nn.quantized.FloatFunctional()
    
        def forward(self, x):
            x = self.quant(x)
            #
            x = self.conv1(x)
            x = self.bn1(x)
            a = self.relu1(x)
            #
            x = self.conv2(x)
            x = self.bn2(x)
            b = self.relu2(x)
            #
            x = self.ffadd.add(a, b)
            #
            x = self.deqnt(x)
            return x
    
    # create nn
    mynn= MyNN()
    
  3. 設定QAT參數:engine

    torch.backends.quantized.engine = 'qnnpack'
    
  4. 設定要fuse的部分

    fuse_list = [['conv1', 'bn1', 'relu1'],
                 ['conv2', 'bn2', 'relu2']]
    mynn = torch.quantization.fuse_modules(mynn, fuse_list, inplace=False)
    
  5. 設定QAT參數:qconfig。若要特別量化成Affine Asymmetric的UINT8,可參考下面if中的範例。

    qcfg = torch.quantization.get_default_qat_qconfig('qnnpack')
    if True==USE_UINT8:
        qact = torch.quantization.FakeQuantize.with_args(observer=torch.quantization.MovingAverageMinMaxObserver,
                                                         quant_min=0, quant_max=255, dtype=torch.quint8,
                                                         qscheme=torch.per_tensor_affine, reduce_range=False)
        qcfg = torch.quantization.QConfig(activation=qact, weight=qcfg.weight)
    
    mynn.qconfig = qcfg
    
  6. 套用設定好的qconfig與observer,這時候可以打印出來與原始NN比較看看。這一步主要是插入訓練用的量化運算子 (fake-quantization)。

    mynn = torch.quantization.prepare_qat(mynn, inplace=False)
    print(mynn)
    

    值得注意的是inplace參數設定,曾經遇過設定為True,結果第7步驟作convert時產生問題,很大一部分原因在inplace可能會為了節省memory而不保留之前的運算結果,在模型結構上也相對應有可能會被合併簡化,如果convert的相容性沒有處理到這樣的案例,就會產生問題。

  7. 最後就可以開始training,結束之後會透過convert將model轉為quantized model。這一步主要是生成最終的量化模型,基礎上是將一些訓練時才需要的資料去除,例如qconfig。

    my_train_loops(mynn)
    
    torch.quantization.convert(mynn, inplace=False)
    print(mynn)
    

Good Tool to Convert A Quantized Torch Model to A Quantized TFLite Model

這個套件是將Pytroch繁複的QAT設定步驟,透過程式自動化應用QAT的工具,省去自行修改程式碼的功夫,如果將這個工具的底層攤開來其實也就是如上步驟而已。

GitHub - alibaba/TinyNeuralNetwork

這個工具還提供了能將Pytorch QAT Model轉成tflite格式的功能,有利於佈署至Edge Device上:

with torch.no_grad():
    qmodel = copy.deepcopy(mynn)
    torch.quantization.convert(qmodel, inplace=False)
    #
    torch.backends.quantized.engine = 'qnnpack'
    converter = TFLiteConverter(qmodel.module,
                                torch.randn(1, 64, nn_h, nn_w,
                                tflite_path="qmodel.tflite")
    converter.convert()

Reference

  1. Quantization 的那些事

  2. Quantization - PyTorch 1.10.0 documentation

  3. (beta) Static Quantization with Eager Mode in PyTorch - PyTorch Tutorials 1.10.0+cu102 documentation


comments powered by Disqus