import torch
from datasystem.ds_tensor_client import DsTensorClient
from torch_npu.npu import current_device
from vllm.distributed.parallel_state import (
get_tp_group, get_ep_group, get_pp_group
)
from vllm.logger import logger
from collections import OrderedDict
_global_tensor_client = None
_DEV_MGET_TIMEOUT_MS = 10 * 1000
# 创建 openYuanrong 数据系统客户端
def get_tensor_client():
global _global_tensor_client
if _global_tensor_client:
return _global_tensor_client
_global_tensor_client = DsTensorClient(host, port, current_device())
_global_tensor_client.init()
return _global_tensor_client
# 将模型参数元数据发布到 openYuanrong 数据系统
def _publish_to_ds(model, key: str):
tensor_client = get_tensor_client()
tensor_list = []
named_tensors = OrderedDict()
for name, param in model.named_parameters():
named_tensors[name] = param.data
for name, param in named_tensors.items():
tensor_list.append(param)
key_list = [key + f"index_{index}" for index in range(len(tensor_list))]
tensor_client.dev_mset(key_list, tensor_list)
# 从 openYuanrong 数据系统中加载指定的模型参数
def _load_from_ds(model, key: str):
tensor_client = get_tensor_client()
tensor_list = []
num_load_success_param = 0
named_tensors = OrderedDict()
for name, param in model.named_parameters():
named_tensors[name] = param.data
for name, param in named_tensors.items():
tensor_list.append(param)
key_list = [key + f"index_{index}" for index in range(len(tensor_list))]
tensor_client.dev_mget(key_list, tensor_list, _DEV_MGET_TIMEOUT_MS)
num_load_success_param += len(tensor_list)
return num_load_success_param
# 关键步骤
# 使用 openYuanrong 数据系统加速模型加载
try:
_load_from_ds(model, key)
except Exception as e:
logger.info(f"Fallback to load user_callback. error:{e}")
# 使用默认方法加载模型
load_callback(model, model_config)
_publish_to_ds(model, key)
// 原SpringBoot项目代码无需修改
// 例如:Application.java
package com.example.microservicedemo;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
@SpringBootApplication
public class Application {
public static void main(String[] args) {
SpringApplication.run(Application.class, args);
}
}
// 例如:MyController.java
package com.example.microservicedemo;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
@RequestMapping("/rest")
@RestController
public class MyController {
@GetMapping(value = "/hello")
public String helloWorld(String name) {
return "Hello World, " + name;
}
}
第一步:在pom.xml文件中加入对openYuanrong SpringBoot Adapter SDK的依赖
<dependency>
<groupId>org.yuanrong.m2s</groupId>
<artifactId>microservice-function-yuanrong</artifactId>
</dependency>
第二步:编译打包后调用REST API在openYuanrong集群上部署应用
curl -X POST -i /serverless/v1/functions -d @create_func.json
// create_func.json
{
"name": "0@microservice@demo",
"handler": "org.yuanrong.m2s.function.Handler.handleRestRequest",
"runtime": "java8",
"cpu": 600,
"memory": 512,
"timeout": 30,
"environment": {
"spring_start_class": "com.example.microservicedemo.Application"
},
"kind": "faas",
"extendedHandler": {
"initializer": "org.yuanrong.m2s.function.Handler.initializer"
},
"extendedTimeout": {
"initializer": 600
},
}
第三步:使用REST API调用服务
curl -X POST -i /serverless/v1/functions/${FUNCTION_VERSION_URN}/invocations -d @invoke_func.json
// invoke_func.json
{
"httpMethod": "GET",
"body": "",
"path":"/rest/hello",
"pathParameters":{},
"queryStringParameters": {
"name": "yuanrong"
}
}