Source code for yr.decorator.instance_proxy

#!/usr/bin/env python3
# coding=UTF-8
# Copyright (c) Huawei Technologies Co., Ltd. 2025. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""
Instance decorator
"""

import contextlib
import inspect
import logging
import os
import threading
import weakref
import uuid
from typing import List
import yr
from yr import signature
from yr.code_manager import CodeManager
from yr.fnruntime import buffer_from_bytes
from yr.serialization.serialization import Serialization
from yr.generator import ObjectRefGenerator
from yr.common import constants, utils
from yr.common.types import GroupInfo
from yr.config import InvokeOptions, function_group_enabled
from yr.executor.instance_manager import InstanceManager
from yr.libruntime_pb2 import FunctionMeta, LanguageType
from yr.object_ref import ObjectRef
from yr.runtime_holder import global_runtime, save_real_instance_id
from yr.serialization import register_pack_hook, register_unpack_hook
from yr.accelerate.shm_broadcast import MessageQueue, STOP_EVENT
from yr.serialization import Serialization

_logger = logging.getLogger(__name__)


[docs] class InstanceCreator: """ User instance creator. """
[docs] def __init__(self): """ Initialize the InstanceCreator instance. """ self.__user_class__ = None self.__user_class_descriptor__ = None self.__user_class_methods__ = {} self.__base_cls__ = None self.__target_function_key__ = None self._yr_embedded_code_ref = None self._yr_embedded_code = None self._lock = threading.Lock() self.__invoke_options__ = InvokeOptions() self.__is_async__ = False self.function_group_size = 0
def __getstate__(self): attrs = self.__dict__.copy() del attrs["_lock"] del attrs["_yr_embedded_code_ref"] del attrs["_yr_embedded_code"] return attrs def __setstate__(self, state): self.__dict__.update(state) self.__dict__["_lock"] = threading.Lock() self.__dict__["_yr_embedded_code_ref"] = None self.__dict__["_yr_embedded_code"] = None @property def user_class_descriptor(self): """ Get the user class descriptor. Returns: The user class descriptor. """ return self.__user_class_descriptor__
[docs] @classmethod def create_from_user_class(cls, user_class, invoke_options): """ Create from user class. Args: user_class (class): The user class. invoke_options (InvokeOptions): The invoke options. Returns: The InstanceCreator object itself. Data type is InstanceCreator. """ class DerivedInstanceCreator(cls, user_class): """derived instance creator""" pass name = f"YRInstance({user_class.__name__})" DerivedInstanceCreator.__module__ = user_class.__module__ DerivedInstanceCreator.__name__ = name DerivedInstanceCreator.__qualname__ = name self = DerivedInstanceCreator.__new__(DerivedInstanceCreator) self.__user_class__ = user_class # Save original class info for skip_serialize mode (after self is assigned) self.__original_class_name__ = user_class.__name__ self.__original_qualname__ = user_class.__qualname__ if invoke_options is not None: self.__invoke_options__ = invoke_options else: self.__invoke_options__ = InvokeOptions() class_methods = inspect.getmembers(user_class, utils.is_function_or_method) self.__is_async__ = len([ name for name, method in class_methods if inspect.iscoroutinefunction(method) or inspect.isasyncgenfunction(method) ]) > 0 if ((self.__invoke_options__.concurrency is None or self.__invoke_options__.concurrency == 1) and not self.__is_async__): self.__invoke_options__.need_order = True self.__user_class_descriptor__ = utils.ObjectDescriptor.get_from_class(user_class) self.__user_class_descriptor__.target_language = LanguageType.Python self.__target_function_key__ = "" self.__user_class_methods__ = dict(class_methods) self.__base_cls__ = inspect.getmro(user_class) object.__setattr__(self, "_yr_embedded_code_ref", None) object.__setattr__(self, "_yr_embedded_code", None) self._lock = threading.Lock() self.function_group_size = 0 # Pre-register in CodeManager so get_instance_by_name can resolve local/nested classes # (classes defined inside functions have <locals> in their qualname, which getattr cannot # traverse on the module). Registration uses module%%qualname to match the key used by # load_code_from_local when function_meta has no code or codeID. CodeManager().register(f"{user_class.__module__}%%{user_class.__qualname__}", user_class) return self
[docs] @classmethod def create_cpp_user_class(cls, cpp_class): """ Create a cpp user class. Args: cpp_class (class): The cpp class. Returns: The InstanceCreator object itself. Data type is InstanceCreator. """ self = cls() self.__user_class_descriptor__ = utils.ObjectDescriptor(class_name=cpp_class.get_class_name(), function_name=cpp_class.get_factory_name(), target_language=LanguageType.Cpp) self.__user_class_methods__ = None self.__target_function_key__ = cpp_class.get_function_key() return self
[docs] @classmethod def create_cross_user_class(cls, user_class): """ Create a java user class Args: user_class (class): The java class. Returns: The InstanceCreator object itself. Data type is InstanceCreator. """ self = cls() self.__user_class_descriptor__ = utils.ObjectDescriptor(class_name=user_class.class_name, target_language=user_class.target_language) self.__user_class_methods__ = None self.__target_function_key__ = user_class.function_key return self
[docs] def options(self, *args, **kwargs): """ Options YR. Args: *args: Variable position parameters. YR mode is triggered only when a single parameter is of type InvokeOptions. **kwargs: Variable keyword arguments. Pass execution parameters. Returns: Instance option wrapper. """ if (len(args) == 1 and isinstance(args[0], InvokeOptions)): return self._options_yr(args[0]) return self._options_wrapper(**kwargs)
[docs] def set_function_group_size(self, function_group_size: int): """ Set function group size. Args: function_group_size (int): The function group size. Returns: The InstanceCreator object itself. Data type is InstanceCreator. """ self.function_group_size = function_group_size return self
[docs] def get_original_cls(self): """ Get the original class. Returns: The original user class type before encapsulation. """ return self.__user_class__
[docs] def create_instance_for_testing(self, invoke_options, function_id, name, args_list, group_info): """ Create instance for testing purposes. This is a public wrapper around _inner_create_instance for testing. Args: invoke_options: The invoke options. function_id: The function ID. name: The instance name. args_list: The arguments list. group_info: The group info. Returns: The instance ID. """ return self._inner_create_instance(invoke_options, function_id, name, args_list, group_info)
[docs] def invoke(self, *args, **kwargs): """ Create an instance in cluster. Args: *args: Variable arguments, used to pass non-keyword arguments. **kwargs: Variable arguments, used to pass keyword arguments. Returns: InstanceProxy. """ return self._invoke(args=args, kwargs=kwargs)
[docs] def get_instance(self, name, *args, **kwargs): """ Get an instance in cluster. Args: name (str): The instance name. *args: Variable arguments, used to pass non-keyword arguments. **kwargs: Variable arguments, used to pass keyword arguments. Returns: InstanceProxy. """ return self._invoke(name=name, args=args, kwargs=kwargs)
[docs] def snapstart(self, checkpoint_id: str) -> "InstanceProxy": """ Start a new instance from a checkpoint snapshot. This class method creates a new instance by restoring from a previously created snapshot. The new instance will have the same state as when the snapshot was taken. Args: checkpoint_id (str): The checkpoint ID returned by a previous snapshot() call. Format: {instanceID}-{functionID}-{uuid} Returns: InstanceProxy: A new instance proxy for the restored instance. Raises: RuntimeError: If the checkpoint does not exist or restore operation fails. Example: >>> import yr >>> yr.init() >>> >>> @yr.instance ... class Counter: ... def __init__(self): ... self.value = 0 ... ... def increment(self): ... self.value += 1 ... return self.value ... ... def __yr_after_snapstart__(self): ... print(f\"Restored with value={self.value}\") ... >>> # Create instance and snapshot >>> ins = Counter.invoke() >>> yr.get(ins.increment.invoke()) # value = 1 >>> checkpoint_id = ins.snapshot(leave_running=False) >>> >>> # Restore from snapshot using the class >>> restored_ins = Counter.snapstart(checkpoint_id) >>> result = yr.get(restored_ins.increment.invoke()) # value = 2 >>> print(f\"Value after restore: {result}\") >>> >>> yr.finalize() """ _logger.info("Starting instance from snapshot: %s", checkpoint_id) new_instance_id = global_runtime.get_runtime().snapstart_instance(checkpoint_id) # Create a new InstanceProxy for the restored instance restored_proxy = InstanceProxy( instance_id=new_instance_id, class_descriptor=self.__user_class_descriptor__, class_methods=self.__user_class_methods__, base_cls=self.__base_cls__, function_id="", need_order=self.__invoke_options__.need_order, group_name="", is_async=self.__is_async__, instance_name=self.__invoke_options__.name, namespace=self.__invoke_options__.namespace, code_ref=self._yr_embedded_code_ref ) _logger.info("Instance restored from snapshot %s: %s", checkpoint_id, new_instance_id) return restored_proxy
def _register_python_class_lookup_keys(self, invoke_options): if (self.__user_class_descriptor__.target_language != LanguageType.Python or not self.__user_class__): return lookup_key = (self.__user_class_descriptor__.module_name + "%%" + self.__user_class_descriptor__.class_name) CodeManager().register(lookup_key, self.__user_class__) if (getattr(invoke_options, "skip_serialize", False) and hasattr(self, "__original_class_name__")): short_key = (self.__user_class_descriptor__.module_name + "%%" + self.__original_class_name__) if short_key != lookup_key: CodeManager().register(short_key, self.__user_class__) def _inner_create_instance(self, invoke_options, function_id, name, args_list, group_info): # For skip_serialize mode, use original class name instead of YRInstance(...) wrapper name class_name = self.__user_class_descriptor__.class_name if getattr(invoke_options, "skip_serialize", False) and hasattr(self, "__original_class_name__"): class_name = self.__original_class_name__ func_meta = FunctionMeta(functionID=function_id, moduleName=self.__user_class_descriptor__.module_name, className=class_name, functionName=self.__user_class_descriptor__.function_name, language=self.__user_class_descriptor__.target_language, codeID=self._yr_embedded_code_ref.id if self._yr_embedded_code_ref is not None else "", name=name if name is not None else invoke_options.name, ns=invoke_options.namespace, isAsync=self.__is_async__, code=self._yr_embedded_code if self._yr_embedded_code is not None else b"",) runtime = global_runtime.get_runtime() if invoke_options.name: with contextlib.suppress(Exception): runtime.get_instance_by_name( invoke_options.name, invoke_options.namespace, timeout=30) instance_id = runtime.create_instance(func_meta=func_meta, args=args_list, opt=invoke_options, group_info=group_info) self._register_python_class_lookup_keys(invoke_options) return instance_id def _invoke(self, name=None, args=None, kwargs=None, invoke_options=None): if invoke_options is None: invoke_options = self.__invoke_options__ invoke_options.check_options_valid() if invoke_options.get_if_exists: if invoke_options.name is None or len(invoke_options.name) == 0: raise ValueError("The actor name must be specified to use `get_if_exists`.") try: return get_instance_by_name(invoke_options.name, "" if invoke_options.namespace is None else invoke_options.namespace, 60) except Exception as e: _logger.debug("instance: %s not exist, current get instance err is : %s", invoke_options.name, e) is_cross_invoke = self.__user_class_descriptor__.target_language != LanguageType.Python skip_serialize = getattr(invoke_options, "skip_serialize", False) need_embedded_serialization = ( not is_cross_invoke and not skip_serialize and ( self._yr_embedded_code_ref is None or not global_runtime.get_runtime().is_object_existing_in_local( self._yr_embedded_code_ref.id) ) ) with self._lock: # Skip serialization for pre-deployed classes when skip_serialize=True if need_embedded_serialization: serialized_object = Serialization().serialize(self.__user_class__) if len(serialized_object) <= 102400: self._yr_embedded_code = serialized_object.to_bytes() _logger.debug( "[Reference Counting] pass code by request, functionName = %s", self.__user_class__.__qualname__ ) else: self._yr_embedded_code_ref = ObjectRef( global_runtime.get_runtime().put_serialized(serialized_object), need_incre=False ) _logger.info( "[Reference Counting] put code with id = %s, className = %s", self._yr_embedded_code_ref.id, self.__user_class_descriptor__.class_name ) elif skip_serialize: # For pre-deployed classes, skip serialization class_path = f"{self.__user_class__.__module__}.{self.__user_class__.__qualname__}" _logger.debug("[Reference Counting] skip serialization for pre-deployed class: %s", class_path) # __init__ existed when user-defined if self.__user_class_methods__ is not None and '__init__' in self.__user_class_methods__: sig = signature.get_signature(self.__user_class_methods__.get('__init__'), ignore_first=True) else: sig = None function_id = "" if self.__user_class_descriptor__.target_language == LanguageType.Python: args_list = signature.package_args(sig, args, kwargs) else: args_list = utils.make_cross_language_args(args, kwargs) function_id = self.__target_function_key__ group_info = None if function_group_enabled(invoke_options.function_group_options, self.function_group_size): group_info = GroupInfo() group_info.group_name = global_runtime.get_runtime().generate_group_name() group_info.group_size = self.function_group_size instance_id = self._inner_create_instance(invoke_options, function_id, name, args_list, group_info) group_name = "" if group_info is None else group_info.group_name return InstanceProxy(instance_id, self.__user_class_descriptor__, self.__user_class_methods__, self.__base_cls__, function_id, invoke_options.need_order, group_name, self.__is_async__, name if name is not None else invoke_options.name, invoke_options.namespace, self._yr_embedded_code_ref) def _options_wrapper(self, **actor_options): """ options wrapper, Set user invoke options """ name = actor_options.get("name") namespace = actor_options.get("namespace") lifecycle = actor_options.get("lifetime") get_if_exists = actor_options.get("get_if_exists") if name is not None: if not isinstance(name, str): raise TypeError( f"name must be None or a string, got: '{type(name)}'.") if name == "": raise ValueError("stateful function name cannot be an empty string.") if namespace is not None: if not isinstance(namespace, str): raise TypeError("namespace must be None or a string.") if namespace == "": raise ValueError('"" is not a valid namespace. ' "Pass None to not specify a namespace.") if lifecycle is not None: if not isinstance(lifecycle, str): raise TypeError( f"lifetime must be None or a string, got: '{type(lifecycle)}'.") if lifecycle != "detached": raise ValueError(f"lifetime is only support detached") self.__invoke_options__.custom_extensions["lifecycle"] = lifecycle self.__invoke_options__.name = name self.__invoke_options__.namespace = namespace self.__invoke_options__.get_if_exists = False if get_if_exists is None else get_if_exists if "runtime_env" in actor_options: if "env_vars" in actor_options["runtime_env"]: self.__invoke_options__.env_vars = actor_options[ "runtime_env"]["env_vars"] if "resources" in actor_options: resources = actor_options.get("resources") if not isinstance(resources, dict): raise TypeError("resources must be None or a string.") self.__invoke_options__.custom_resources.update(resources) return self._options_yr(self.__invoke_options__) def _options_yr(self, invoke_options: InvokeOptions): """ Set user invoke options Args: invoke_options: invoke options for users to set resources """ instance_cls = self invoke_options.check_options_valid() if (invoke_options.concurrency is None or invoke_options.concurrency == 1) and not self.__is_async__: invoke_options.need_order = True else: invoke_options.need_order = False class InstanceOptionWrapper: """instance option wrapper""" def invoke(self, *args, **kwargs): """invoke""" return instance_cls._invoke(args=args, kwargs=kwargs, invoke_options=invoke_options) def get_instance(self, name: str, *args, **kwargs): """ Create an instance in cluster name: str, the instance name """ return instance_cls._invoke(name=name, args=args, kwargs=kwargs, invoke_options=invoke_options) return InstanceOptionWrapper()
[docs] class InstanceProxy: """ Use to decorate a user class. """
[docs] def __init__(self, instance_id, class_descriptor, class_methods, base_cls, function_id, need_order=True, group_name="", is_async=False, instance_name=None, namespace=None, code_ref=None): """ Initialize the InstanceProxy instance. """ self._class_descriptor = class_descriptor self.instance_id = instance_id self._class_methods = class_methods self._base_cls = base_cls self._method_descriptor = {} self.__instance_activate__ = True self._function_id = function_id self.need_order = need_order self.group_name = group_name self._is_async = is_async self._instance_name = instance_name self._ns = namespace self._yr_embedded_code_ref = code_ref if self._class_methods is not None: for method_name, value in self._class_methods.items(): is_async_ = False if inspect.isgeneratorfunction(value) else self._is_async function_descriptor = utils.ObjectDescriptor( module_name=self._class_descriptor.module_name, function_name=method_name, class_name=self._class_descriptor.class_name, is_generator=inspect.isgeneratorfunction(value) or inspect.isasyncgenfunction(value), is_async=is_async_) self._method_descriptor[method_name] = function_descriptor is_bound = utils.is_class_method(value) or utils.is_static_method(self._base_cls, method_name) sig = signature.get_signature(value, ignore_first=not is_bound) return_nums = value.__return_nums__ if hasattr( value, "__return_nums__") else 1 tensor_transport_target = value.__tensor_transport_target__ if hasattr( value, "__tensor_transport_target__") else "" enable_tensor_transport = value.__enable_tensor_transport__ if hasattr( value, "__enable_tensor_transport__") else False method = MethodProxy(self, self.instance_id, self._method_descriptor.get(method_name), sig, return_nums, function_id, is_async_, self._instance_name, self._ns, tensor_transport_target, enable_tensor_transport) setattr(self, method_name, method)
def __getattr__(self, method_name): if self._class_descriptor.target_language == LanguageType.Python: raise AttributeError(f"'{self._class_descriptor.class_name}' object has " f"no attribute '{method_name}'") function_name = method_name if self._class_descriptor.target_language == LanguageType.Cpp: function_name = "&" + self._class_descriptor.class_name + "::" + method_name method_descriptor = utils.ObjectDescriptor(module_name=self._class_descriptor.module_name, function_name=function_name, class_name=self._class_descriptor.class_name, target_language=self._class_descriptor.target_language, is_generator=self.__get_method_generator(method_name)) return MethodProxy( self, self.instance_id, method_descriptor, None, 1, self._function_id, self._is_async, self._instance_name, self._ns) def __reduce__(self): state = self.serialization_(False) return InstanceProxy.deserialization_, (state,) @property def real_id(self) -> str: """ The real instance ID assigned by the runtime. ``instance_id`` is a logical key used internally. This property blocks until the runtime finishes scheduling the actor (default timeout: 30 seconds), then resolves it to the physical instance ID. Raises: TimeoutError: If the actor is not ready within 30 seconds. Returns: The real instance ID. Data type is str. Examples: >>> ins = MyActor.invoke() >>> print(ins.real_id) """ runtime = global_runtime.get_runtime() runtime.wait([self.instance_id], 1, 30) return runtime.get_real_instance_id(self.instance_id)
[docs] @classmethod def deserialization_(cls, state): """ Deserialization to rebuild instance proxy. Args: state (dict): Contains serialized state information. Returns: Returns a class instance. Data type is InstanceProxy. """ class_method = None if constants.CLASS_METHOD in state: class_method = state[constants.CLASS_METHOD] function_name = state[constants.FUNC_NAME] if constants.FUNC_NAME in state else "" need_order = state[constants.NEED_ORDER] if constants.NEED_ORDER in state else False is_async = state[constants.IS_ASYNC] if constants.IS_ASYNC in state else False save_real_instance_id(state[constants.INSTANCE_ID], need_order) return cls(instance_id=state[constants.INSTANCE_ID], class_descriptor=utils.ObjectDescriptor(state[constants.MODULE_NAME], state[constants.CLASS_NAME], function_name, state[constants.TARGET_LANGUAGE]), class_methods=class_method, base_cls=state[constants.BASE_CLS], function_id="", need_order=need_order, is_async=is_async)
[docs] def serialization_(self, is_cross_language: False): """ Serialization of instance proxy. Args: is_cross_language (bool, optional): Whether cross-language, default to ``False``. Returns: Serialized instance proxy information. Data type is dict. """ info_ = {constants.INSTANCE_ID: global_runtime.get_runtime().get_real_instance_id(self.instance_id)} if is_cross_language is False: info_[constants.CLASS_METHOD] = self._class_methods info_[constants.NEED_ORDER] = self.need_order info_[constants.BASE_CLS] = self._base_cls info_[constants.IS_ASYNC] = self._is_async self._class_descriptor.to_dict() state = {**info_, **self._class_descriptor.to_dict()} global_runtime.get_runtime().wait([self.instance_id], 1, -1) return state
[docs] def terminate(self, is_sync: bool = False): """ Terminate the instance. Supports synchronous or asynchronous termination. When synchronous termination is not enabled, the default timeout for the current kill request is 30 seconds. In scenarios such as high disk load or etcd failure, the kill request processing time may exceed 30 seconds, causing the interface to throw a timeout exception. Since the kill request has a retry mechanism, users can choose not to handle or retry after capturing the timeout exception. When synchronous termination is enabled, the interface will block until the instance completely exits. Args: is_sync (bool, optional): Whether to enable synchronization. If true, it indicates sending a kill request with the signal quantity of killInstanceSync to the function-proxy, and the kernel synchronously kills the instance. If false, it indicates sending a kill request with the signal quantity of killInstance to the function-proxy, and the kernel asynchronously kills the instance. Default to ``False``. """ if not self.is_activate(): return if self.group_name != "": global_runtime.get_runtime().terminate_group(self.group_name) elif is_sync: global_runtime.get_runtime().terminate_instance_sync(self.instance_id) else: global_runtime.get_runtime().terminate_instance(self.instance_id) self.__instance_activate__ = False _logger.info("%s is terminated", self.instance_id)
[docs] def is_activate(self): """ Return the instance status. Returns: The instance status. Data type is bool. """ return self.__instance_activate__
[docs] def get_function_group_handler(self) -> "FunctionGroupHandler": """ Get the FunctionGroupHandler. Returns: The function group handler. Data type is FunctionGroupHandler. """ if self.group_name == "": raise RuntimeError( "unsupported function type: this function can only be used for group instance handler" ) instance_ids = global_runtime.get_runtime().get_instances( self.instance_id, self.group_name) handler = FunctionGroupHandler(instance_ids, self._class_descriptor, self._class_methods, self._base_cls, self._function_id, self.need_order, self.group_name, self._is_async, self._instance_name, self._ns) return handler
[docs] def snapshot(self, ttl: int = -1, leave_running: bool = False) -> str: """ Create instance snapshot. This method triggers a snapshot of the current instance state, sending signal 18 (INSTANCE_SNAPSHOT_SIGNAL) through the Kill interface. The snapshot can be used later to restore the instance to this exact state. Args: ttl (int, optional): Time-to-live for the snapshot in seconds. Default is 600 seconds. leave_running (bool, optional): Whether to keep the instance running after snapshot. - If True: Instance continues running after snapshot (online snapshot) - If False: Instance will be terminated after snapshot (offline snapshot) Default to ``False``. Returns: str: The checkpoint ID that uniquely identifies this snapshot. Format: {instanceID}-{functionID}-{uuid} Raises: RuntimeError: If the instance is not active or snapshot operation fails. Example: >>> import yr >>> yr.init() >>> >>> @yr.instance ... class MyInstance: ... def __init__(self): ... self.counter = 0 ... ... def increment(self): ... self.counter += 1 ... ... def __yr_before_snapshot__(self): ... print(f"Preparing snapshot, counter={self.counter}") ... >>> ins = MyInstance.invoke() >>> yr.get(ins.increment.invoke()) >>> checkpoint_id = ins.snapshot(leave_running=False) >>> print(f"Snapshot created: {checkpoint_id}") >>> >>> yr.finalize() """ if not self.is_activate(): raise RuntimeError(f"Instance {self.instance_id} is not active") _logger.info("Creating snapshot for instance %s, leave_running=%s", self.instance_id, leave_running) checkpoint_id = global_runtime.get_runtime().snapshot_instance( self.instance_id, ttl, leave_running) if not leave_running: self.__instance_activate__ = False _logger.info("Snapshot created for instance %s: %s", self.instance_id, checkpoint_id) return checkpoint_id
[docs] def snapstart(self, checkpoint_id: str) -> "InstanceProxy": """ Start a new instance from a snapshot. This method creates a new instance by restoring from a previously created snapshot, sending signal 19 (INSTANCE_SNAPSTART_SIGNAL) through the Kill interface. The new instance will have the same state as when the snapshot was taken. Args: checkpoint_id (str): The checkpoint ID returned by the snapshot() method. Format: {instanceID}-{functionID}-{uuid} Returns: InstanceProxy: A new instance proxy for the restored instance. Raises: RuntimeError: If the checkpoint does not exist or restore operation fails. Example: >>> import yr >>> yr.init() >>> >>> @yr.instance ... class MyInstance: ... def __init__(self): ... self.counter = 0 ... ... def increment(self): ... self.counter += 1 ... return self.counter ... ... def __yr_after_snapstart__(self): ... print(f"Instance restored, counter={self.counter}") ... >>> # Create instance and snapshot >>> ins = MyInstance.invoke() >>> yr.get(ins.increment.invoke()) # counter = 1 >>> checkpoint_id = ins.snapshot(leave_running=False) >>> >>> # Restore from snapshot >>> restored_ins = MyInstance.snapstart(checkpoint_id) >>> result = yr.get(restored_ins.increment.invoke()) # counter = 2 >>> print(f"Counter after restore: {result}") >>> >>> yr.finalize() """ _logger.info("Starting instance from snapshot: %s", checkpoint_id) new_instance_id = global_runtime.get_runtime().snapstart_instance(checkpoint_id) # Create a new InstanceProxy for the restored instance restored_proxy = InstanceProxy( instance_id=new_instance_id, class_descriptor=self._class_descriptor, class_methods=self._class_methods, base_cls=self._base_cls, function_id=self._function_id, need_order=self.need_order, group_name=self.group_name, is_async=self._is_async, instance_name=self._instance_name, namespace=self._ns, code_ref=self._yr_embedded_code_ref ) _logger.info("Instance restored from snapshot %s: %s", checkpoint_id, new_instance_id) return restored_proxy
def __get_method_generator(self, method_name): """ get method generator feature """ if method_name in self._method_descriptor: return self._method_descriptor[method_name].is_generator return False
@register_pack_hook def msgpack_encode_hook(obj): """ register msgpack encode hook """ if isinstance(obj, InstanceProxy): return obj.serialization_(True) return obj @register_unpack_hook def msgpack_decode_hook(obj): """ register msgpack decode hook """ if constants.INSTANCE_ID in obj: return InstanceProxy.deserialization_(obj) return obj
[docs] class MethodProxy: """ Use to decorate a user class method. """
[docs] def __init__(self, instance, instance_id, method_descriptor, sig, return_nums=1, function_id="", is_async=False, instance_name="", namespace="", tensor_transport_target="", enable_tensor_transport=False): """ Initialize the MethodProxy instance. """ self._instance_ref = weakref.ref(instance) self._instance_id = instance_id self._method_descriptor = method_descriptor self._function_id = function_id self._signature = sig self._return_nums = return_nums self._is_async = is_async self._instance_name = instance_name self._ns = namespace self._tensor_transport_target = None self._enable_tensor_transport = False self._tensor_transport_target = tensor_transport_target self._enable_tensor_transport = enable_tensor_transport if return_nums < 0 or return_nums > 100: raise RuntimeError(f"invalid return_nums: {return_nums}, should be an integer between 0 and 100")
[docs] def invoke(self, *args, **kwargs) -> "yr.ObjectRef": """ Execute remote invoke to user functions. Args: *args: Variable arguments, used to pass non-keyword arguments. **kwargs: Variable arguments, used to pass keyword arguments. Returns: Reference to a data object. Data type is ObjectRef. Raises: TypeError: If the parameter type is incorrect. Example: >>> import yr >>> yr.init() >>> >>> @yr.instance ... class Instance: ... sum = 0 ... ... def add(self, a): ... self.sum += a ... ... def get(self): ... return self.sum ... >>> ins = Instance.invoke() >>> yr.get(ins.add.invoke(1)) >>> >>> print(yr.get(ins.get.invoke())) >>> >>> ins.terminate() >>> >>> yr.finalize() """ return self._invoke(args, kwargs)
[docs] def options(self, invoke_options: InvokeOptions): """ Set user invoke options. Args: invoke_options (InvokeOptions): Invoke options for users to set resources. Returns: Method proxy wrapper. Data type is FuncWrapper. """ func_cls = self invoke_options.check_options_valid() class FuncWrapper: """ FuncWrapper wrapper method proxy """ @classmethod def invoke(cls, *args, **kwargs): """ invoke a class method in cluster """ return func_cls._invoke(args, kwargs, invoke_options) return FuncWrapper()
def _invoke(self, args, kwargs, invoke_options=InvokeOptions()): if not self._instance_ref().is_activate(): raise RuntimeError("this instance is terminated") if self._method_descriptor.target_language == LanguageType.Python: args_list = signature.package_args(self._signature, args, kwargs) else: args_list = utils.make_cross_language_args(args, kwargs) func_meta = FunctionMeta(moduleName=self._method_descriptor.module_name, className=self._method_descriptor.class_name, functionName=self._method_descriptor.function_name, language=self._method_descriptor.target_language, functionID=self._function_id, isGenerator=self._method_descriptor.is_generator, isAsync=self._is_async, name=self._instance_name, ns=self._ns, tensorTransportTarget=self._tensor_transport_target, enableTensorTransport=self._enable_tensor_transport) runtime = global_runtime.get_runtime() return_nums = 1 if (self._return_nums == 0 or self._method_descriptor.is_generator) else self._return_nums obj_list = runtime.invoke_instance(func_meta=func_meta, instance_id=self._instance_id, args=args_list, opt=invoke_options, return_nums=return_nums) # each invoke should have its own InvokeOptions, # therefore self.__invoke_options is going to be reset if self._return_nums == 0: return None objref_list = [] for i in obj_list: objref_list.append(ObjectRef(i, need_incre=False, enable_tensor_transport=self._enable_tensor_transport)) if self._method_descriptor.is_generator: return ObjectRefGenerator(objref_list[0]) return objref_list[0] if self._return_nums == 1 else objref_list
def make_decorator(invoke_options=None): """ Make decorator for invoke function """ def decorator(cls): if inspect.isclass(cls): return InstanceCreator.create_from_user_class(cls, invoke_options) raise RuntimeError("@yr.instance decorator must be applied to a class") return decorator def make_cpp_instance_creator(cpp_class): """ Make cpp_instance creator for invoke function """ return InstanceCreator.create_cpp_user_class(cpp_class) def get_instance_by_name(name, namespace, timeout) -> InstanceProxy: """ Get instance by name """ runtime = global_runtime.get_runtime() function_meta = runtime.get_instance_by_name(name, namespace, timeout) if function_meta.language == LanguageType.Python: user_class = None if function_meta.payload: try: ins_package = Serialization().deserialize( buffer_from_bytes(bytes(function_meta.payload))) InstanceManager().init_from_inspackage(ins_package) user_class = InstanceManager().class_code except Exception as exc: _logger.warning( "deserialize recovered instance payload failed, fall back to CodeManager: %s", exc, exc_info=True, ) if user_class is None: _logger.debug(f"pay load of class code is empty, load code of instance: {name} from function meta") user_class = CodeManager().load_code(function_meta, True) user_class_descriptor = utils.ObjectDescriptor.get_from_class(user_class) class_methods = inspect.getmembers(user_class, utils.is_function_or_method) user_class_methods = dict(class_methods) base_cls = inspect.getmro(user_class) yr_ns = runtime.get_namespace() ins_proxy = InstanceProxy(namespace + "-" + name if namespace else yr_ns + "-" + name, user_class_descriptor, user_class_methods, base_cls, "", is_async=function_meta.isAsync, instance_name=name, namespace=namespace) return ins_proxy user_class_descriptor = utils.ObjectDescriptor.get_from_func_meta(function_meta) ins_proxy = InstanceProxy(namespace + "-" + name if namespace else name, user_class_descriptor, None, None, function_id=function_meta.functionID, instance_name=name, namespace=namespace, is_async=function_meta.isAsync) return ins_proxy
[docs] class FunctionGroupMethodProxy: """ Use to invoke instance proxy. """ #: This flag enables shared memory.Default is ``False``. use_shared_memory: bool = False #: Message queue instance used for RPC broadcasting. rpc_broadcast_mq: "MessageQueue"
[docs] def __init__(self, method_name: str, class_descriptor: utils.ObjectDescriptor, proxy_list: List[InstanceProxy]): """Initialization method, used to create instances of a class.""" self._method_name = method_name self._class_descriptor = class_descriptor self._instance_proxy_list = proxy_list
[docs] def invoke(self, *args, **kwargs) -> List[ObjectRef]: """ Perform remote calls to user functions. Returns: A reference to a group of data objects. """ if self.use_shared_memory: task_id = str(uuid.uuid4()) return_objs = [] obj_ids = [] for i in range(len(self._instance_proxy_list)): obj_id = task_id + "_" + str(i) obj_ids.append(obj_id) return_objs.append(ObjectRef(obj_id, need_incre=False)) global_runtime.get_runtime().add_return_object(obj_ids) _logger.debug(f"start to invoke member function: {self._method_name}") self.rpc_broadcast_mq.enqueue((task_id, self._method_name, args, kwargs)) _logger.debug(f"finish to send member function request: {self._method_name}, task id {task_id}") return return_objs for proxy in self._instance_proxy_list: if not hasattr(proxy, self._method_name): raise AttributeError(f"'{self._class_descriptor.class_name}' object has " f"no attribute '{self._method_name}'") result = [] for proxy in self._instance_proxy_list: objs = getattr(proxy, self._method_name).invoke(*args, **kwargs) if isinstance(objs, List): result.extend(objs) else: result.append(objs) return result
[docs] def set_rpc_broadcast_mq(self, rpc_broadcast_mq: "MessageQueue"): """set rpc broadcast message queue.""" self.rpc_broadcast_mq = rpc_broadcast_mq self.use_shared_memory = True
[docs] class FunctionGroupHandler: """ Use to decorate a list of instance proxy. """
[docs] def __init__(self, instance_ids, class_descriptor, class_methods, base_cls, function_id, need_order=True, group_name="", is_async=False, instance_name=None, namespace=None): """ Initialization method, used to create instances of a class. """ self._instance_ids = instance_ids self._group_name = group_name self._class_descriptor = class_descriptor self.__instance_activate__ = True self._class_methods = class_methods self.executor = None self.rpc_broadcast_mq = None if class_methods is not None: for method_name, _ in class_methods.items(): method = FunctionGroupMethodProxy(method_name, class_descriptor, [ InstanceProxy(instance_id, class_descriptor, class_methods, base_cls, function_id, need_order, group_name, is_async, instance_name, namespace) for instance_id in instance_ids ]) setattr(self, method_name, method)
[docs] def terminate(self): """ Terminate the function group. """ if not self.__instance_activate__: return STOP_EVENT.set() if self.executor is not None: self.executor.shutdown(wait=False) global_runtime.get_runtime().terminate_group(self._group_name) self.__instance_activate__ = False
[docs] def accelerate(self): """ Acceleration method, used to perform acceleration operations on local instances. """ is_local = global_runtime.get_runtime().is_local_instances(self._instance_ids) if not is_local: return _logger.debug(f"group all are {len(self._instance_ids)} local instances") fcc_max_chunk_bytes = int(os.environ.get("FCC_MAX_CHUNK_BYTES", 1024 * 1024 * 10)) fcc_max_chunks = int(os.environ.get("FCC_MAX_CHUNKS", 10)) fcc_use_async_loop = True fcc_use_sync_loop = os.environ.get("FCC_USE_SYNC_LOOP") if fcc_use_sync_loop or fcc_use_sync_loop in ['1', 'true', 'True', 'yes']: fcc_use_async_loop = False self.rpc_broadcast_mq = MessageQueue(len(self._instance_ids), fcc_max_chunk_bytes, fcc_max_chunks, fcc_use_async_loop) handle = self.rpc_broadcast_mq.export_handle() global_runtime.get_runtime().accelerate(self._group_name, handle) _logger.debug("finish accelerate") if self._class_methods is not None: for method_name, _ in self._class_methods.items(): method_proxy = getattr(self, method_name) method_proxy.set_rpc_broadcast_mq(self.rpc_broadcast_mq)
# Gradual migration: StatefulInstance and StatefulInstanceCreator are the new preferred names # InstanceProxy and InstanceCreator are kept for backward compatibility StatefulInstance = InstanceProxy StatefulInstanceCreator = InstanceCreator