JAX 训练 (JAXJob)

使用 JAXJob 通过 JAX 训练模型

本页介绍了如何使用 JAXJob 通过 JAX 训练机器学习模型。

JAXJob 是一个 Kubernetes 自定义资源 (custom resource),用于在 Kubernetes 上运行 JAX 训练作业。JAXJob 在 Kubeflow 中的实现位于 training-operator 中。

目前,JAX 的自定义资源已在 CPU 上测试,使用 gloo 在 CPU 之间进行通信以运行多个进程。副本 0 的工作节点被识别为 JAX 协调器。进程 0 将启动一个 JAX 协调器服务,该服务通过集群中进程 0 的 IP 地址以及该进程上可用的端口暴露,集群中的其他进程将连接到该服务。我们正在寻求用户反馈,以在 GPU 和 TPU 上运行 JAXJob。

创建 JAX 训练作业

您可以通过定义 JAXJob 配置文件来创建训练作业。请参阅简单 JAXJob 示例的清单。您可以根据您的要求修改作业配置文件。

部署 JAXJob 资源开始训练

kubectl create -f https://raw.githubusercontent.com/kubeflow/training-operator/refs/heads/release-1.9/examples/jax/cpu-demo/demo.yaml

您现在应该能够看到创建的与指定副本数匹配的 Pod。

kubectl get pods -n kubeflow -l training.kubeflow.org/job-name=jaxjob-simple

在 CPU 集群上,分布式计算需要几分钟。可以检查日志以查看其进度。

PODNAME=$(kubectl get pods -l training.kubeflow.org/job-name=jaxjob-simple,training.kubeflow.org/replica-type=worker,training.kubeflow.org/replica-index=0 -o name -n kubeflow)
kubectl logs -f ${PODNAME} -n kubeflow
I1016 14:30:28.956959 139643066051456 distributed.py:106] Starting JAX distributed service on [::]:6666
I1016 14:30:28.959352 139643066051456 distributed.py:119] Connecting to JAX distributed service on jaxjob-simple-worker-0:6666
I1016 14:30:30.633651 139643066051456 xla_bridge.py:895] Unable to initialize backend 'rocm': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig'
I1016 14:30:30.638316 139643066051456 xla_bridge.py:895] Unable to initialize backend 'tpu': INTERNAL: Failed to open libtpu.so: libtpu.so: cannot open shared object file: No such file or directory
JAX process 0/1 initialized on jaxjob-simple-worker-0
JAX global devices:[CpuDevice(id=0), CpuDevice(id=131072)]
JAX local devices:[CpuDevice(id=0)]
JAX device count:2
JAX local device count:1
[2.]

监控 JAXJob

kubectl get -o yaml jaxjobs jaxjob-simple -n kubeflow

查看状态部分以监控作业状态。以下是作业成功完成时的示例输出。

apiVersion: kubeflow.org/v1
kind: JAXJob
metadata:
  annotations:
    kubectl.kubernetes.io/last-applied-configuration: |
      {"apiVersion":"kubeflow.org/v1","kind":"JAXJob","metadata":{"annotations":{},"name":"jaxjob-simple","namespace":"kubeflow"},"spec":{"jaxReplicaSpecs":{"Worker":{"replicas":2,"restartPolicy":"OnFailure","template":{"spec":{"containers":[{"command":["python3","train.py"],"image":"docker.io/kubeflow/jaxjob-simple:latest","imagePullPolicy":"Always","name":"jax"}]}}}}}}      
  creationTimestamp: "2024-09-22T20:07:59Z"
  generation: 1
  name: jaxjob-simple
  namespace: kubeflow
  resourceVersion: "1972"
  uid: eb20c874-44fc-459b-b9a8-09f5c3ff46d3
spec:
  jaxReplicaSpecs:
    Worker:
      replicas: 2
      restartPolicy: OnFailure
      template:
        spec:
          containers:
            - command:
                - python3
                - train.py
              image: docker.io/kubeflow/jaxjob-simple:latest
              imagePullPolicy: Always
              name: jax
status:
  completionTime: "2024-09-22T20:11:34Z"
  conditions:
    - lastTransitionTime: "2024-09-22T20:07:59Z"
      lastUpdateTime: "2024-09-22T20:07:59Z"
      message: JAXJob jaxjob-simple is created.
      reason: JAXJobCreated
      status: "True"
      type: Created
    - lastTransitionTime: "2024-09-22T20:11:28Z"
      lastUpdateTime: "2024-09-22T20:11:28Z"
      message: JAXJob kubeflow/jaxjob-simple is running.
      reason: JAXJobRunning
      status: "False"
      type: Running
    - lastTransitionTime: "2024-09-22T20:11:34Z"
      lastUpdateTime: "2024-09-22T20:11:34Z"
      message: JAXJob kubeflow/jaxjob-simple successfully completed.
      reason: JAXJobSucceeded
      status: "True"
      type: Succeeded
  replicaStatuses:
    Worker:
      selector: training.kubeflow.org/job-name=jaxjob-simple,training.kubeflow.org/operator-name=jaxjob-controller,training.kubeflow.org/replica-type=worker
      succeeded: 2
  startTime: "2024-09-22T20:07:59Z"

反馈

本页是否有帮助?