JAX 训练 (JAXJob)
使用 JAXJob 通过 JAX 训练模型
旧版本
本页介绍的是 Kubeflow Training Operator V1,有关最新信息请查看 Kubeflow Trainer V2 文档。
本页介绍了如何使用 JAXJob
通过 JAX 训练机器学习模型。
JAXJob
是一个 Kubernetes 自定义资源 (custom resource),用于在 Kubernetes 上运行 JAX 训练作业。JAXJob
在 Kubeflow 中的实现位于 training-operator
中。
目前,JAX 的自定义资源已在 CPU 上测试,使用 gloo 在 CPU 之间进行通信以运行多个进程。副本 0 的工作节点被识别为 JAX 协调器。进程 0 将启动一个 JAX 协调器服务,该服务通过集群中进程 0 的 IP 地址以及该进程上可用的端口暴露,集群中的其他进程将连接到该服务。我们正在寻求用户反馈,以在 GPU 和 TPU 上运行 JAXJob。
创建 JAX 训练作业
您可以通过定义 JAXJob
配置文件来创建训练作业。请参阅简单 JAXJob 示例的清单。您可以根据您的要求修改作业配置文件。
部署 JAXJob
资源开始训练
kubectl create -f https://raw.githubusercontent.com/kubeflow/training-operator/refs/heads/release-1.9/examples/jax/cpu-demo/demo.yaml
您现在应该能够看到创建的与指定副本数匹配的 Pod。
kubectl get pods -n kubeflow -l training.kubeflow.org/job-name=jaxjob-simple
在 CPU 集群上,分布式计算需要几分钟。可以检查日志以查看其进度。
PODNAME=$(kubectl get pods -l training.kubeflow.org/job-name=jaxjob-simple,training.kubeflow.org/replica-type=worker,training.kubeflow.org/replica-index=0 -o name -n kubeflow)
kubectl logs -f ${PODNAME} -n kubeflow
I1016 14:30:28.956959 139643066051456 distributed.py:106] Starting JAX distributed service on [::]:6666
I1016 14:30:28.959352 139643066051456 distributed.py:119] Connecting to JAX distributed service on jaxjob-simple-worker-0:6666
I1016 14:30:30.633651 139643066051456 xla_bridge.py:895] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I1016 14:30:30.638316 139643066051456 xla_bridge.py:895] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
JAX process 0/1 initialized on jaxjob-simple-worker-0
JAX global devices:[CpuDevice(id=0), CpuDevice(id=131072)]
JAX local devices:[CpuDevice(id=0)]
JAX device count:2
JAX local device count:1
[2.]
监控 JAXJob
kubectl get -o yaml jaxjobs jaxjob-simple -n kubeflow
查看状态部分以监控作业状态。以下是作业成功完成时的示例输出。
apiVersion: kubeflow.org/v1
kind: JAXJob
metadata:
annotations:
kubectl.kubernetes.io/last-applied-configuration: |
{"apiVersion":"kubeflow.org/v1","kind":"JAXJob","metadata":{"annotations":{},"name":"jaxjob-simple","namespace":"kubeflow"},"spec":{"jaxReplicaSpecs":{"Worker":{"replicas":2,"restartPolicy":"OnFailure","template":{"spec":{"containers":[{"command":["python3","train.py"],"image":"docker.io/kubeflow/jaxjob-simple:latest","imagePullPolicy":"Always","name":"jax"}]}}}}}}
creationTimestamp: "2024-09-22T20:07:59Z"
generation: 1
name: jaxjob-simple
namespace: kubeflow
resourceVersion: "1972"
uid: eb20c874-44fc-459b-b9a8-09f5c3ff46d3
spec:
jaxReplicaSpecs:
Worker:
replicas: 2
restartPolicy: OnFailure
template:
spec:
containers:
- command:
- python3
- train.py
image: docker.io/kubeflow/jaxjob-simple:latest
imagePullPolicy: Always
name: jax
status:
completionTime: "2024-09-22T20:11:34Z"
conditions:
- lastTransitionTime: "2024-09-22T20:07:59Z"
lastUpdateTime: "2024-09-22T20:07:59Z"
message: JAXJob jaxjob-simple is created.
reason: JAXJobCreated
status: "True"
type: Created
- lastTransitionTime: "2024-09-22T20:11:28Z"
lastUpdateTime: "2024-09-22T20:11:28Z"
message: JAXJob kubeflow/jaxjob-simple is running.
reason: JAXJobRunning
status: "False"
type: Running
- lastTransitionTime: "2024-09-22T20:11:34Z"
lastUpdateTime: "2024-09-22T20:11:34Z"
message: JAXJob kubeflow/jaxjob-simple successfully completed.
reason: JAXJobSucceeded
status: "True"
type: Succeeded
replicaStatuses:
Worker:
selector: training.kubeflow.org/job-name=jaxjob-simple,training.kubeflow.org/operator-name=jaxjob-controller,training.kubeflow.org/replica-type=worker
succeeded: 2
startTime: "2024-09-22T20:07:59Z"