classTensorShapeLogger(FunctionWrapper): """ TensorShapeLogger extends FunctionWrapper to log the shapes of torch.Tensor objects. """
@staticmethod def_process_tensor_item(seq: List[Any]) -> Optional[List[Any]]: """ Process a sequence to extract tensor shapes if all items are torch.Tensor. Args: seq (List[Any]): The sequence to process. Returns: Optional[List[Any]]: List of tensor shapes or None if not applicable. """ if torch isnotNoneandall(isinstance(x, torch.Tensor) for x in seq): return [x.shape for x in seq] else: returnNone
defwrap_call(self, func_name: str, frame: FrameType) -> str: """ Format the function call information, including tensor shapes if applicable. Args: func_name (str): Name of the function being called. frame (FrameType): The current stack frame. Returns: str: Formatted call message. """ args, kwargs = self._extract_args_kwargs(frame) call_msg = self._format_args_kwargs(args, kwargs) return call_msg
defwrap_return(self, func_name: str, result: Any) -> str: """ Format the function return information, including tensor shapes if applicable. Args: func_name (str): Name of the function returning. result (Any): The result returned by the function. Returns: str: Formatted return message. """ return_msg = self._format_return(result) return return_msg
defwrap_upd(self, old_value: Any, current_value: Any) -> Tuple[str, str]: """ Format the update information of a variable, including tensor shapes if applicable. Args: old_value (Any): The old value of the variable. current_value (Any): The new value of the variable. Returns: Tuple[str, str]: Formatted old and new values. """ old_msg = self._format_value(old_value) current_msg = self._format_value(current_value) return old_msg, current_msg
def_format_value(self, value: Any, is_return: bool = False) -> str: """ Format a value into a string, logging tensor shapes if applicable. Args: value (Any): The value to format. is_return (bool): Flag indicating if the value is a return value. Returns: str: Formatted value string. """ if torch isnotNoneandisinstance(value, torch.Tensor): formatted = f"{value.shape}" elifisinstance(value, log_element_types): formatted = f"{value}" elifisinstance(value, log_sequence_types): formatted_sequence = EventHandls.format_sequence(value, func=TensorShapeLogger._process_tensor_item) if formatted_sequence: formatted = f"{formatted_sequence}" else: formatted = f"(type){value.__class__.__name__}" else: formatted = f"(type){value.__class__.__name__}"
if is_return: ifisinstance(value, torch.Tensor): returnf"{value.shape}" elifisinstance(value, log_sequence_types) and formatted: returnf"[{formatted}]" returnf"{formatted}" return formatted