torch.hub.loadでImportError: cannot import name 'amp'というエラー

Pythochにはtorch.hubという超簡単に物体検出などができる機能がある。コードはこれだけ。

import torch
model = torch.hub.load('ultralytics/yolov3', 'yolov3')

input_img = 'https://ultralytics.com/images/zidane.jpg'

# 推論
results = model(input_img)

# 結果の出力
results.print()

中身としてはGithubからコードを引っ張ってきて実行しているだけ。ただ、コードの中には特定の環境でしか動かないものがある。そのひとつがAMP。AMPは一部レイヤーの計算精度(FP32, PF16など)を落として計算することで、制度を保ちながら計算速度を向上させよう、という機能である。ただしこれはPytorchではLinuxにしか対応していない。そのため、AMPを使用しているコードはWindowsで実行しようとすると以下のようなエラーになる。

  File "C:\Users\hagehage/.cache\torch\hub\ultralytics_yolov3_master\utils\plots.py", line 61, in check_font
    torch.hub.download_url_to_file(url, str(font), progress=False)

AttributeError: module 'torch.hub' has no attribute 'download_url_to_file'

回避方法としては、コードの中からAMPを使用しているところを取り除いてしまえばいい。AMPは使いたいところで、以下のようにWith文でくくてつかう。なのでこのWith文をとってしまえばいい。

with amp.autocast(enabled=p.device.type != 'cpu'):
    # Inference
    y = self.model(x, augment, profile)[0]  # forward
    t.append(time_sync())

↑は↓になる。

# Inference
y = self.model(x, augment, profile)[0]  # forward
t.append(time_sync())

これでOK