import torch
from datasystem.ds_tensor_client import DsTensorClient
from torch_npu.npu import current_device
from vllm.distributed.parallel_state import (
get_tp_group, get_ep_group, get_pp_group
)
from vllm.logger import logger
from collections import OrderedDict
_global_tensor_client = None
_DEV_MGET_TIMEOUT_MS = 10 * 1000
# Create openYuanrong data system client
def get_tensor_client():
global _global_tensor_client
if _global_tensor_client:
return _global_tensor_client
_global_tensor_client = DsTensorClient(host, port, current_device())
_global_tensor_client.init()
return _global_tensor_client
# Publish model parameter metadata to openYuanrong data system
def _publish_to_ds(model, key: str):
tensor_client = get_tensor_client()
tensor_list = []
named_tensors = OrderedDict()
for name, param in model.named_parameters():
named_tensors[name] = param.data
for name, param in named_tensors.items():
tensor_list.append(param)
key_list = [key + f"index_{index}" for index in range(len(tensor_list))]
tensor_client.dev_mset(key_list, tensor_list)
# Load specified model parameters from openYuanrong data system
def _load_from_ds(model, key: str):
tensor_client = get_tensor_client()
tensor_list = []
num_load_success_param = 0
named_tensors = OrderedDict()
for name, param in model.named_parameters():
named_tensors[name] = param.data
for name, param in named_tensors.items():
tensor_list.append(param)
key_list = [key + f"index_{index}" for index in range(len(tensor_list))]
tensor_client.dev_mget(key_list, tensor_list, _DEV_MGET_TIMEOUT_MS)
num_load_success_param += len(tensor_list)
return num_load_success_param
# Key step
# Accelerate model loading using openYuanrong data system
try:
_load_from_ds(model, key)
except Exception as e:
logger.info(f"Fallback to load user_callback. error:{e}")
# Load model using default method
load_callback(model, model_config)
_publish_to_ds(model, key)
// Original SpringBoot project code needs no modification
// Example: Application.java
package com.example.microservicedemo;
import org.springframework.boot.SpringApplication;
import org.springframework.boot.autoconfigure.SpringBootApplication;
@SpringBootApplication
public class Application {
public static void main(String[] args) {
SpringApplication.run(Application.class, args);
}
}
// Example: MyController.java
package com.example.microservicedemo;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.RequestMapping;
import org.springframework.web.bind.annotation.RestController;
@RequestMapping("/rest")
@RestController
public class MyController {
@GetMapping(value = "/hello")
public String helloWorld(String name) {
return "Hello World, " + name;
}
}
Step 1: Add dependency on openYuanrong SpringBoot Adapter SDK in pom.xml
<dependency>
<groupId>org.yuanrong.m2s</groupId>
<artifactId>microservice-function-yuanrong</artifactId>
</dependency>
Step 2: After compiling and packaging, call REST API to deploy the application on openYuanrong cluster
curl -X POST -i /serverless/v1/functions -d @create_func.json
// create_func.json
{
"name": "0@microservice@demo",
"handler": "org.yuanrong.m2s.function.Handler.handleRestRequest",
"runtime": "java8",
"cpu": 600,
"memory": 512,
"timeout": 30,
"environment": {
"spring_start_class": "com.example.microservicedemo.Application"
},
"kind": "faas",
"extendedHandler": {
"initializer": "org.yuanrong.m2s.function.Handler.initializer"
},
"extendedTimeout": {
"initializer": 600
},
}
Step 3: Use REST API to call the service
curl -X POST -i /serverless/v1/functions/${FUNCTION_VERSION_URN}/invocations -d @invoke_func.json
// invoke_func.json
{
"httpMethod": "GET",
"body": "",
"path":"/rest/hello",
"pathParameters":{},
"queryStringParameters": {
"name": "yuanrong"
}
}