Python 3.x - torch を使用するとマルチプロセッシングが遅くなる

okwaves2024-01-25  7

実行に平均 9 秒かかる関数 func があります。しかし、マルチプロセッシングを使用して並列化しようとすると (torch.multiprocessing を使用した場合でも)、各推論に平均 20 秒かかるのはなぜでしょうか?

func は、patient_name を受け取り、その患者のデータの推論でトーチ モデルを実行する推論関数です。

device = torch.device(torch.device('cpu'))

def func(patient_name):
    
    data = np.load(my_dict[system_name]['data_path'])
    model_state = torch.load(my_dict[system_name]['model_state_path'],map_location='cpu')
    
    model = my_net(my_dict[system_name]['HPs'])
    model = model.to(device)

    model.load_state_dict(model_state)
    model.eval()
    
    result = model(torch.FloatTensor(data).to(device))
    return result

from torch.multiprocessing import pool

core_cnt = 10

pool = Pool(core_cnt)
out = pool.starmap(func, pool_args)

1

提案を含めて回答しましたが、multiprocessing と torch.multiprocessing をどのように使用したかを示していただければ幸いです。あなたがそうするとき、私は答えを更新します、私は該当する場合

– プロコ

2020 年 9 月 3 日 12:59



------------------------

提供されたデータを使用したモデル アーキテクチャの推論がすでにかなりの計算能力を使用しているかどうかを確認してください。これにより、OS プロセス スケジューラが各プロセス間で切り替えることになり、さらに時間がかかります。

また、関数内で毎回モデルをロードします。データ オブジェクトをディスクから読み取るよりも、プロセス間でデータ オブジェクトをコピーするほうが常に高速です (これがマルチプロセッシングの戦略、または torch.multiprocessing でモデルを完全に共有します)

1

マルチプロセッシングを実行する方法を質問に追加しました。 pytorch を含めずにマルチプロセッシングを実行するためにこの関数をテストしました。データの読み込みなどは正常に実行されます。実行ごとにかかる時間は同じです。このモデル(torch.FloatTensor(data).to(device)) を実際に実行しているときにのみ速度が低下します。これは、患者ごとにモデルを構築する特定のケースであることに注意してください。この選択は問題により関連しているため、ここではモデルが異なります。

– アナルキ

2020 年 9 月 3 日 13:06

総合生活情報サイト - OKWAVES
総合生活情報サイト - OKWAVES
生活総合情報サイトokwaves(オールアバウト)。その道のプロ(専門家)が、日常生活をより豊かに快適にするノウハウから業界の最新動向、読み物コラムまで、多彩なコンテンツを発信。