PyTorch
ウィキペディアから
PyTorchは、コンピュータビジョンや自然言語処理で利用されている[2]Torchを元に作られた、Pythonのオープンソースの機械学習ライブラリである[3][4][5]。最初はFacebookの人工知能研究グループAI Research lab(FAIR)により開発された[6][7][8]。PyTorchはフリーでオープンソースのソフトウェアであり、修正BSDライセンスで公開されている。
さまざまなディープラーニングのソフトウェアがPyTorchを利用して構築されており、その中には、UberのPyro[9]、HuggingFaceのTransformers[10]、Catalyst[11][12]などがある。
PyTorchは以下の機能を備えている:
歴史
FacebookはPyTorchとConvolutional Architecture for Fast Feature Embedding(Caffe2)をメンテナンスしていた。しかし、互換性が無いためPyTorchで定義されたモデルのCaffe2への移行やまたその逆の作業が困難であった。これら2つのフレームワークでモデルを変換することができるように、2017年9月にFacebookとマイクロソフトがOpen Neural Network Exchange(ONNX)プロジェクトを作成した。2018年3月下旬に、Caffe2はPyTorchに併合された[14]。
2019年12月、Preferred Networksは自社開発していたChainerのバージョン7をもって、PyTorchによる研究開発へ順次移行して行くことを発表した [15]。
PyTorchのテンソル
要約
視点
PyTorch はテンソルに Tensor (torch.Tensor
)と呼ばれるクラスを定義しており、それを均質(homogeneous)な多次元の長方形の数値配列の保存と演算に利用している。PyTorch の Tensor は NumPy の多次元配列 (numpy.ndarray
) に似ているが、CUDA が有効な Nvidia のGPU上での演算も可能になっている。NumPyの配列からPyTorchのテンソルへと変換するための専用APIも存在する (torch.from_numpy)。なおPyTorch 2.1以降はPyTorchだけでなくNumPyを使ったコードのGPU向けコンパイルにも対応している[16]。
PyTorch には 32bit 浮動小数点数用の FloatTensor や 16bit 浮動小数点数用の HalfTensor 、32bit 整数用の IntTensor など、さまざまな型のTensorサブタイプが存在する[17][18]。またテンソルにはCPUに配置する torch.*Tensor とGPUに配置する torch.cuda.*Tensor が存在している[18]。それぞれは Tensor.to(...)
メソッドなどを用いることで変換することができる[18]。
また PyTorch のテンソルは機械学習の逆伝播に使われる微分のためのパラメータ(勾配データ (Tensor.grad
) や 微分関数 (Tensor.grad_fn
))を持つことが出来る (requires_grad=True
の場合、詳しくは後述の#autogradモジュールを参照)[19]。勾配データ (grad) は追加メモリを必要とするため、不要な場合に取り除くことが可能となっている(torch.no_grad()
や Tensor.detach()
など)。
複素数テンソル(dtype=torch.cfloat
のテンソル)[20]や量子化テンソル(dtype=torch.quint8
やdtype=torch.qint8
などのテンソル)[21]も存在している。量子化テンソルはスケールやゼロポイントなどの量子化パラメータを持っている[21]。なお4bit量子化向けのdtype=torch.quint4x2
のテンソルもあるが、2023年現在その対応は一部(EmbeddingBag命令)のみとなっている[18]。また8bit浮動小数点数向けの torch.float8_e4m3fn
及び torch.float8_e5m2
のテンソルにも対応しているが対応命令はごく一部となっている[18][(8bit浮動小数点数はNVIDIA H100などの一部GPUのTransformer Engineがネイティブに対応している[22])。
基本のテンソルは長方形であり、非長方形のデータを扱う際には padding で穴埋めする必要があるが、穴の多いデータを穴埋めするのは無駄となるため、非長方形のデータを扱うための疎テンソル (torch.sparse) や ネストされたテンソル (torch.nested) も用意されている。なお、疎テンソル構造はNVIDIA GPUのAmpere以降などでも直接サポートされるようになっているものの、その対応は2:4スパース(半構造化スパース)に限られている[23]が、PyTorch 2.1以降はその疎半構造化テンソルにも試験的に対応している (SparseSemiStructuredTensor)[24][16]。
PyTorchのモデル
![]() | この節の加筆が望まれています。 |
PyTorch のモデルは基本的に torch.nn.Module
の派生クラスとする必要があり[25]、そのtrain()
メソッドとeval()
メソッドにより、トレーニングモードと評価モードの切り替えが可能となっている[25]。
また to(...)
メソッドなどによってモデル全体のパラメータ型の変換が可能であり[25]、それによってモデルパラメータの半精度化(専用のhalf()
メソッドもある[25])やbfloat16化(専用のbfloat16()
メソッドもある[25])を行うことでメモリ使用量を減らすことが可能となっている。ただしモデルパラメータのより高い圧縮を行うためには torch.quantization などで量子化を行う必要がある。
フロントエンド
PyTorch のフロントエンドには Pythonインターフェイスだけでなく C++インターフェイスも存在している[26]。
autogradモジュール
PyTorchは自動微分と呼ばれるメソッドを利用する。recorderは実行された演算を記録し、その後、勾配の計算を行うときに記録した演算を逆方向にリプレイする。このメソッドは、ニューラルネットワークの構築時に特に強力であるため、順方向のパラメータの微分を計算することで、1エポックの計算にかかる時間を節約することができる。
optim
モジュール
torch.optim
は、ニューラルネットワークの構築時に使用されるさまざまな最適化アルゴリズムを実装したモジュールである。通常使用されるメソッドのほとんどはすでに対応しているため、スクラッチで構築する必要がない。
nn
モジュール
PyTorchのautogradは、簡単に計算グラフを定義して勾配を得られるようになっているが、生のautogradは複雑なニューラルネットワークを定義するにはすこし低レベルすぎる場合がある。そのような場合のサポートとしてnn
モジュールが提供されている。
jit
モジュール
TorchScriptとして後述。
バックエンド
要約
視点
![]() | この節の加筆が望まれています。 |
PyTorch のバックエンドは主に C++ で実装されている (ATen)。ATenは外部から直接使うことも可能となっている (C++ではATenライブラリ[27]、Pythonではtorch.ops.aten)。
また PyTorch にはCPUだけでなくGPUバックエンドもあり、NVIDIA製GPU向けにはCUDAで、AMD製GPU向けにはROCmのHIPで、Intel製GPU向けにはSYCLで[28][29][注釈 1]、macOS向けにはMetal Performance Shaders (MPS) で実装されている。CUDAにバージョン毎の互換性の問題があることもあって、PyTorch のバイナリはGPUプラットフォーム毎に別々で提供されている。
AI専用チップ向けのバックエンドでは外部プロジェクトとして Intel Gaudi 向けの habana_frameworks.torch [30]、Huawei Ascend 向けの torch_npu などが存在する[31]。PyTorchではこれら外部バックエンドを登録するための仕組み「PrivateUse1」が用意されている[31][32]。
PyTorch 2.0以降にはOpenAIのTritonなどをバックエンドとして使用できるTorchDynamoも統合されている(#TorchDynamoを参照)。TorchDynamo ではバックエンドの実装を簡略化するために、複合的な Aten 命令を単純な Prims 命令 (torch._prims) へと低下させる機能も提供している[33]。
また PyTorch では別パッケージとして Torch-TensorRT が用意されており、これを使うことにより PyTorch のバックエンドとして NVIDIA の TensorRT を使うことも可能となっている[34][35]。
並列処理
PyTorch は標準で並列処理機能を内蔵しているものの一部のみとなっている。外部プロジェクトの DeepSpeed はより高度な並列処理を実装している。並列処理を簡単に使うための補助ライブラリとして HuggingFace Accelerate や Lightning Fabric(旧Lightning Lite)が存在する。
また既存モデルでは推論の並列処理を簡単に使うための DeepSpeed MII も存在する[36][37]。
またコンピュータ・クラスターを使った学習向けでは外部プロジェクトに MosaicML の Composer が存在し、この Composer はクラウドコンピューティングを使った学習とも相性が良いとされる[38][39][40]。
データ並列 (DP)
PyTorchは以下のデータ並列処理に対応している:
- マルチスレッドデータ並列処理 (torch.nn.DataParallel) - 2024年現在Pythonのグローバルインタプリタロック (GIL) の問題がありスケールしにくい[41]。
- マルチプロセス/マルチマシンデータ並列処理 (DDP; torch.nn.parallel.DistributedDataParallel)
- 完全共有データ並列処理 (FSDP; torch.distributed.fsdp.FullyShardedDataParallel) - FairScale の FSDP を統合した[42]。
前述の FairScale 及び DeepSpeed は、どちらも最適化状態分割 (OSS; ZeRO Stage 1)、最適化および勾配状態分割 (SDP; ZeRO Stage 2)、完全共有データ並列処理 (FSDP; ZeRO Stage 3) の3つに対応している[43][44]。加えて DeepSpeed は ZeRO Stage 3 よりもノード間通信効率の良い ZeRO++ にも対応している[45]。
その他 DeepSpeed にはデータをCPUメモリやNVMeなどへとオフロードするための ZeRO-Infinity[46](及びそのサブセットでホストCPUメモリのみ対応の旧来の ZeRO-Offload[47])も搭載されている。
パイプライン並列 (PP)
PyTorchは以下のパイプライン並列処理に対応している:
- パイプライン並列処理 (torch.distributed.pipeline) - nn.Sequentialにのみ対応[48]。FairScaleの実装 (Fairscale.nn.Pipe) を統合したもので、元々は torchgpipe の GPipe実装に由来する[48][49]。
PyTorchモデルの自動パイプライン並列処理に向けては PiPPy が開発中となっている[50]。また DeepSpeed もパイプライン並列処理に対応している (deepspeed.pipe)[51]。
また推論のパイプライン並列処理にはPetals[52]やFlexGen[53]も存在する。
テンソル並列 (TP)
PyTorchは以下のテンソル並列処理に対応している:
- テンソル並列処理 (torch.distributed.tensor.parallel)[54]
また DeepSpeed も DeepSpeed-Inference で推論におけるテンソル並列処理に対応している[55]。
エキスパート並列
エキスパート並列は複数エキスパート(専門家)の混合 (MoE) レイヤー(Switch Transformerなど)を用いたモデルの並列処理を行う[56]。
このエキスパート並列には DeepSpeed MoE が対応している[56]。
推論の自動ウェイトオフロード
![]() | この節の加筆が望まれています。 |
HuggingFaceのAccelerateは大規模モデルの推論を処理するためにウェイトの自動オフロードに対応している[57][58]。
またDeepSpeed も推論でのデータオフロードにも対応している(DeepSpeed ZeRO-Inference)[59]。
その他
PyTorch は CUDA の複数ストリームを使った明示的なGPU並列処理にも対応している[60]。
ファイル形式
![]() | この節の加筆が望まれています。 |
PyTorch で良く使われているファイル形式は Python における Python オブジェクトの標準シリアライズ(直列化)である Pickle形式 (*.pkl) [61]及びウェイトやバイアスなどのパラメータを無圧縮ZIPアーカイブでまとめたpt形式となっている。PyTorchでのこの形式の読み書きはtorch.load/torch.save
やtorch.jit.load/torch.jit.save
(TorchScript向け)などを使うことで可能となっている。この形式はPyTorch内外を含め広く用いられているものの、Pickle形式は仮想マシンの命令コードを記述したものであり[62]様々な操作が実行できるため、セキュリティのリスクが存在している[61]。
そのため PyTorch 標準形式よりも安全で効率的なパラメータ保存形式が開発されてきており、その中の一つ HuggingFace の safetensors 形式は広く普及してきている。
ONNX出力
![]() | この節の加筆が望まれています。 |
ONNX出力 (torch.onnx) には従来のTorchScript経由によるもの (torch.onnx.export) と新しいTorchDynamo経由によるもの (torch.onnx.dynamo_export) がある[63]。前者は#TorchScriptを参照。
TorchScript
要約
視点
TorchScriptは推論モデル生成のための静的型付け言語である[64]。Pythonのサブセット[65]。TorchScriptプログラムはPythonに依存しない環境(例: C++)で実行可能であり[66]、JITコンパイルにより動作環境での最適化がおこなわれる。JITコンパイラでは演算fusion等がおこなわれる[67]。
なおTorchScriptモデルの保存にはTorchScriptモデル形式が使われている (torch.jit.load / torch.jit.save) が、この形式はPyTorchモデル形式と同じく拡張子に「.pt」が使われている[68]。またTorchScriptを介したONNX形式への変換 (torch.onnx.export) も可能となっている[69]。また外部の LLVM プロジェクトの Torch-MLIR では TorchScript を介してPyTorchモデルのコンパイルを行うことができる[70]。
実行
TorchScriptはJITオプティマイザ・JITコンパイラ付きのインタプリタにより実行される[71]。内部ではGraphExecutorによる最適化とinterpreterによる実行がおこなわれる。対応している最適化には、入力型に基づくGraphの特殊化[72]、演算子の融合 (Op fusion)、定数の事前計算[73]、不要ノード削除[74]などがある。
生成
TorchScriptコードは直接記述 (torch.jit.script) およびPyTorch実行トレース (torch.jit.trace) により生成される。
以下は前向き (forward) 関数のスクリプト変換の例である:
import torch # PyTorchモジュールを読み込み
def f(x, y): return x + y # 前向き関数を作る
ts_converted = torch.jit.script(f) # Pythonコードから直接TorchScriptに変換
print("Converted:\n%s" % ts_converted.code) # 変換したTorchScriptを表示
ts_traced = torch.jit.trace(f, (torch.tensor([1,2]), torch.tensor([3,4]))) # PyTorchの実行をトレースしてTorchScriptを生成
print("Traced:\n%s" % ts_traced.code) # 生成したTorchScriptを表示
torch.jit.save(ts_converted, "converted.pt") # 変換したTorchScriptをTorchScriptモデル形式で保存
torch.onnx.export(ts_converted, (torch.tensor([1,2]), torch.tensor([3,4])), "converted.onnx") # 上記をONNX形式で出力
torch.jit.script
torch.jit.script
はTorchScriptコンパイラのPythonラッパー関数である[75]。
torch.jit.script
はPythonオブジェクトとして書かれたTorchScriptコードをコンパイラへ渡し、コンパイル結果へPythonからアクセスするためのラッパーオブジェクトを返す。Pythonデコレータとしても利用できる。
TorchScriptコードのコンパイルをおこなうため、全てのTorchScirpt機能を利用できる(c.f. torch.jit.trace
によるトレース)。例えば動的条件分岐を記述できる。
torch.jit.trace
torch.jit.trace
は実行時トレースに基づくPython-to-TorchScriptトランスパイラである[76]。
torch.jit.trace
は実行可能Pythonオブジェクトとその仮入力を引数に取り、これをトレースしながら実行し、トレース結果に基づいてTorchScriptコードを生成する。すなわち実行時トレースに基づいてPythonコードをTorchScriptコードへ変換する。
仮入力を用いたトレースに基づいて生成をおこなうため、利用できるTorchScript機能に制限がある(c.f. torch.jit.script
によるコンパイル)。例えば条件分岐は仮入力が満たす片方のパスしか記録されないため、TorchScriptの動的条件分岐として変換されない。すなわちTorchScriptの動的条件分岐は利用できない[77]。
FX
要約
視点
torch.fx はモデルの高レベルグラフ化・変換・コード生成に関するモジュールである[78][79]。モデルを入力/出力/演算の有向非巡回グラフと見なすIRを定義し[80]、モデル→IR→モデルの変換とIR編集をサポートする。主に次の3つの機能を提供する[81]。
- Model-to-IR: シンボリックトレースによるPython/PyTorchモデルからのIR生成
- IR-to-IR: IRで表現されたグラフの変換
- IR-to-Model: IRからのPython/PyTorchコード生成
FXはPythonをホスト言語とするIRグラフのメタプログラミングであり[82]、グラフの分析・可視化[83]、演算子の書き換えによる量子化[84]、グラフ全体を考慮したop fusion[85]、IRからのアクセラレータコード直接生成[86]など様々な用途に利用される。
モデルエクスポートのための幅広い演算サポートを目指すTorchScript IR[87]と異なり、FX IRは動的な制御フローを含まず演算ノード単位の操作を前提とした高レベルな表現と、それに付随するシンプルなIR生成・IR操作・モデル生成実装を掲げている。
以下は前向き (forward) 関数の変換の例である:
import torch # PyTorchモジュールを読み込み
def f(x, y): return x + y # 前向き関数を作る
fx_traced = torch.fx.symbolic_trace(f) # PyTorchの実行をトレースしてFX Graphを生成
print(fx_traced.graph) # FX GraphのIRを表示
print(fx_traced.code) # FX GraphをPythonコードに再変換して表示
AOT Autograd
AOT Autograd (functorch.compile) は前向き関数と後向き関数をコンパイルするための実験的サブシステムである[88]。FX GraphをTorchScriptのコンパイラによってコンパイルするための functorch.compile.ts_compile も実験的に提供されている[89]。
以下は前向き (forward) 関数の変換の例である:
import torch, functorch # PyTorch 及び functorch モジュールを読み込み
def f(x, y): return x + y # 前向き関数を作る
fx_traced = torch.fx.symbolic_trace(f) # PyTorchの実行をトレースしてFX Graphを生成
ts_compiled = functorch.compile.ts_compile(fx_traced, [torch.Tensor([1,2]), torch.Tensor([3,4])]) # TorchScriptコンパイラでFXグラフをコンパイル
ts_compiled([torch.Tensor([1,2]), torch.Tensor([3,4])]) # コンパイルされた関数を実行
TorchDynamo
TorchDynamo (torch._dynamo) は任意のPythonコードをコンパイルのためのサブシステムである。JITコンパイルするためのtorch.compile
API、事前コンパイルするためのtorch.export
APIもある(後者は#事前コンパイルと量子化参照)。コンパイルできないコードは自動的にPythonインタプリタへとフォールバックされる。2023年にリリースされたPyTorch 2.0で本体へと統合された。
狭義のTorchDynamoはPythonのバイトコードから前述のFX Graphを作るサブシステムであり[33]、FX Graphを各実行環境向けIRへと変換するための各バックエンド(TorchInductor、onnxrtバックエンドなど)と組み合わせてPythonコードのコンパイルを行うことができる[33]。バックエンドを実装しやすくするためのものとして、FX Graphの複合的なAten命令を単純なPrim命令へと置き換えるPrimTorchも存在する[33]。
以下は前向き (forward) 関数の最適化の例である:
import torch
def f(x, y): return x + y # 前向き関数を作る
optimized_f = torch.compile(f, backend="inductor") # 前向き関数をTorchDynamo(+ TorchInductor)で最適化
optimized_f(1, 2) # 最適化された前向き関数を実行
また派生プロジェクトとしてバックエンドにOpenXLAのXLAコンパイラを使うためのPyTorch/XLAも開発中となっているが、こちらも試験的にTorchDynamoへと対応しており[90][91]、場合によってはTorchDynamo+Inductorよりも高速であるとされる[91]。
事前コンパイルと量子化
![]() | この節の加筆が望まれています。 |
PyTorch 2.1以降には事前コンパイル (torch.export) が試験的に実装されている[16]。また量子化 (torch.ao.quantization[注釈 2]) も試験的にtorch.exportに対応している[16]。
量子化では訓練後量子化(Post Training Quantization; PTQ)と量子化意識訓練 (Quantization Aware Training; QAT) の両方に対応している[92]。
モバイル向け展開
![]() | この節の加筆が望まれています。 |
PyTorchモデルのモバイル向け展開にはPyTorch MobileとExecutorchが存在する[93][94]。
PyTorch MobileはTorchScriptを使った軽量インタプリタとなっている[93]が、最適化 (torch.utils.mobile_optimizer) [95]にも対応している[93]。
一方、Executorchはtorch.compileベースとなっている[94]。
またPyTorchのモデルから別の機械学習フレームワークTensorFlowのモバイル向け実装LiteRT(旧称TensorFlow Lite)のモデルへと変換することも可能であり、これには前述のONNX出力でONNX形式を出力した後にそのファイルをonnx2tfでLiteRTモデルへと変換する方法[96]や、GoogleのAI Edge Torchを使う方法が存在する[96]。
モデル / ライブラリ
要約
視点
Transformer モデル
![]() | この節の加筆が望まれています。 |
性能の高い機械学習モデル「Transformer」が人気となるにつれ、PyTorchでも内部プロジェクト外部プロジェクト合わせ数多くの実装が登場した。その代表的なものには以下が存在する:
- torch.nn.Transformer[97] - PyTorch 1.2で実装された[98]。
- Transformers (Hugging Face) - 様々な応用モデルを実装している。PyTorch以外にも対応している。
- Fairseq (Meta AI) - Transformer派生を含む様々な応用モデルを実装している。またNVIDIAによりこのFairseqに基づくTransformerの実装が「Transformer for PyTorch」として提供されている[102]。
- xFormers (Meta AI) - 高速かつ省メモリな複数の実装を含んでいる(CUTLASSを使った実装やOpenAI Tritonを使った実装など)。前述のFairseqや後述のDiffusersでも使えるようになっている[103][104]。
拡散モデル
拡散モデル(確率拡散モデル; Diffusion model)は特に生成AI(ジェネレーティブAI)において人気となり、PyTorchでもその実装が登場した。その代表的なものには以下が存在する:
- Diffusers (Hugging Face) - 単純なU-Netベースの拡散モデルだけでなく、変分オートエンコーダー (VAE) による潜在空間を使った潜在拡散モデル (LDM) やその応用の Stable Diffusion (SD)、埋め込み表現にCLIPの代わりにXLM-RobertaベースのAltCLIPを使ってマルチリンガル化した AltDiffusion[105]、ベクトル量子化VAE (VQ-VAE) による潜在空間を使った VQ-Diffusion、CLIPのテキスト埋め込みを直接使わず画像埋め込みへと変換してからガイドとして使用するDALL-E 2を模したKarlo unCLIP、U-Net部分までTransformerに置き換えたDiT (Diffusion Transformer)、動画生成に対応する Text-to-Video / Text2Video-Zero、オーディオに特化した Audio Diffusion / Latent Audio Diffusion、 MIDIからオーディオを生成する Spectrogram Diffusion、汎用拡散モデルの Versatile Diffusion などを実装している。また追加学習ではTextual Inversion、Text-to-imageファインチューニング、DreamBooth、LoRAなどに対応している。
torchvision
torchvision は PyTorch 向けの画像処理向けのライブラリである。画像分類 (Swin Transformer / MaxVit / Swin Transformer V2 / EfficientNet / ConvNeXt / EfficientNetV2 / RegNet / VisionTransformer等[注釈 3])、動画分類 (Video ResNet / Video S3D / Video MViT / Video SwinTransformer)、セマンティックセグメンテーション (LRASPP / FCN / DeepLabV3)、物体検出 (SSDlite / SSD / FCOS / RetinaNet / Faster R-CNN)、インスタンスセグメンテーション (Mask R-CNN)、キーポイント検出 (Keypoint R-CNN)、オプティカルフロー (RAFT) などに対応している[106]。
torchaudio
torchaudio は PyTorch 向けの音声処理向けのライブラリである。音声認識 (Emformer RNN-T / XLS-R / Conformer RNN-T[107])、視覚音声認識 (AV-ASR[107])、テキスト音声合成 (Tacotron2)、音声強調 (MVDR)、ビームフォーミング (DNNBeamformer[107])、音源分離 (Hybrid Demucs)、音声品質測定 (TorchAudio-Squim[107]) 、テキストへの強制位置合わせ (CTC forced alignment[107]) などに対応している。
その他、音声のベクトル化では wav2vec 2.0、HuBERT、WavLM に対応している[108]。GPUベースのCTCビームサーチ (CUCTCDecoder) にも対応している[107]。
torchtext
![]() | この節の加筆が望まれています。 |
torchtext は PyTorch 向けのテキスト処理ライブラリである。大規模言語モデル (LLM) の XLM-RoBERTa[109] や T5[110] に対応している。なお大規模言語モデルで文書分類、言語モデリング、質疑応答、機械翻訳などの多様なタスクを行うには、モデルをそれぞれへと適応するためのファインチューニングやプロンプティングが必要となる。
その他、テキストのトークン化では正規表現での置換によるトークン化 (RegexTokenizer)、文字バイト対符号化 (BPE) によるトークン化 (CharBPETokenizer)、バイトレベルBPEによるトークン化(CLIPで使われるCLIPTokenizerやGPT-2で使われるGPT2BPETokenizer)、WordPieceによるトークン化(BERTで使われるBERTTokenizer)、事前学習モデルによるトークン化(T5などに使われるSentencePieceTokenizer)に対応している[111][112]。また単語の埋め込みベクトルへの変換では GloVe、FastText、CharNGram に対応している[113]。
TorchRL
TorchRL は PyTorch 向けの強化学習ライブラリである。暗黙的Q学習 (IQL)、深層Q学習 (DQN)、近傍方策最適化 (PPO)、深層決定論的方策勾配法 (DDPG)、双生遅延DDPG (TD3)、Advantage Actor-Critic (A2C)、Soft-Actor-Critic (SAC)、保守的Q学習 (CQL)[107]、ランダム化アンサンブル化ダブルQ学習 (REDQ)、Dreamer、Decision Transformer[107]、人間のフィードバックによる強化学習 (RLHF)に対応している[115]。
TensorDict
![]() | この節の加筆が望まれています。 |
TensorDict は PyTorch 向けの辞書構造ライブラリである。
TorchRec
![]() | この節の加筆が望まれています。 |
tensordict は PyTorch 向けのレコメンデーション(推薦)ライブラリである。
torchtune
![]() | この節の加筆が望まれています。 |
torchtuneは PyTorch 向けの大規模言語モデル (LLM) ファインチューニングライブラリである。
torchchat / gpt-fast
![]() | この節の加筆が望まれています。 |
torchchatはPyTorch向けのLLMを使ったチャット実装であり[116]、gpt-fastはPyTorchを使ったLLM実装である[117]。
PyTorch Hub
PyTorch Hub (torch.hub) は GitHub のリポジトリ上に存在するPyTorchの事前学習モデルを簡単に使用できるようにするための仕組みである[118]。リポジトリをこの PyTorch Hub で使えるようにするためにはリポジトリ上に「hubconf.py」を設置する必要がある[118]。
PyTorch Hub に対応しているリポジトリには例えば以下が存在する:
- torchvision ("pytorch/vision")[106]
- Fairseq ("pytorch/fairseq")[119]
- YOLOv5 ("ultralytics/yolov5")[120]
PyTorch Hubでリポジトリ上のモデルを読み込むコードの例:
import torch
yolov5s_model = torch.hub.load('ultralytics/yolov5', 'yolov5s')
なお外部プロジェクトには GitHub 及び PyTorch Hubを使わずに、GitHubに似た独自リポジトリシステムとその読み込みスクリプトを提供するものも存在する(Hugging Face の Hugging Face Hub[121][122]など)。
TorchMetrics
![]() | この節の加筆が望まれています。 |
Lightning AI製の評価指標(メトリクス)用ライブラリ。
学習
![]() | この節の加筆が望まれています。 |
データローダ
PyTorchに使うことのできるデータローダには以下が存在する:
- DataLoader (torch.utils.data.DataLoader) - PyTorch標準のデータローダ。
- DataLoader2 (torchdata.dataloader2.DataLoader2) - 別パッケージのTorchDataで開発されているDataPipesを基にした実験的なデータローダ[123]。
- FFCV (Fast Forward Computer Vision; ffcv.loader)
- NVIDIA DALI (Data Loading Library; nvidia.dali.plugin.pytorch) - GPUDirect Storageに対応したデータローダ[124]。PyTorch以外にも対応している[124]。実験的なGPUデータ解凍にも対応している (fn.experimental.inflate)。
またデータセットの自動ダウンロードに対応するデータローダとして HuggingFace Datasets Hub のデータセットが使用可能な HuggingFace 製の Datasets も存在する。
PyTorch Lightning
→「PyTorch Lightning」を参照
PyTorch Lightning は学習から展開までの簡易化のための PyTorch 向けサードパーティー製ライブラリである。
高速化
より高い効率で学習を行うための様々な標準機能・外部ライブラリが提供されている。以下はその一例である。
名称 | 手法 |
---|---|
torch.jit.script | TorchScript言語での記述 → TorchScriptによる学習の最適化 |
torch.jit.trace | 仮入力トレースによるTorchScriptへの変換 → TorchScriptによる学習の最適化 |
torch.fx.symbolic_trace | FXのTrace機能によるGraphIR化→各バックエンドによる学習の最適化 |
torch._dynamo | 部分的GraphIR化→PythonJITコンパイラでの部分的バックエンド利用による学習の最適化 |
それぞれの手法で特徴と適用範囲が異なる。例えばDynamoでは動的条件分岐を扱えるがfx.symbolic_traceでは扱えない[125]。
脚注
出典
参考文献
関連項目
外部リンク
Wikiwand - on
Seamless Wikipedia browsing. On steroids.