misc.py 15 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552
  1. """Miscellaneous stuff that doesn't really fit anywhere else."""
  2. from typing import List
  3. import operator
  4. import sys
  5. import os
  6. import re as _re
  7. import struct
  8. from textwrap import fill, dedent
  9. class Undecidable(ValueError):
  10. # an error to be raised when a decision cannot be made definitively
  11. # where a definitive answer is needed
  12. pass
  13. def filldedent(s, w=70):
  14. """
  15. Strips leading and trailing empty lines from a copy of `s`, then dedents,
  16. fills and returns it.
  17. Empty line stripping serves to deal with docstrings like this one that
  18. start with a newline after the initial triple quote, inserting an empty
  19. line at the beginning of the string.
  20. See Also
  21. ========
  22. strlines, rawlines
  23. """
  24. return '\n' + fill(dedent(str(s)).strip('\n'), width=w)
  25. def strlines(s, c=64, short=False):
  26. """Return a cut-and-pastable string that, when printed, is
  27. equivalent to the input. The lines will be surrounded by
  28. parentheses and no line will be longer than c (default 64)
  29. characters. If the line contains newlines characters, the
  30. `rawlines` result will be returned. If ``short`` is True
  31. (default is False) then if there is one line it will be
  32. returned without bounding parentheses.
  33. Examples
  34. ========
  35. >>> from sympy.utilities.misc import strlines
  36. >>> q = 'this is a long string that should be broken into shorter lines'
  37. >>> print(strlines(q, 40))
  38. (
  39. 'this is a long string that should be b'
  40. 'roken into shorter lines'
  41. )
  42. >>> q == (
  43. ... 'this is a long string that should be b'
  44. ... 'roken into shorter lines'
  45. ... )
  46. True
  47. See Also
  48. ========
  49. filldedent, rawlines
  50. """
  51. if not isinstance(s, str):
  52. raise ValueError('expecting string input')
  53. if '\n' in s:
  54. return rawlines(s)
  55. q = '"' if repr(s).startswith('"') else "'"
  56. q = (q,)*2
  57. if '\\' in s: # use r-string
  58. m = '(\nr%s%%s%s\n)' % q
  59. j = '%s\nr%s' % q
  60. c -= 3
  61. else:
  62. m = '(\n%s%%s%s\n)' % q
  63. j = '%s\n%s' % q
  64. c -= 2
  65. out = []
  66. while s:
  67. out.append(s[:c])
  68. s=s[c:]
  69. if short and len(out) == 1:
  70. return (m % out[0]).splitlines()[1] # strip bounding (\n...\n)
  71. return m % j.join(out)
  72. def rawlines(s):
  73. """Return a cut-and-pastable string that, when printed, is equivalent
  74. to the input. Use this when there is more than one line in the
  75. string. The string returned is formatted so it can be indented
  76. nicely within tests; in some cases it is wrapped in the dedent
  77. function which has to be imported from textwrap.
  78. Examples
  79. ========
  80. Note: because there are characters in the examples below that need
  81. to be escaped because they are themselves within a triple quoted
  82. docstring, expressions below look more complicated than they would
  83. be if they were printed in an interpreter window.
  84. >>> from sympy.utilities.misc import rawlines
  85. >>> from sympy import TableForm
  86. >>> s = str(TableForm([[1, 10]], headings=(None, ['a', 'bee'])))
  87. >>> print(rawlines(s))
  88. (
  89. 'a bee\\n'
  90. '-----\\n'
  91. '1 10 '
  92. )
  93. >>> print(rawlines('''this
  94. ... that'''))
  95. dedent('''\\
  96. this
  97. that''')
  98. >>> print(rawlines('''this
  99. ... that
  100. ... '''))
  101. dedent('''\\
  102. this
  103. that
  104. ''')
  105. >>> s = \"\"\"this
  106. ... is a triple '''
  107. ... \"\"\"
  108. >>> print(rawlines(s))
  109. dedent(\"\"\"\\
  110. this
  111. is a triple '''
  112. \"\"\")
  113. >>> print(rawlines('''this
  114. ... that
  115. ... '''))
  116. (
  117. 'this\\n'
  118. 'that\\n'
  119. ' '
  120. )
  121. See Also
  122. ========
  123. filldedent, strlines
  124. """
  125. lines = s.split('\n')
  126. if len(lines) == 1:
  127. return repr(lines[0])
  128. triple = ["'''" in s, '"""' in s]
  129. if any(li.endswith(' ') for li in lines) or '\\' in s or all(triple):
  130. rv = []
  131. # add on the newlines
  132. trailing = s.endswith('\n')
  133. last = len(lines) - 1
  134. for i, li in enumerate(lines):
  135. if i != last or trailing:
  136. rv.append(repr(li + '\n'))
  137. else:
  138. rv.append(repr(li))
  139. return '(\n %s\n)' % '\n '.join(rv)
  140. else:
  141. rv = '\n '.join(lines)
  142. if triple[0]:
  143. return 'dedent("""\\\n %s""")' % rv
  144. else:
  145. return "dedent('''\\\n %s''')" % rv
  146. ARCH = str(struct.calcsize('P') * 8) + "-bit"
  147. # XXX: PyPy doesn't support hash randomization
  148. HASH_RANDOMIZATION = getattr(sys.flags, 'hash_randomization', False)
  149. _debug_tmp = [] # type: List[str]
  150. _debug_iter = 0
  151. def debug_decorator(func):
  152. """If SYMPY_DEBUG is True, it will print a nice execution tree with
  153. arguments and results of all decorated functions, else do nothing.
  154. """
  155. from sympy import SYMPY_DEBUG
  156. if not SYMPY_DEBUG:
  157. return func
  158. def maketree(f, *args, **kw):
  159. global _debug_tmp
  160. global _debug_iter
  161. oldtmp = _debug_tmp
  162. _debug_tmp = []
  163. _debug_iter += 1
  164. def tree(subtrees):
  165. def indent(s, variant=1):
  166. x = s.split("\n")
  167. r = "+-%s\n" % x[0]
  168. for a in x[1:]:
  169. if a == "":
  170. continue
  171. if variant == 1:
  172. r += "| %s\n" % a
  173. else:
  174. r += " %s\n" % a
  175. return r
  176. if len(subtrees) == 0:
  177. return ""
  178. f = []
  179. for a in subtrees[:-1]:
  180. f.append(indent(a))
  181. f.append(indent(subtrees[-1], 2))
  182. return ''.join(f)
  183. # If there is a bug and the algorithm enters an infinite loop, enable the
  184. # following lines. It will print the names and parameters of all major functions
  185. # that are called, *before* they are called
  186. #from functools import reduce
  187. #print("%s%s %s%s" % (_debug_iter, reduce(lambda x, y: x + y, \
  188. # map(lambda x: '-', range(1, 2 + _debug_iter))), f.__name__, args))
  189. r = f(*args, **kw)
  190. _debug_iter -= 1
  191. s = "%s%s = %s\n" % (f.__name__, args, r)
  192. if _debug_tmp != []:
  193. s += tree(_debug_tmp)
  194. _debug_tmp = oldtmp
  195. _debug_tmp.append(s)
  196. if _debug_iter == 0:
  197. print(_debug_tmp[0])
  198. _debug_tmp = []
  199. return r
  200. def decorated(*args, **kwargs):
  201. return maketree(func, *args, **kwargs)
  202. return decorated
  203. def debug(*args):
  204. """
  205. Print ``*args`` if SYMPY_DEBUG is True, else do nothing.
  206. """
  207. from sympy import SYMPY_DEBUG
  208. if SYMPY_DEBUG:
  209. print(*args, file=sys.stderr)
  210. def find_executable(executable, path=None):
  211. """Try to find 'executable' in the directories listed in 'path' (a
  212. string listing directories separated by 'os.pathsep'; defaults to
  213. os.environ['PATH']). Returns the complete filename or None if not
  214. found
  215. """
  216. from .exceptions import sympy_deprecation_warning
  217. sympy_deprecation_warning(
  218. """
  219. sympy.utilities.misc.find_executable() is deprecated. Use the standard
  220. library shutil.which() function instead.
  221. """,
  222. deprecated_since_version="1.7",
  223. active_deprecations_target="deprecated-find-executable",
  224. )
  225. if path is None:
  226. path = os.environ['PATH']
  227. paths = path.split(os.pathsep)
  228. extlist = ['']
  229. if os.name == 'os2':
  230. (base, ext) = os.path.splitext(executable)
  231. # executable files on OS/2 can have an arbitrary extension, but
  232. # .exe is automatically appended if no dot is present in the name
  233. if not ext:
  234. executable = executable + ".exe"
  235. elif sys.platform == 'win32':
  236. pathext = os.environ['PATHEXT'].lower().split(os.pathsep)
  237. (base, ext) = os.path.splitext(executable)
  238. if ext.lower() not in pathext:
  239. extlist = pathext
  240. for ext in extlist:
  241. execname = executable + ext
  242. if os.path.isfile(execname):
  243. return execname
  244. else:
  245. for p in paths:
  246. f = os.path.join(p, execname)
  247. if os.path.isfile(f):
  248. return f
  249. return None
  250. def func_name(x, short=False):
  251. """Return function name of `x` (if defined) else the `type(x)`.
  252. If short is True and there is a shorter alias for the result,
  253. return the alias.
  254. Examples
  255. ========
  256. >>> from sympy.utilities.misc import func_name
  257. >>> from sympy import Matrix
  258. >>> from sympy.abc import x
  259. >>> func_name(Matrix.eye(3))
  260. 'MutableDenseMatrix'
  261. >>> func_name(x < 1)
  262. 'StrictLessThan'
  263. >>> func_name(x < 1, short=True)
  264. 'Lt'
  265. """
  266. alias = {
  267. 'GreaterThan': 'Ge',
  268. 'StrictGreaterThan': 'Gt',
  269. 'LessThan': 'Le',
  270. 'StrictLessThan': 'Lt',
  271. 'Equality': 'Eq',
  272. 'Unequality': 'Ne',
  273. }
  274. typ = type(x)
  275. if str(typ).startswith("<type '"):
  276. typ = str(typ).split("'")[1].split("'")[0]
  277. elif str(typ).startswith("<class '"):
  278. typ = str(typ).split("'")[1].split("'")[0]
  279. rv = getattr(getattr(x, 'func', x), '__name__', typ)
  280. if '.' in rv:
  281. rv = rv.split('.')[-1]
  282. if short:
  283. rv = alias.get(rv, rv)
  284. return rv
  285. def _replace(reps):
  286. """Return a function that can make the replacements, given in
  287. ``reps``, on a string. The replacements should be given as mapping.
  288. Examples
  289. ========
  290. >>> from sympy.utilities.misc import _replace
  291. >>> f = _replace(dict(foo='bar', d='t'))
  292. >>> f('food')
  293. 'bart'
  294. >>> f = _replace({})
  295. >>> f('food')
  296. 'food'
  297. """
  298. if not reps:
  299. return lambda x: x
  300. D = lambda match: reps[match.group(0)]
  301. pattern = _re.compile("|".join(
  302. [_re.escape(k) for k, v in reps.items()]), _re.M)
  303. return lambda string: pattern.sub(D, string)
  304. def replace(string, *reps):
  305. """Return ``string`` with all keys in ``reps`` replaced with
  306. their corresponding values, longer strings first, irrespective
  307. of the order they are given. ``reps`` may be passed as tuples
  308. or a single mapping.
  309. Examples
  310. ========
  311. >>> from sympy.utilities.misc import replace
  312. >>> replace('foo', {'oo': 'ar', 'f': 'b'})
  313. 'bar'
  314. >>> replace("spamham sha", ("spam", "eggs"), ("sha","md5"))
  315. 'eggsham md5'
  316. There is no guarantee that a unique answer will be
  317. obtained if keys in a mapping overlap (i.e. are the same
  318. length and have some identical sequence at the
  319. beginning/end):
  320. >>> reps = [
  321. ... ('ab', 'x'),
  322. ... ('bc', 'y')]
  323. >>> replace('abc', *reps) in ('xc', 'ay')
  324. True
  325. References
  326. ==========
  327. .. [1] https://stackoverflow.com/questions/6116978/python-replace-multiple-strings
  328. """
  329. if len(reps) == 1:
  330. kv = reps[0]
  331. if isinstance(kv, dict):
  332. reps = kv
  333. else:
  334. return string.replace(*kv)
  335. else:
  336. reps = dict(reps)
  337. return _replace(reps)(string)
  338. def translate(s, a, b=None, c=None):
  339. """Return ``s`` where characters have been replaced or deleted.
  340. SYNTAX
  341. ======
  342. translate(s, None, deletechars):
  343. all characters in ``deletechars`` are deleted
  344. translate(s, map [,deletechars]):
  345. all characters in ``deletechars`` (if provided) are deleted
  346. then the replacements defined by map are made; if the keys
  347. of map are strings then the longer ones are handled first.
  348. Multicharacter deletions should have a value of ''.
  349. translate(s, oldchars, newchars, deletechars)
  350. all characters in ``deletechars`` are deleted
  351. then each character in ``oldchars`` is replaced with the
  352. corresponding character in ``newchars``
  353. Examples
  354. ========
  355. >>> from sympy.utilities.misc import translate
  356. >>> abc = 'abc'
  357. >>> translate(abc, None, 'a')
  358. 'bc'
  359. >>> translate(abc, {'a': 'x'}, 'c')
  360. 'xb'
  361. >>> translate(abc, {'abc': 'x', 'a': 'y'})
  362. 'x'
  363. >>> translate('abcd', 'ac', 'AC', 'd')
  364. 'AbC'
  365. There is no guarantee that a unique answer will be
  366. obtained if keys in a mapping overlap are the same
  367. length and have some identical sequences at the
  368. beginning/end:
  369. >>> translate(abc, {'ab': 'x', 'bc': 'y'}) in ('xc', 'ay')
  370. True
  371. """
  372. mr = {}
  373. if a is None:
  374. if c is not None:
  375. raise ValueError('c should be None when a=None is passed, instead got %s' % c)
  376. if b is None:
  377. return s
  378. c = b
  379. a = b = ''
  380. else:
  381. if isinstance(a, dict):
  382. short = {}
  383. for k in list(a.keys()):
  384. if len(k) == 1 and len(a[k]) == 1:
  385. short[k] = a.pop(k)
  386. mr = a
  387. c = b
  388. if short:
  389. a, b = [''.join(i) for i in list(zip(*short.items()))]
  390. else:
  391. a = b = ''
  392. elif len(a) != len(b):
  393. raise ValueError('oldchars and newchars have different lengths')
  394. if c:
  395. val = str.maketrans('', '', c)
  396. s = s.translate(val)
  397. s = replace(s, mr)
  398. n = str.maketrans(a, b)
  399. return s.translate(n)
  400. def ordinal(num):
  401. """Return ordinal number string of num, e.g. 1 becomes 1st.
  402. """
  403. # modified from https://codereview.stackexchange.com/questions/41298/producing-ordinal-numbers
  404. n = as_int(num)
  405. k = abs(n) % 100
  406. if 11 <= k <= 13:
  407. suffix = 'th'
  408. elif k % 10 == 1:
  409. suffix = 'st'
  410. elif k % 10 == 2:
  411. suffix = 'nd'
  412. elif k % 10 == 3:
  413. suffix = 'rd'
  414. else:
  415. suffix = 'th'
  416. return str(n) + suffix
  417. def as_int(n, strict=True):
  418. """
  419. Convert the argument to a builtin integer.
  420. The return value is guaranteed to be equal to the input. ValueError is
  421. raised if the input has a non-integral value. When ``strict`` is True, this
  422. uses `__index__ <https://docs.python.org/3/reference/datamodel.html#object.__index__>`_
  423. and when it is False it uses ``int``.
  424. Examples
  425. ========
  426. >>> from sympy.utilities.misc import as_int
  427. >>> from sympy import sqrt, S
  428. The function is primarily concerned with sanitizing input for
  429. functions that need to work with builtin integers, so anything that
  430. is unambiguously an integer should be returned as an int:
  431. >>> as_int(S(3))
  432. 3
  433. Floats, being of limited precision, are not assumed to be exact and
  434. will raise an error unless the ``strict`` flag is False. This
  435. precision issue becomes apparent for large floating point numbers:
  436. >>> big = 1e23
  437. >>> type(big) is float
  438. True
  439. >>> big == int(big)
  440. True
  441. >>> as_int(big)
  442. Traceback (most recent call last):
  443. ...
  444. ValueError: ... is not an integer
  445. >>> as_int(big, strict=False)
  446. 99999999999999991611392
  447. Input that might be a complex representation of an integer value is
  448. also rejected by default:
  449. >>> one = sqrt(3 + 2*sqrt(2)) - sqrt(2)
  450. >>> int(one) == 1
  451. True
  452. >>> as_int(one)
  453. Traceback (most recent call last):
  454. ...
  455. ValueError: ... is not an integer
  456. """
  457. if strict:
  458. try:
  459. if isinstance(n, bool):
  460. raise TypeError
  461. return operator.index(n)
  462. except TypeError:
  463. raise ValueError('%s is not an integer' % (n,))
  464. else:
  465. try:
  466. result = int(n)
  467. except TypeError:
  468. raise ValueError('%s is not an integer' % (n,))
  469. if n != result:
  470. raise ValueError('%s is not an integer' % (n,))
  471. return result