pickle_compat.py 8.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305
  1. """
  2. Support pre-0.12 series pickle compatibility.
  3. """
  4. from __future__ import annotations
  5. import contextlib
  6. import copy
  7. import io
  8. import pickle as pkl
  9. from typing import TYPE_CHECKING
  10. import warnings
  11. import numpy as np
  12. from pandas._libs.arrays import NDArrayBacked
  13. from pandas._libs.tslibs import BaseOffset
  14. from pandas import Index
  15. from pandas.core.arrays import (
  16. DatetimeArray,
  17. PeriodArray,
  18. TimedeltaArray,
  19. )
  20. from pandas.core.internals import BlockManager
  21. if TYPE_CHECKING:
  22. from pandas import (
  23. DataFrame,
  24. Series,
  25. )
  26. def load_reduce(self):
  27. stack = self.stack
  28. args = stack.pop()
  29. func = stack[-1]
  30. if len(args) and type(args[0]) is type:
  31. n = args[0].__name__ # noqa
  32. try:
  33. stack[-1] = func(*args)
  34. return
  35. except TypeError as err:
  36. # If we have a deprecated function,
  37. # try to replace and try again.
  38. msg = "_reconstruct: First argument must be a sub-type of ndarray"
  39. if msg in str(err):
  40. try:
  41. cls = args[0]
  42. stack[-1] = object.__new__(cls)
  43. return
  44. except TypeError:
  45. pass
  46. elif args and isinstance(args[0], type) and issubclass(args[0], BaseOffset):
  47. # TypeError: object.__new__(Day) is not safe, use Day.__new__()
  48. cls = args[0]
  49. stack[-1] = cls.__new__(*args)
  50. return
  51. elif args and issubclass(args[0], PeriodArray):
  52. cls = args[0]
  53. stack[-1] = NDArrayBacked.__new__(*args)
  54. return
  55. raise
  56. _sparse_msg = """\
  57. Loading a saved '{cls}' as a {new} with sparse values.
  58. '{cls}' is now removed. You should re-save this dataset in its new format.
  59. """
  60. class _LoadSparseSeries:
  61. # To load a SparseSeries as a Series[Sparse]
  62. # https://github.com/python/mypy/issues/1020
  63. # error: Incompatible return type for "__new__" (returns "Series", but must return
  64. # a subtype of "_LoadSparseSeries")
  65. def __new__(cls) -> Series: # type: ignore[misc]
  66. from pandas import Series
  67. warnings.warn(
  68. _sparse_msg.format(cls="SparseSeries", new="Series"),
  69. FutureWarning,
  70. stacklevel=6,
  71. )
  72. return Series(dtype=object)
  73. class _LoadSparseFrame:
  74. # To load a SparseDataFrame as a DataFrame[Sparse]
  75. # https://github.com/python/mypy/issues/1020
  76. # error: Incompatible return type for "__new__" (returns "DataFrame", but must
  77. # return a subtype of "_LoadSparseFrame")
  78. def __new__(cls) -> DataFrame: # type: ignore[misc]
  79. from pandas import DataFrame
  80. warnings.warn(
  81. _sparse_msg.format(cls="SparseDataFrame", new="DataFrame"),
  82. FutureWarning,
  83. stacklevel=6,
  84. )
  85. return DataFrame()
  86. # If classes are moved, provide compat here.
  87. _class_locations_map = {
  88. ("pandas.core.sparse.array", "SparseArray"): ("pandas.core.arrays", "SparseArray"),
  89. # 15477
  90. ("pandas.core.base", "FrozenNDArray"): ("numpy", "ndarray"),
  91. ("pandas.core.indexes.frozen", "FrozenNDArray"): ("numpy", "ndarray"),
  92. ("pandas.core.base", "FrozenList"): ("pandas.core.indexes.frozen", "FrozenList"),
  93. # 10890
  94. ("pandas.core.series", "TimeSeries"): ("pandas.core.series", "Series"),
  95. ("pandas.sparse.series", "SparseTimeSeries"): (
  96. "pandas.core.sparse.series",
  97. "SparseSeries",
  98. ),
  99. # 12588, extensions moving
  100. ("pandas._sparse", "BlockIndex"): ("pandas._libs.sparse", "BlockIndex"),
  101. ("pandas.tslib", "Timestamp"): ("pandas._libs.tslib", "Timestamp"),
  102. # 18543 moving period
  103. ("pandas._period", "Period"): ("pandas._libs.tslibs.period", "Period"),
  104. ("pandas._libs.period", "Period"): ("pandas._libs.tslibs.period", "Period"),
  105. # 18014 moved __nat_unpickle from _libs.tslib-->_libs.tslibs.nattype
  106. ("pandas.tslib", "__nat_unpickle"): (
  107. "pandas._libs.tslibs.nattype",
  108. "__nat_unpickle",
  109. ),
  110. ("pandas._libs.tslib", "__nat_unpickle"): (
  111. "pandas._libs.tslibs.nattype",
  112. "__nat_unpickle",
  113. ),
  114. # 15998 top-level dirs moving
  115. ("pandas.sparse.array", "SparseArray"): (
  116. "pandas.core.arrays.sparse",
  117. "SparseArray",
  118. ),
  119. ("pandas.sparse.series", "SparseSeries"): (
  120. "pandas.compat.pickle_compat",
  121. "_LoadSparseSeries",
  122. ),
  123. ("pandas.sparse.frame", "SparseDataFrame"): (
  124. "pandas.core.sparse.frame",
  125. "_LoadSparseFrame",
  126. ),
  127. ("pandas.indexes.base", "_new_Index"): ("pandas.core.indexes.base", "_new_Index"),
  128. ("pandas.indexes.base", "Index"): ("pandas.core.indexes.base", "Index"),
  129. ("pandas.indexes.numeric", "Int64Index"): (
  130. "pandas.core.indexes.numeric",
  131. "Int64Index",
  132. ),
  133. ("pandas.indexes.range", "RangeIndex"): ("pandas.core.indexes.range", "RangeIndex"),
  134. ("pandas.indexes.multi", "MultiIndex"): ("pandas.core.indexes.multi", "MultiIndex"),
  135. ("pandas.tseries.index", "_new_DatetimeIndex"): (
  136. "pandas.core.indexes.datetimes",
  137. "_new_DatetimeIndex",
  138. ),
  139. ("pandas.tseries.index", "DatetimeIndex"): (
  140. "pandas.core.indexes.datetimes",
  141. "DatetimeIndex",
  142. ),
  143. ("pandas.tseries.period", "PeriodIndex"): (
  144. "pandas.core.indexes.period",
  145. "PeriodIndex",
  146. ),
  147. # 19269, arrays moving
  148. ("pandas.core.categorical", "Categorical"): ("pandas.core.arrays", "Categorical"),
  149. # 19939, add timedeltaindex, float64index compat from 15998 move
  150. ("pandas.tseries.tdi", "TimedeltaIndex"): (
  151. "pandas.core.indexes.timedeltas",
  152. "TimedeltaIndex",
  153. ),
  154. ("pandas.indexes.numeric", "Float64Index"): (
  155. "pandas.core.indexes.numeric",
  156. "Float64Index",
  157. ),
  158. ("pandas.core.sparse.series", "SparseSeries"): (
  159. "pandas.compat.pickle_compat",
  160. "_LoadSparseSeries",
  161. ),
  162. ("pandas.core.sparse.frame", "SparseDataFrame"): (
  163. "pandas.compat.pickle_compat",
  164. "_LoadSparseFrame",
  165. ),
  166. }
  167. # our Unpickler sub-class to override methods and some dispatcher
  168. # functions for compat and uses a non-public class of the pickle module.
  169. # error: Name 'pkl._Unpickler' is not defined
  170. class Unpickler(pkl._Unpickler): # type: ignore[name-defined]
  171. def find_class(self, module, name):
  172. # override superclass
  173. key = (module, name)
  174. module, name = _class_locations_map.get(key, key)
  175. return super().find_class(module, name)
  176. Unpickler.dispatch = copy.copy(Unpickler.dispatch)
  177. Unpickler.dispatch[pkl.REDUCE[0]] = load_reduce
  178. def load_newobj(self):
  179. args = self.stack.pop()
  180. cls = self.stack[-1]
  181. # compat
  182. if issubclass(cls, Index):
  183. obj = object.__new__(cls)
  184. elif issubclass(cls, DatetimeArray) and not args:
  185. arr = np.array([], dtype="M8[ns]")
  186. obj = cls.__new__(cls, arr, arr.dtype)
  187. elif issubclass(cls, TimedeltaArray) and not args:
  188. arr = np.array([], dtype="m8[ns]")
  189. obj = cls.__new__(cls, arr, arr.dtype)
  190. elif cls is BlockManager and not args:
  191. obj = cls.__new__(cls, (), [], False)
  192. else:
  193. obj = cls.__new__(cls, *args)
  194. self.stack[-1] = obj
  195. Unpickler.dispatch[pkl.NEWOBJ[0]] = load_newobj
  196. def load_newobj_ex(self):
  197. kwargs = self.stack.pop()
  198. args = self.stack.pop()
  199. cls = self.stack.pop()
  200. # compat
  201. if issubclass(cls, Index):
  202. obj = object.__new__(cls)
  203. else:
  204. obj = cls.__new__(cls, *args, **kwargs)
  205. self.append(obj)
  206. try:
  207. Unpickler.dispatch[pkl.NEWOBJ_EX[0]] = load_newobj_ex
  208. except (AttributeError, KeyError):
  209. pass
  210. def load(fh, encoding: str | None = None, is_verbose: bool = False):
  211. """
  212. Load a pickle, with a provided encoding,
  213. Parameters
  214. ----------
  215. fh : a filelike object
  216. encoding : an optional encoding
  217. is_verbose : show exception output
  218. """
  219. try:
  220. fh.seek(0)
  221. if encoding is not None:
  222. up = Unpickler(fh, encoding=encoding)
  223. else:
  224. up = Unpickler(fh)
  225. up.is_verbose = is_verbose
  226. return up.load()
  227. except (ValueError, TypeError):
  228. raise
  229. def loads(
  230. bytes_object: bytes,
  231. *,
  232. fix_imports: bool = True,
  233. encoding: str = "ASCII",
  234. errors: str = "strict",
  235. ):
  236. """
  237. Analogous to pickle._loads.
  238. """
  239. fd = io.BytesIO(bytes_object)
  240. return Unpickler(
  241. fd, fix_imports=fix_imports, encoding=encoding, errors=errors
  242. ).load()
  243. @contextlib.contextmanager
  244. def patch_pickle():
  245. """
  246. Temporarily patch pickle to use our unpickler.
  247. """
  248. orig_loads = pkl.loads
  249. try:
  250. setattr(pkl, "loads", loads)
  251. yield
  252. finally:
  253. setattr(pkl, "loads", orig_loads)