"""gRPC servicers for each version of the measurement service interface."""
from __future__ import annotations
import collections.abc
import contextlib
import inspect
import pathlib
import weakref
from contextvars import ContextVar
from typing import Any, Callable, Dict, Generator, List, Optional
import grpc
from google.protobuf import any_pb2
from ni_measurementlink_service._internal.parameter import serializer
from ni_measurementlink_service._internal.parameter.metadata import ParameterMetadata
from ni_measurementlink_service._internal.stubs.ni.measurementlink.measurement.v1 import (
measurement_service_pb2 as v1_measurement_service_pb2,
measurement_service_pb2_grpc as v1_measurement_service_pb2_grpc,
)
from ni_measurementlink_service._internal.stubs.ni.measurementlink.measurement.v2 import (
measurement_service_pb2 as v2_measurement_service_pb2,
measurement_service_pb2_grpc as v2_measurement_service_pb2_grpc,
)
from ni_measurementlink_service.measurement.info import MeasurementInfo
from ni_measurementlink_service.session_management import PinMapContext
[docs]
class MeasurementServiceContext:
"""Accessor for the measurement service's context-local state."""
def __init__(
self,
grpc_context: grpc.ServicerContext,
pin_map_context: PinMapContext,
owner: Optional[weakref.ReferenceType[object]],
) -> None:
"""Initialize the measurement service context."""
self._grpc_context = grpc_context
self._pin_map_context = pin_map_context
self._is_complete = False
self._owner = owner
[docs]
def mark_complete(self) -> None:
"""Mark the current RPC as complete."""
self._is_complete = True
@property
def grpc_context(self) -> grpc.ServicerContext:
"""Get the context for the RPC."""
return self._grpc_context
@property
def owner(self) -> object:
"""The owner of the server (e.g. measurement service)."""
if self._owner is None:
return None
return self._owner() # dereference weak ref
@property
def pin_map_context(self) -> PinMapContext:
"""Get the pin map context for the RPC."""
return self._pin_map_context
[docs]
def add_cancel_callback(self, cancel_callback: Callable[[], None]) -> None:
"""Add a callback that is invoked when the RPC is canceled."""
def grpc_callback() -> None:
if not self._is_complete:
cancel_callback()
self._grpc_context.add_callback(grpc_callback)
[docs]
def cancel(self) -> None:
"""Cancel the RPC."""
if not self._is_complete:
self._grpc_context.cancel()
@property
def time_remaining(self) -> float:
"""Get the time remaining for the RPC."""
return self._grpc_context.time_remaining()
[docs]
def abort(self, code: grpc.StatusCode, details: str) -> None:
"""Aborts the RPC."""
try:
self._grpc_context.abort(code, details)
except Exception as e:
# If gRPC raises an empty exception, replace it with an RpcError.
# This allows our logging interceptors to query the code/details.
if type(e) is Exception:
raise CustomRpcError(code, details) from e
raise
[docs]
class CustomRpcError(grpc.RpcError):
"""A custom exception class for handling gRPC RPC errors.
gRPC's built-in RpcError is not directly configurable in Python, so this class
enables the creation of custom RPC errors with specific error codes.
"""
def __init__(self, code: grpc.StatusCode, details: str) -> None:
"""Initialize a CustomRpcError instance."""
self._code = code
self._details = details
[docs]
def code(self) -> grpc.StatusCode:
"""Get the gRPC status code."""
return self._code
[docs]
def details(self) -> str:
"""Get the gRPC status details."""
return self._details
measurement_service_context: ContextVar[MeasurementServiceContext] = ContextVar(
"measurement_service_context"
)
def _get_mapping_by_parameter_name(
mapping_by_id: Dict[int, Any], measure_function: Callable[[], None]
) -> Dict[str, Any]:
"""Transform a mapping by id into a mapping by parameter name (i.e. kwargs)."""
signature = inspect.signature(measure_function)
mapping_by_variable_name = {}
for i, parameter in enumerate(signature.parameters.values(), start=1):
mapping_by_variable_name[parameter.name] = mapping_by_id[i]
return mapping_by_variable_name
def _serialize_outputs(output_metadata: Dict[int, ParameterMetadata], outputs: Any) -> any_pb2.Any:
if isinstance(outputs, collections.abc.Sequence):
return any_pb2.Any(value=serializer.serialize_parameters(output_metadata, outputs))
elif outputs is None:
raise ValueError(f"Measurement function returned None")
else:
raise TypeError(
f"Measurement function returned value with unsupported type: {type(outputs)}"
)
def _frame_metadata_dict(
parameter_list: List[ParameterMetadata],
) -> Dict[int, ParameterMetadata]:
metadata_dict = {}
for i, parameter in enumerate(parameter_list, start=1):
metadata_dict[i] = parameter
return metadata_dict
[docs]
class MeasurementServiceServicerV1(v1_measurement_service_pb2_grpc.MeasurementServiceServicer):
"""Measurement v1 servicer."""
def __init__(
self,
measurement_info: MeasurementInfo,
configuration_parameter_list: List[ParameterMetadata],
output_parameter_list: List[ParameterMetadata],
measure_function: Callable,
owner: object,
) -> None:
"""Initialize the measurement v1 servicer."""
super().__init__()
self._configuration_metadata = _frame_metadata_dict(configuration_parameter_list)
self._output_metadata = _frame_metadata_dict(output_parameter_list)
self._measurement_info = measurement_info
self._measure_function = measure_function
self._owner = weakref.ref(owner) if owner is not None else None # avoid reference cycle
[docs]
def Measure( # noqa: N802 - function name should be lowercase
self, request: v1_measurement_service_pb2.MeasureRequest, context: grpc.ServicerContext
) -> v1_measurement_service_pb2.MeasureResponse:
"""RPC API that executes the registered measurement method."""
mapping_by_id = serializer.deserialize_parameters(
self._configuration_metadata, request.configuration_parameters.value
)
mapping_by_variable_name = _get_mapping_by_parameter_name(
mapping_by_id, self._measure_function
)
pin_map_context = PinMapContext._from_grpc(request.pin_map_context)
token = measurement_service_context.set(
MeasurementServiceContext(context, pin_map_context, self._owner)
)
try:
return_value = self._measure_function(**mapping_by_variable_name)
if isinstance(return_value, collections.abc.Generator):
with contextlib.closing(return_value) as output_iter:
outputs = None
try:
while True:
outputs = next(output_iter)
except StopIteration as e:
if e.value is not None:
outputs = e.value
return self._serialize_response(outputs)
else:
return self._serialize_response(return_value)
finally:
measurement_service_context.get().mark_complete()
measurement_service_context.reset(token)
def _serialize_response(self, outputs: Any) -> v1_measurement_service_pb2.MeasureResponse:
return v1_measurement_service_pb2.MeasureResponse(
outputs=_serialize_outputs(self._output_metadata, outputs)
)
[docs]
class MeasurementServiceServicerV2(v2_measurement_service_pb2_grpc.MeasurementServiceServicer):
"""Measurement v2 servicer."""
def __init__(
self,
measurement_info: MeasurementInfo,
configuration_parameter_list: List[ParameterMetadata],
output_parameter_list: List[ParameterMetadata],
measure_function: Callable,
owner: object,
) -> None:
"""Initialize the measurement v2 servicer."""
super().__init__()
self._configuration_metadata = _frame_metadata_dict(configuration_parameter_list)
self._output_metadata = _frame_metadata_dict(output_parameter_list)
self._measurement_info = measurement_info
self._measure_function = measure_function
self._owner = weakref.ref(owner) if owner is not None else None # avoid reference cycle
[docs]
def Measure( # noqa: N802 - function name should be lowercase
self, request: v2_measurement_service_pb2.MeasureRequest, context: grpc.ServicerContext
) -> Generator[v2_measurement_service_pb2.MeasureResponse, None, None]:
"""RPC API that executes the registered measurement method."""
mapping_by_id = serializer.deserialize_parameters(
self._configuration_metadata, request.configuration_parameters.value
)
mapping_by_variable_name = _get_mapping_by_parameter_name(
mapping_by_id, self._measure_function
)
pin_map_context = PinMapContext._from_grpc(request.pin_map_context)
token = measurement_service_context.set(
MeasurementServiceContext(context, pin_map_context, self._owner)
)
try:
return_value = self._measure_function(**mapping_by_variable_name)
if isinstance(return_value, collections.abc.Generator):
with contextlib.closing(return_value) as output_iter:
try:
while True:
outputs = next(output_iter)
yield self._serialize_response(outputs)
except StopIteration as e:
if e.value is not None:
yield self._serialize_response(e.value)
else:
yield self._serialize_response(return_value)
finally:
measurement_service_context.get().mark_complete()
measurement_service_context.reset(token)
def _serialize_response(self, outputs: Any) -> v2_measurement_service_pb2.MeasureResponse:
return v2_measurement_service_pb2.MeasureResponse(
outputs=_serialize_outputs(self._output_metadata, outputs)
)