common.py 1.7 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758
  1. from __future__ import annotations
  2. from pandas import (
  3. DataFrame,
  4. concat,
  5. )
  6. def _check_mixed_float(df, dtype=None):
  7. # float16 are most likely to be upcasted to float32
  8. dtypes = {"A": "float32", "B": "float32", "C": "float16", "D": "float64"}
  9. if isinstance(dtype, str):
  10. dtypes = {k: dtype for k, v in dtypes.items()}
  11. elif isinstance(dtype, dict):
  12. dtypes.update(dtype)
  13. if dtypes.get("A"):
  14. assert df.dtypes["A"] == dtypes["A"]
  15. if dtypes.get("B"):
  16. assert df.dtypes["B"] == dtypes["B"]
  17. if dtypes.get("C"):
  18. assert df.dtypes["C"] == dtypes["C"]
  19. if dtypes.get("D"):
  20. assert df.dtypes["D"] == dtypes["D"]
  21. def _check_mixed_int(df, dtype=None):
  22. dtypes = {"A": "int32", "B": "uint64", "C": "uint8", "D": "int64"}
  23. if isinstance(dtype, str):
  24. dtypes = {k: dtype for k, v in dtypes.items()}
  25. elif isinstance(dtype, dict):
  26. dtypes.update(dtype)
  27. if dtypes.get("A"):
  28. assert df.dtypes["A"] == dtypes["A"]
  29. if dtypes.get("B"):
  30. assert df.dtypes["B"] == dtypes["B"]
  31. if dtypes.get("C"):
  32. assert df.dtypes["C"] == dtypes["C"]
  33. if dtypes.get("D"):
  34. assert df.dtypes["D"] == dtypes["D"]
  35. def zip_frames(frames: list[DataFrame], axis: int = 1) -> DataFrame:
  36. """
  37. take a list of frames, zip them together under the
  38. assumption that these all have the first frames' index/columns.
  39. Returns
  40. -------
  41. new_frame : DataFrame
  42. """
  43. if axis == 1:
  44. columns = frames[0].columns
  45. zipped = [f.loc[:, c] for c in columns for f in frames]
  46. return concat(zipped, axis=1)
  47. else:
  48. index = frames[0].index
  49. zipped = [f.loc[i, :] for i in index for f in frames]
  50. return DataFrame(zipped)