123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439 |
- from typing import Any, Dict as tDict, Tuple as tTuple
- from itertools import product
- import re
- from sympy.core.sympify import sympify
- def mathematica(s, additional_translations=None):
- '''
- Users can add their own translation dictionary.
- variable-length argument needs '*' character.
- Examples
- ========
- >>> from sympy.parsing.mathematica import mathematica
- >>> mathematica('Log3[9]', {'Log3[x]':'log(x,3)'})
- 2
- >>> mathematica('F[7,5,3]', {'F[*x]':'Max(*x)*Min(*x)'})
- 21
- '''
- parser = MathematicaParser(additional_translations)
- return sympify(parser.parse(s))
- def _deco(cls):
- cls._initialize_class()
- return cls
- @_deco
- class MathematicaParser:
- '''An instance of this class converts a string of a basic Mathematica
- expression to SymPy style. Output is string type.'''
- # left: Mathematica, right: SymPy
- CORRESPONDENCES = {
- 'Sqrt[x]': 'sqrt(x)',
- 'Exp[x]': 'exp(x)',
- 'Log[x]': 'log(x)',
- 'Log[x,y]': 'log(y,x)',
- 'Log2[x]': 'log(x,2)',
- 'Log10[x]': 'log(x,10)',
- 'Mod[x,y]': 'Mod(x,y)',
- 'Max[*x]': 'Max(*x)',
- 'Min[*x]': 'Min(*x)',
- 'Pochhammer[x,y]':'rf(x,y)',
- 'ArcTan[x,y]':'atan2(y,x)',
- 'ExpIntegralEi[x]': 'Ei(x)',
- 'SinIntegral[x]': 'Si(x)',
- 'CosIntegral[x]': 'Ci(x)',
- 'AiryAi[x]': 'airyai(x)',
- 'AiryAiPrime[x]': 'airyaiprime(x)',
- 'AiryBi[x]' :'airybi(x)',
- 'AiryBiPrime[x]' :'airybiprime(x)',
- 'LogIntegral[x]':' li(x)',
- 'PrimePi[x]': 'primepi(x)',
- 'Prime[x]': 'prime(x)',
- 'PrimeQ[x]': 'isprime(x)'
- }
- # trigonometric, e.t.c.
- for arc, tri, h in product(('', 'Arc'), (
- 'Sin', 'Cos', 'Tan', 'Cot', 'Sec', 'Csc'), ('', 'h')):
- fm = arc + tri + h + '[x]'
- if arc: # arc func
- fs = 'a' + tri.lower() + h + '(x)'
- else: # non-arc func
- fs = tri.lower() + h + '(x)'
- CORRESPONDENCES.update({fm: fs})
- REPLACEMENTS = {
- ' ': '',
- '^': '**',
- '{': '[',
- '}': ']',
- }
- RULES = {
- # a single whitespace to '*'
- 'whitespace': (
- re.compile(r'''
- (?<=[a-zA-Z\d]) # a letter or a number
- \ # a whitespace
- (?=[a-zA-Z\d]) # a letter or a number
- ''', re.VERBOSE),
- '*'),
- # add omitted '*' character
- 'add*_1': (
- re.compile(r'''
- (?<=[])\d]) # ], ) or a number
- # ''
- (?=[(a-zA-Z]) # ( or a single letter
- ''', re.VERBOSE),
- '*'),
- # add omitted '*' character (variable letter preceding)
- 'add*_2': (
- re.compile(r'''
- (?<=[a-zA-Z]) # a letter
- \( # ( as a character
- (?=.) # any characters
- ''', re.VERBOSE),
- '*('),
- # convert 'Pi' to 'pi'
- 'Pi': (
- re.compile(r'''
- (?:
- \A|(?<=[^a-zA-Z])
- )
- Pi # 'Pi' is 3.14159... in Mathematica
- (?=[^a-zA-Z])
- ''', re.VERBOSE),
- 'pi'),
- }
- # Mathematica function name pattern
- FM_PATTERN = re.compile(r'''
- (?:
- \A|(?<=[^a-zA-Z]) # at the top or a non-letter
- )
- [A-Z][a-zA-Z\d]* # Function
- (?=\[) # [ as a character
- ''', re.VERBOSE)
- # list or matrix pattern (for future usage)
- ARG_MTRX_PATTERN = re.compile(r'''
- \{.*\}
- ''', re.VERBOSE)
- # regex string for function argument pattern
- ARGS_PATTERN_TEMPLATE = r'''
- (?:
- \A|(?<=[^a-zA-Z])
- )
- {arguments} # model argument like x, y,...
- (?=[^a-zA-Z])
- '''
- # will contain transformed CORRESPONDENCES dictionary
- TRANSLATIONS = {} # type: tDict[tTuple[str, int], tDict[str, Any]]
- # cache for a raw users' translation dictionary
- cache_original = {} # type: tDict[tTuple[str, int], tDict[str, Any]]
- # cache for a compiled users' translation dictionary
- cache_compiled = {} # type: tDict[tTuple[str, int], tDict[str, Any]]
- @classmethod
- def _initialize_class(cls):
- # get a transformed CORRESPONDENCES dictionary
- d = cls._compile_dictionary(cls.CORRESPONDENCES)
- cls.TRANSLATIONS.update(d)
- def __init__(self, additional_translations=None):
- self.translations = {}
- # update with TRANSLATIONS (class constant)
- self.translations.update(self.TRANSLATIONS)
- if additional_translations is None:
- additional_translations = {}
- # check the latest added translations
- if self.__class__.cache_original != additional_translations:
- if not isinstance(additional_translations, dict):
- raise ValueError('The argument must be dict type')
- # get a transformed additional_translations dictionary
- d = self._compile_dictionary(additional_translations)
- # update cache
- self.__class__.cache_original = additional_translations
- self.__class__.cache_compiled = d
- # merge user's own translations
- self.translations.update(self.__class__.cache_compiled)
- @classmethod
- def _compile_dictionary(cls, dic):
- # for return
- d = {}
- for fm, fs in dic.items():
- # check function form
- cls._check_input(fm)
- cls._check_input(fs)
- # uncover '*' hiding behind a whitespace
- fm = cls._apply_rules(fm, 'whitespace')
- fs = cls._apply_rules(fs, 'whitespace')
- # remove whitespace(s)
- fm = cls._replace(fm, ' ')
- fs = cls._replace(fs, ' ')
- # search Mathematica function name
- m = cls.FM_PATTERN.search(fm)
- # if no-hit
- if m is None:
- err = "'{f}' function form is invalid.".format(f=fm)
- raise ValueError(err)
- # get Mathematica function name like 'Log'
- fm_name = m.group()
- # get arguments of Mathematica function
- args, end = cls._get_args(m)
- # function side check. (e.g.) '2*Func[x]' is invalid.
- if m.start() != 0 or end != len(fm):
- err = "'{f}' function form is invalid.".format(f=fm)
- raise ValueError(err)
- # check the last argument's 1st character
- if args[-1][0] == '*':
- key_arg = '*'
- else:
- key_arg = len(args)
- key = (fm_name, key_arg)
- # convert '*x' to '\\*x' for regex
- re_args = [x if x[0] != '*' else '\\' + x for x in args]
- # for regex. Example: (?:(x|y|z))
- xyz = '(?:(' + '|'.join(re_args) + '))'
- # string for regex compile
- patStr = cls.ARGS_PATTERN_TEMPLATE.format(arguments=xyz)
- pat = re.compile(patStr, re.VERBOSE)
- # update dictionary
- d[key] = {}
- d[key]['fs'] = fs # SymPy function template
- d[key]['args'] = args # args are ['x', 'y'] for example
- d[key]['pat'] = pat
- return d
- def _convert_function(self, s):
- '''Parse Mathematica function to SymPy one'''
- # compiled regex object
- pat = self.FM_PATTERN
- scanned = '' # converted string
- cur = 0 # position cursor
- while True:
- m = pat.search(s)
- if m is None:
- # append the rest of string
- scanned += s
- break
- # get Mathematica function name
- fm = m.group()
- # get arguments, and the end position of fm function
- args, end = self._get_args(m)
- # the start position of fm function
- bgn = m.start()
- # convert Mathematica function to SymPy one
- s = self._convert_one_function(s, fm, args, bgn, end)
- # update cursor
- cur = bgn
- # append converted part
- scanned += s[:cur]
- # shrink s
- s = s[cur:]
- return scanned
- def _convert_one_function(self, s, fm, args, bgn, end):
- # no variable-length argument
- if (fm, len(args)) in self.translations:
- key = (fm, len(args))
- # x, y,... model arguments
- x_args = self.translations[key]['args']
- # make CORRESPONDENCES between model arguments and actual ones
- d = {k: v for k, v in zip(x_args, args)}
- # with variable-length argument
- elif (fm, '*') in self.translations:
- key = (fm, '*')
- # x, y,..*args (model arguments)
- x_args = self.translations[key]['args']
- # make CORRESPONDENCES between model arguments and actual ones
- d = {}
- for i, x in enumerate(x_args):
- if x[0] == '*':
- d[x] = ','.join(args[i:])
- break
- d[x] = args[i]
- # out of self.translations
- else:
- err = "'{f}' is out of the whitelist.".format(f=fm)
- raise ValueError(err)
- # template string of converted function
- template = self.translations[key]['fs']
- # regex pattern for x_args
- pat = self.translations[key]['pat']
- scanned = ''
- cur = 0
- while True:
- m = pat.search(template)
- if m is None:
- scanned += template
- break
- # get model argument
- x = m.group()
- # get a start position of the model argument
- xbgn = m.start()
- # add the corresponding actual argument
- scanned += template[:xbgn] + d[x]
- # update cursor to the end of the model argument
- cur = m.end()
- # shrink template
- template = template[cur:]
- # update to swapped string
- s = s[:bgn] + scanned + s[end:]
- return s
- @classmethod
- def _get_args(cls, m):
- '''Get arguments of a Mathematica function'''
- s = m.string # whole string
- anc = m.end() + 1 # pointing the first letter of arguments
- square, curly = [], [] # stack for brakets
- args = []
- # current cursor
- cur = anc
- for i, c in enumerate(s[anc:], anc):
- # extract one argument
- if c == ',' and (not square) and (not curly):
- args.append(s[cur:i]) # add an argument
- cur = i + 1 # move cursor
- # handle list or matrix (for future usage)
- if c == '{':
- curly.append(c)
- elif c == '}':
- curly.pop()
- # seek corresponding ']' with skipping irrevant ones
- if c == '[':
- square.append(c)
- elif c == ']':
- if square:
- square.pop()
- else: # empty stack
- args.append(s[cur:i])
- break
- # the next position to ']' bracket (the function end)
- func_end = i + 1
- return args, func_end
- @classmethod
- def _replace(cls, s, bef):
- aft = cls.REPLACEMENTS[bef]
- s = s.replace(bef, aft)
- return s
- @classmethod
- def _apply_rules(cls, s, bef):
- pat, aft = cls.RULES[bef]
- return pat.sub(aft, s)
- @classmethod
- def _check_input(cls, s):
- for bracket in (('[', ']'), ('{', '}'), ('(', ')')):
- if s.count(bracket[0]) != s.count(bracket[1]):
- err = "'{f}' function form is invalid.".format(f=s)
- raise ValueError(err)
- if '{' in s:
- err = "Currently list is not supported."
- raise ValueError(err)
- def parse(self, s):
- # input check
- self._check_input(s)
- # uncover '*' hiding behind a whitespace
- s = self._apply_rules(s, 'whitespace')
- # remove whitespace(s)
- s = self._replace(s, ' ')
- # add omitted '*' character
- s = self._apply_rules(s, 'add*_1')
- s = self._apply_rules(s, 'add*_2')
- # translate function
- s = self._convert_function(s)
- # '^' to '**'
- s = self._replace(s, '^')
- # 'Pi' to 'pi'
- s = self._apply_rules(s, 'Pi')
- # '{', '}' to '[', ']', respectively
- # s = cls._replace(s, '{') # currently list is not taken into account
- # s = cls._replace(s, '}')
- return s
|