mathematica.py 13 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439
  1. from typing import Any, Dict as tDict, Tuple as tTuple
  2. from itertools import product
  3. import re
  4. from sympy.core.sympify import sympify
  5. def mathematica(s, additional_translations=None):
  6. '''
  7. Users can add their own translation dictionary.
  8. variable-length argument needs '*' character.
  9. Examples
  10. ========
  11. >>> from sympy.parsing.mathematica import mathematica
  12. >>> mathematica('Log3[9]', {'Log3[x]':'log(x,3)'})
  13. 2
  14. >>> mathematica('F[7,5,3]', {'F[*x]':'Max(*x)*Min(*x)'})
  15. 21
  16. '''
  17. parser = MathematicaParser(additional_translations)
  18. return sympify(parser.parse(s))
  19. def _deco(cls):
  20. cls._initialize_class()
  21. return cls
  22. @_deco
  23. class MathematicaParser:
  24. '''An instance of this class converts a string of a basic Mathematica
  25. expression to SymPy style. Output is string type.'''
  26. # left: Mathematica, right: SymPy
  27. CORRESPONDENCES = {
  28. 'Sqrt[x]': 'sqrt(x)',
  29. 'Exp[x]': 'exp(x)',
  30. 'Log[x]': 'log(x)',
  31. 'Log[x,y]': 'log(y,x)',
  32. 'Log2[x]': 'log(x,2)',
  33. 'Log10[x]': 'log(x,10)',
  34. 'Mod[x,y]': 'Mod(x,y)',
  35. 'Max[*x]': 'Max(*x)',
  36. 'Min[*x]': 'Min(*x)',
  37. 'Pochhammer[x,y]':'rf(x,y)',
  38. 'ArcTan[x,y]':'atan2(y,x)',
  39. 'ExpIntegralEi[x]': 'Ei(x)',
  40. 'SinIntegral[x]': 'Si(x)',
  41. 'CosIntegral[x]': 'Ci(x)',
  42. 'AiryAi[x]': 'airyai(x)',
  43. 'AiryAiPrime[x]': 'airyaiprime(x)',
  44. 'AiryBi[x]' :'airybi(x)',
  45. 'AiryBiPrime[x]' :'airybiprime(x)',
  46. 'LogIntegral[x]':' li(x)',
  47. 'PrimePi[x]': 'primepi(x)',
  48. 'Prime[x]': 'prime(x)',
  49. 'PrimeQ[x]': 'isprime(x)'
  50. }
  51. # trigonometric, e.t.c.
  52. for arc, tri, h in product(('', 'Arc'), (
  53. 'Sin', 'Cos', 'Tan', 'Cot', 'Sec', 'Csc'), ('', 'h')):
  54. fm = arc + tri + h + '[x]'
  55. if arc: # arc func
  56. fs = 'a' + tri.lower() + h + '(x)'
  57. else: # non-arc func
  58. fs = tri.lower() + h + '(x)'
  59. CORRESPONDENCES.update({fm: fs})
  60. REPLACEMENTS = {
  61. ' ': '',
  62. '^': '**',
  63. '{': '[',
  64. '}': ']',
  65. }
  66. RULES = {
  67. # a single whitespace to '*'
  68. 'whitespace': (
  69. re.compile(r'''
  70. (?<=[a-zA-Z\d]) # a letter or a number
  71. \ # a whitespace
  72. (?=[a-zA-Z\d]) # a letter or a number
  73. ''', re.VERBOSE),
  74. '*'),
  75. # add omitted '*' character
  76. 'add*_1': (
  77. re.compile(r'''
  78. (?<=[])\d]) # ], ) or a number
  79. # ''
  80. (?=[(a-zA-Z]) # ( or a single letter
  81. ''', re.VERBOSE),
  82. '*'),
  83. # add omitted '*' character (variable letter preceding)
  84. 'add*_2': (
  85. re.compile(r'''
  86. (?<=[a-zA-Z]) # a letter
  87. \( # ( as a character
  88. (?=.) # any characters
  89. ''', re.VERBOSE),
  90. '*('),
  91. # convert 'Pi' to 'pi'
  92. 'Pi': (
  93. re.compile(r'''
  94. (?:
  95. \A|(?<=[^a-zA-Z])
  96. )
  97. Pi # 'Pi' is 3.14159... in Mathematica
  98. (?=[^a-zA-Z])
  99. ''', re.VERBOSE),
  100. 'pi'),
  101. }
  102. # Mathematica function name pattern
  103. FM_PATTERN = re.compile(r'''
  104. (?:
  105. \A|(?<=[^a-zA-Z]) # at the top or a non-letter
  106. )
  107. [A-Z][a-zA-Z\d]* # Function
  108. (?=\[) # [ as a character
  109. ''', re.VERBOSE)
  110. # list or matrix pattern (for future usage)
  111. ARG_MTRX_PATTERN = re.compile(r'''
  112. \{.*\}
  113. ''', re.VERBOSE)
  114. # regex string for function argument pattern
  115. ARGS_PATTERN_TEMPLATE = r'''
  116. (?:
  117. \A|(?<=[^a-zA-Z])
  118. )
  119. {arguments} # model argument like x, y,...
  120. (?=[^a-zA-Z])
  121. '''
  122. # will contain transformed CORRESPONDENCES dictionary
  123. TRANSLATIONS = {} # type: tDict[tTuple[str, int], tDict[str, Any]]
  124. # cache for a raw users' translation dictionary
  125. cache_original = {} # type: tDict[tTuple[str, int], tDict[str, Any]]
  126. # cache for a compiled users' translation dictionary
  127. cache_compiled = {} # type: tDict[tTuple[str, int], tDict[str, Any]]
  128. @classmethod
  129. def _initialize_class(cls):
  130. # get a transformed CORRESPONDENCES dictionary
  131. d = cls._compile_dictionary(cls.CORRESPONDENCES)
  132. cls.TRANSLATIONS.update(d)
  133. def __init__(self, additional_translations=None):
  134. self.translations = {}
  135. # update with TRANSLATIONS (class constant)
  136. self.translations.update(self.TRANSLATIONS)
  137. if additional_translations is None:
  138. additional_translations = {}
  139. # check the latest added translations
  140. if self.__class__.cache_original != additional_translations:
  141. if not isinstance(additional_translations, dict):
  142. raise ValueError('The argument must be dict type')
  143. # get a transformed additional_translations dictionary
  144. d = self._compile_dictionary(additional_translations)
  145. # update cache
  146. self.__class__.cache_original = additional_translations
  147. self.__class__.cache_compiled = d
  148. # merge user's own translations
  149. self.translations.update(self.__class__.cache_compiled)
  150. @classmethod
  151. def _compile_dictionary(cls, dic):
  152. # for return
  153. d = {}
  154. for fm, fs in dic.items():
  155. # check function form
  156. cls._check_input(fm)
  157. cls._check_input(fs)
  158. # uncover '*' hiding behind a whitespace
  159. fm = cls._apply_rules(fm, 'whitespace')
  160. fs = cls._apply_rules(fs, 'whitespace')
  161. # remove whitespace(s)
  162. fm = cls._replace(fm, ' ')
  163. fs = cls._replace(fs, ' ')
  164. # search Mathematica function name
  165. m = cls.FM_PATTERN.search(fm)
  166. # if no-hit
  167. if m is None:
  168. err = "'{f}' function form is invalid.".format(f=fm)
  169. raise ValueError(err)
  170. # get Mathematica function name like 'Log'
  171. fm_name = m.group()
  172. # get arguments of Mathematica function
  173. args, end = cls._get_args(m)
  174. # function side check. (e.g.) '2*Func[x]' is invalid.
  175. if m.start() != 0 or end != len(fm):
  176. err = "'{f}' function form is invalid.".format(f=fm)
  177. raise ValueError(err)
  178. # check the last argument's 1st character
  179. if args[-1][0] == '*':
  180. key_arg = '*'
  181. else:
  182. key_arg = len(args)
  183. key = (fm_name, key_arg)
  184. # convert '*x' to '\\*x' for regex
  185. re_args = [x if x[0] != '*' else '\\' + x for x in args]
  186. # for regex. Example: (?:(x|y|z))
  187. xyz = '(?:(' + '|'.join(re_args) + '))'
  188. # string for regex compile
  189. patStr = cls.ARGS_PATTERN_TEMPLATE.format(arguments=xyz)
  190. pat = re.compile(patStr, re.VERBOSE)
  191. # update dictionary
  192. d[key] = {}
  193. d[key]['fs'] = fs # SymPy function template
  194. d[key]['args'] = args # args are ['x', 'y'] for example
  195. d[key]['pat'] = pat
  196. return d
  197. def _convert_function(self, s):
  198. '''Parse Mathematica function to SymPy one'''
  199. # compiled regex object
  200. pat = self.FM_PATTERN
  201. scanned = '' # converted string
  202. cur = 0 # position cursor
  203. while True:
  204. m = pat.search(s)
  205. if m is None:
  206. # append the rest of string
  207. scanned += s
  208. break
  209. # get Mathematica function name
  210. fm = m.group()
  211. # get arguments, and the end position of fm function
  212. args, end = self._get_args(m)
  213. # the start position of fm function
  214. bgn = m.start()
  215. # convert Mathematica function to SymPy one
  216. s = self._convert_one_function(s, fm, args, bgn, end)
  217. # update cursor
  218. cur = bgn
  219. # append converted part
  220. scanned += s[:cur]
  221. # shrink s
  222. s = s[cur:]
  223. return scanned
  224. def _convert_one_function(self, s, fm, args, bgn, end):
  225. # no variable-length argument
  226. if (fm, len(args)) in self.translations:
  227. key = (fm, len(args))
  228. # x, y,... model arguments
  229. x_args = self.translations[key]['args']
  230. # make CORRESPONDENCES between model arguments and actual ones
  231. d = {k: v for k, v in zip(x_args, args)}
  232. # with variable-length argument
  233. elif (fm, '*') in self.translations:
  234. key = (fm, '*')
  235. # x, y,..*args (model arguments)
  236. x_args = self.translations[key]['args']
  237. # make CORRESPONDENCES between model arguments and actual ones
  238. d = {}
  239. for i, x in enumerate(x_args):
  240. if x[0] == '*':
  241. d[x] = ','.join(args[i:])
  242. break
  243. d[x] = args[i]
  244. # out of self.translations
  245. else:
  246. err = "'{f}' is out of the whitelist.".format(f=fm)
  247. raise ValueError(err)
  248. # template string of converted function
  249. template = self.translations[key]['fs']
  250. # regex pattern for x_args
  251. pat = self.translations[key]['pat']
  252. scanned = ''
  253. cur = 0
  254. while True:
  255. m = pat.search(template)
  256. if m is None:
  257. scanned += template
  258. break
  259. # get model argument
  260. x = m.group()
  261. # get a start position of the model argument
  262. xbgn = m.start()
  263. # add the corresponding actual argument
  264. scanned += template[:xbgn] + d[x]
  265. # update cursor to the end of the model argument
  266. cur = m.end()
  267. # shrink template
  268. template = template[cur:]
  269. # update to swapped string
  270. s = s[:bgn] + scanned + s[end:]
  271. return s
  272. @classmethod
  273. def _get_args(cls, m):
  274. '''Get arguments of a Mathematica function'''
  275. s = m.string # whole string
  276. anc = m.end() + 1 # pointing the first letter of arguments
  277. square, curly = [], [] # stack for brakets
  278. args = []
  279. # current cursor
  280. cur = anc
  281. for i, c in enumerate(s[anc:], anc):
  282. # extract one argument
  283. if c == ',' and (not square) and (not curly):
  284. args.append(s[cur:i]) # add an argument
  285. cur = i + 1 # move cursor
  286. # handle list or matrix (for future usage)
  287. if c == '{':
  288. curly.append(c)
  289. elif c == '}':
  290. curly.pop()
  291. # seek corresponding ']' with skipping irrevant ones
  292. if c == '[':
  293. square.append(c)
  294. elif c == ']':
  295. if square:
  296. square.pop()
  297. else: # empty stack
  298. args.append(s[cur:i])
  299. break
  300. # the next position to ']' bracket (the function end)
  301. func_end = i + 1
  302. return args, func_end
  303. @classmethod
  304. def _replace(cls, s, bef):
  305. aft = cls.REPLACEMENTS[bef]
  306. s = s.replace(bef, aft)
  307. return s
  308. @classmethod
  309. def _apply_rules(cls, s, bef):
  310. pat, aft = cls.RULES[bef]
  311. return pat.sub(aft, s)
  312. @classmethod
  313. def _check_input(cls, s):
  314. for bracket in (('[', ']'), ('{', '}'), ('(', ')')):
  315. if s.count(bracket[0]) != s.count(bracket[1]):
  316. err = "'{f}' function form is invalid.".format(f=s)
  317. raise ValueError(err)
  318. if '{' in s:
  319. err = "Currently list is not supported."
  320. raise ValueError(err)
  321. def parse(self, s):
  322. # input check
  323. self._check_input(s)
  324. # uncover '*' hiding behind a whitespace
  325. s = self._apply_rules(s, 'whitespace')
  326. # remove whitespace(s)
  327. s = self._replace(s, ' ')
  328. # add omitted '*' character
  329. s = self._apply_rules(s, 'add*_1')
  330. s = self._apply_rules(s, 'add*_2')
  331. # translate function
  332. s = self._convert_function(s)
  333. # '^' to '**'
  334. s = self._replace(s, '^')
  335. # 'Pi' to 'pi'
  336. s = self._apply_rules(s, 'Pi')
  337. # '{', '}' to '[', ']', respectively
  338. # s = cls._replace(s, '{') # currently list is not taken into account
  339. # s = cls._replace(s, '}')
  340. return s