functions.py 23 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725
  1. from sympy.utilities import dict_merge
  2. from sympy.utilities.iterables import iterable
  3. from sympy.physics.vector import (Dyadic, Vector, ReferenceFrame,
  4. Point, dynamicsymbols)
  5. from sympy.physics.vector.printing import (vprint, vsprint, vpprint, vlatex,
  6. init_vprinting)
  7. from sympy.physics.mechanics.particle import Particle
  8. from sympy.physics.mechanics.rigidbody import RigidBody
  9. from sympy.simplify.simplify import simplify
  10. from sympy.core.backend import (Matrix, sympify, Mul, Derivative, sin, cos,
  11. tan, AppliedUndef, S)
  12. __all__ = ['inertia',
  13. 'inertia_of_point_mass',
  14. 'linear_momentum',
  15. 'angular_momentum',
  16. 'kinetic_energy',
  17. 'potential_energy',
  18. 'Lagrangian',
  19. 'mechanics_printing',
  20. 'mprint',
  21. 'msprint',
  22. 'mpprint',
  23. 'mlatex',
  24. 'msubs',
  25. 'find_dynamicsymbols']
  26. # These are functions that we've moved and renamed during extracting the
  27. # basic vector calculus code from the mechanics packages.
  28. mprint = vprint
  29. msprint = vsprint
  30. mpprint = vpprint
  31. mlatex = vlatex
  32. def mechanics_printing(**kwargs):
  33. """
  34. Initializes time derivative printing for all SymPy objects in
  35. mechanics module.
  36. """
  37. init_vprinting(**kwargs)
  38. mechanics_printing.__doc__ = init_vprinting.__doc__
  39. def inertia(frame, ixx, iyy, izz, ixy=0, iyz=0, izx=0):
  40. """Simple way to create inertia Dyadic object.
  41. Explanation
  42. ===========
  43. If you don't know what a Dyadic is, just treat this like the inertia
  44. tensor. Then, do the easy thing and define it in a body-fixed frame.
  45. Parameters
  46. ==========
  47. frame : ReferenceFrame
  48. The frame the inertia is defined in
  49. ixx : Sympifyable
  50. the xx element in the inertia dyadic
  51. iyy : Sympifyable
  52. the yy element in the inertia dyadic
  53. izz : Sympifyable
  54. the zz element in the inertia dyadic
  55. ixy : Sympifyable
  56. the xy element in the inertia dyadic
  57. iyz : Sympifyable
  58. the yz element in the inertia dyadic
  59. izx : Sympifyable
  60. the zx element in the inertia dyadic
  61. Examples
  62. ========
  63. >>> from sympy.physics.mechanics import ReferenceFrame, inertia
  64. >>> N = ReferenceFrame('N')
  65. >>> inertia(N, 1, 2, 3)
  66. (N.x|N.x) + 2*(N.y|N.y) + 3*(N.z|N.z)
  67. """
  68. if not isinstance(frame, ReferenceFrame):
  69. raise TypeError('Need to define the inertia in a frame')
  70. ol = sympify(ixx) * (frame.x | frame.x)
  71. ol += sympify(ixy) * (frame.x | frame.y)
  72. ol += sympify(izx) * (frame.x | frame.z)
  73. ol += sympify(ixy) * (frame.y | frame.x)
  74. ol += sympify(iyy) * (frame.y | frame.y)
  75. ol += sympify(iyz) * (frame.y | frame.z)
  76. ol += sympify(izx) * (frame.z | frame.x)
  77. ol += sympify(iyz) * (frame.z | frame.y)
  78. ol += sympify(izz) * (frame.z | frame.z)
  79. return ol
  80. def inertia_of_point_mass(mass, pos_vec, frame):
  81. """Inertia dyadic of a point mass relative to point O.
  82. Parameters
  83. ==========
  84. mass : Sympifyable
  85. Mass of the point mass
  86. pos_vec : Vector
  87. Position from point O to point mass
  88. frame : ReferenceFrame
  89. Reference frame to express the dyadic in
  90. Examples
  91. ========
  92. >>> from sympy import symbols
  93. >>> from sympy.physics.mechanics import ReferenceFrame, inertia_of_point_mass
  94. >>> N = ReferenceFrame('N')
  95. >>> r, m = symbols('r m')
  96. >>> px = r * N.x
  97. >>> inertia_of_point_mass(m, px, N)
  98. m*r**2*(N.y|N.y) + m*r**2*(N.z|N.z)
  99. """
  100. return mass * (((frame.x | frame.x) + (frame.y | frame.y) +
  101. (frame.z | frame.z)) * (pos_vec & pos_vec) -
  102. (pos_vec | pos_vec))
  103. def linear_momentum(frame, *body):
  104. """Linear momentum of the system.
  105. Explanation
  106. ===========
  107. This function returns the linear momentum of a system of Particle's and/or
  108. RigidBody's. The linear momentum of a system is equal to the vector sum of
  109. the linear momentum of its constituents. Consider a system, S, comprised of
  110. a rigid body, A, and a particle, P. The linear momentum of the system, L,
  111. is equal to the vector sum of the linear momentum of the particle, L1, and
  112. the linear momentum of the rigid body, L2, i.e.
  113. L = L1 + L2
  114. Parameters
  115. ==========
  116. frame : ReferenceFrame
  117. The frame in which linear momentum is desired.
  118. body1, body2, body3... : Particle and/or RigidBody
  119. The body (or bodies) whose linear momentum is required.
  120. Examples
  121. ========
  122. >>> from sympy.physics.mechanics import Point, Particle, ReferenceFrame
  123. >>> from sympy.physics.mechanics import RigidBody, outer, linear_momentum
  124. >>> N = ReferenceFrame('N')
  125. >>> P = Point('P')
  126. >>> P.set_vel(N, 10 * N.x)
  127. >>> Pa = Particle('Pa', P, 1)
  128. >>> Ac = Point('Ac')
  129. >>> Ac.set_vel(N, 25 * N.y)
  130. >>> I = outer(N.x, N.x)
  131. >>> A = RigidBody('A', Ac, N, 20, (I, Ac))
  132. >>> linear_momentum(N, A, Pa)
  133. 10*N.x + 500*N.y
  134. """
  135. if not isinstance(frame, ReferenceFrame):
  136. raise TypeError('Please specify a valid ReferenceFrame')
  137. else:
  138. linear_momentum_sys = Vector(0)
  139. for e in body:
  140. if isinstance(e, (RigidBody, Particle)):
  141. linear_momentum_sys += e.linear_momentum(frame)
  142. else:
  143. raise TypeError('*body must have only Particle or RigidBody')
  144. return linear_momentum_sys
  145. def angular_momentum(point, frame, *body):
  146. """Angular momentum of a system.
  147. Explanation
  148. ===========
  149. This function returns the angular momentum of a system of Particle's and/or
  150. RigidBody's. The angular momentum of such a system is equal to the vector
  151. sum of the angular momentum of its constituents. Consider a system, S,
  152. comprised of a rigid body, A, and a particle, P. The angular momentum of
  153. the system, H, is equal to the vector sum of the angular momentum of the
  154. particle, H1, and the angular momentum of the rigid body, H2, i.e.
  155. H = H1 + H2
  156. Parameters
  157. ==========
  158. point : Point
  159. The point about which angular momentum of the system is desired.
  160. frame : ReferenceFrame
  161. The frame in which angular momentum is desired.
  162. body1, body2, body3... : Particle and/or RigidBody
  163. The body (or bodies) whose angular momentum is required.
  164. Examples
  165. ========
  166. >>> from sympy.physics.mechanics import Point, Particle, ReferenceFrame
  167. >>> from sympy.physics.mechanics import RigidBody, outer, angular_momentum
  168. >>> N = ReferenceFrame('N')
  169. >>> O = Point('O')
  170. >>> O.set_vel(N, 0 * N.x)
  171. >>> P = O.locatenew('P', 1 * N.x)
  172. >>> P.set_vel(N, 10 * N.x)
  173. >>> Pa = Particle('Pa', P, 1)
  174. >>> Ac = O.locatenew('Ac', 2 * N.y)
  175. >>> Ac.set_vel(N, 5 * N.y)
  176. >>> a = ReferenceFrame('a')
  177. >>> a.set_ang_vel(N, 10 * N.z)
  178. >>> I = outer(N.z, N.z)
  179. >>> A = RigidBody('A', Ac, a, 20, (I, Ac))
  180. >>> angular_momentum(O, N, Pa, A)
  181. 10*N.z
  182. """
  183. if not isinstance(frame, ReferenceFrame):
  184. raise TypeError('Please enter a valid ReferenceFrame')
  185. if not isinstance(point, Point):
  186. raise TypeError('Please specify a valid Point')
  187. else:
  188. angular_momentum_sys = Vector(0)
  189. for e in body:
  190. if isinstance(e, (RigidBody, Particle)):
  191. angular_momentum_sys += e.angular_momentum(point, frame)
  192. else:
  193. raise TypeError('*body must have only Particle or RigidBody')
  194. return angular_momentum_sys
  195. def kinetic_energy(frame, *body):
  196. """Kinetic energy of a multibody system.
  197. Explanation
  198. ===========
  199. This function returns the kinetic energy of a system of Particle's and/or
  200. RigidBody's. The kinetic energy of such a system is equal to the sum of
  201. the kinetic energies of its constituents. Consider a system, S, comprising
  202. a rigid body, A, and a particle, P. The kinetic energy of the system, T,
  203. is equal to the vector sum of the kinetic energy of the particle, T1, and
  204. the kinetic energy of the rigid body, T2, i.e.
  205. T = T1 + T2
  206. Kinetic energy is a scalar.
  207. Parameters
  208. ==========
  209. frame : ReferenceFrame
  210. The frame in which the velocity or angular velocity of the body is
  211. defined.
  212. body1, body2, body3... : Particle and/or RigidBody
  213. The body (or bodies) whose kinetic energy is required.
  214. Examples
  215. ========
  216. >>> from sympy.physics.mechanics import Point, Particle, ReferenceFrame
  217. >>> from sympy.physics.mechanics import RigidBody, outer, kinetic_energy
  218. >>> N = ReferenceFrame('N')
  219. >>> O = Point('O')
  220. >>> O.set_vel(N, 0 * N.x)
  221. >>> P = O.locatenew('P', 1 * N.x)
  222. >>> P.set_vel(N, 10 * N.x)
  223. >>> Pa = Particle('Pa', P, 1)
  224. >>> Ac = O.locatenew('Ac', 2 * N.y)
  225. >>> Ac.set_vel(N, 5 * N.y)
  226. >>> a = ReferenceFrame('a')
  227. >>> a.set_ang_vel(N, 10 * N.z)
  228. >>> I = outer(N.z, N.z)
  229. >>> A = RigidBody('A', Ac, a, 20, (I, Ac))
  230. >>> kinetic_energy(N, Pa, A)
  231. 350
  232. """
  233. if not isinstance(frame, ReferenceFrame):
  234. raise TypeError('Please enter a valid ReferenceFrame')
  235. ke_sys = S.Zero
  236. for e in body:
  237. if isinstance(e, (RigidBody, Particle)):
  238. ke_sys += e.kinetic_energy(frame)
  239. else:
  240. raise TypeError('*body must have only Particle or RigidBody')
  241. return ke_sys
  242. def potential_energy(*body):
  243. """Potential energy of a multibody system.
  244. Explanation
  245. ===========
  246. This function returns the potential energy of a system of Particle's and/or
  247. RigidBody's. The potential energy of such a system is equal to the sum of
  248. the potential energy of its constituents. Consider a system, S, comprising
  249. a rigid body, A, and a particle, P. The potential energy of the system, V,
  250. is equal to the vector sum of the potential energy of the particle, V1, and
  251. the potential energy of the rigid body, V2, i.e.
  252. V = V1 + V2
  253. Potential energy is a scalar.
  254. Parameters
  255. ==========
  256. body1, body2, body3... : Particle and/or RigidBody
  257. The body (or bodies) whose potential energy is required.
  258. Examples
  259. ========
  260. >>> from sympy.physics.mechanics import Point, Particle, ReferenceFrame
  261. >>> from sympy.physics.mechanics import RigidBody, outer, potential_energy
  262. >>> from sympy import symbols
  263. >>> M, m, g, h = symbols('M m g h')
  264. >>> N = ReferenceFrame('N')
  265. >>> O = Point('O')
  266. >>> O.set_vel(N, 0 * N.x)
  267. >>> P = O.locatenew('P', 1 * N.x)
  268. >>> Pa = Particle('Pa', P, m)
  269. >>> Ac = O.locatenew('Ac', 2 * N.y)
  270. >>> a = ReferenceFrame('a')
  271. >>> I = outer(N.z, N.z)
  272. >>> A = RigidBody('A', Ac, a, M, (I, Ac))
  273. >>> Pa.potential_energy = m * g * h
  274. >>> A.potential_energy = M * g * h
  275. >>> potential_energy(Pa, A)
  276. M*g*h + g*h*m
  277. """
  278. pe_sys = S.Zero
  279. for e in body:
  280. if isinstance(e, (RigidBody, Particle)):
  281. pe_sys += e.potential_energy
  282. else:
  283. raise TypeError('*body must have only Particle or RigidBody')
  284. return pe_sys
  285. def gravity(acceleration, *bodies):
  286. """
  287. Returns a list of gravity forces given the acceleration
  288. due to gravity and any number of particles or rigidbodies.
  289. Example
  290. =======
  291. >>> from sympy.physics.mechanics import ReferenceFrame, Point, Particle, outer, RigidBody
  292. >>> from sympy.physics.mechanics.functions import gravity
  293. >>> from sympy import symbols
  294. >>> N = ReferenceFrame('N')
  295. >>> m, M, g = symbols('m M g')
  296. >>> F1, F2 = symbols('F1 F2')
  297. >>> po = Point('po')
  298. >>> pa = Particle('pa', po, m)
  299. >>> A = ReferenceFrame('A')
  300. >>> P = Point('P')
  301. >>> I = outer(A.x, A.x)
  302. >>> B = RigidBody('B', P, A, M, (I, P))
  303. >>> forceList = [(po, F1), (P, F2)]
  304. >>> forceList.extend(gravity(g*N.y, pa, B))
  305. >>> forceList
  306. [(po, F1), (P, F2), (po, g*m*N.y), (P, M*g*N.y)]
  307. """
  308. gravity_force = []
  309. if not bodies:
  310. raise TypeError("No bodies(instances of Particle or Rigidbody) were passed.")
  311. for e in bodies:
  312. point = getattr(e, 'masscenter', None)
  313. if point is None:
  314. point = e.point
  315. gravity_force.append((point, e.mass*acceleration))
  316. return gravity_force
  317. def center_of_mass(point, *bodies):
  318. """
  319. Returns the position vector from the given point to the center of mass
  320. of the given bodies(particles or rigidbodies).
  321. Example
  322. =======
  323. >>> from sympy import symbols, S
  324. >>> from sympy.physics.vector import Point
  325. >>> from sympy.physics.mechanics import Particle, ReferenceFrame, RigidBody, outer
  326. >>> from sympy.physics.mechanics.functions import center_of_mass
  327. >>> a = ReferenceFrame('a')
  328. >>> m = symbols('m', real=True)
  329. >>> p1 = Particle('p1', Point('p1_pt'), S(1))
  330. >>> p2 = Particle('p2', Point('p2_pt'), S(2))
  331. >>> p3 = Particle('p3', Point('p3_pt'), S(3))
  332. >>> p4 = Particle('p4', Point('p4_pt'), m)
  333. >>> b_f = ReferenceFrame('b_f')
  334. >>> b_cm = Point('b_cm')
  335. >>> mb = symbols('mb')
  336. >>> b = RigidBody('b', b_cm, b_f, mb, (outer(b_f.x, b_f.x), b_cm))
  337. >>> p2.point.set_pos(p1.point, a.x)
  338. >>> p3.point.set_pos(p1.point, a.x + a.y)
  339. >>> p4.point.set_pos(p1.point, a.y)
  340. >>> b.masscenter.set_pos(p1.point, a.y + a.z)
  341. >>> point_o=Point('o')
  342. >>> point_o.set_pos(p1.point, center_of_mass(p1.point, p1, p2, p3, p4, b))
  343. >>> expr = 5/(m + mb + 6)*a.x + (m + mb + 3)/(m + mb + 6)*a.y + mb/(m + mb + 6)*a.z
  344. >>> point_o.pos_from(p1.point)
  345. 5/(m + mb + 6)*a.x + (m + mb + 3)/(m + mb + 6)*a.y + mb/(m + mb + 6)*a.z
  346. """
  347. if not bodies:
  348. raise TypeError("No bodies(instances of Particle or Rigidbody) were passed.")
  349. total_mass = 0
  350. vec = Vector(0)
  351. for i in bodies:
  352. total_mass += i.mass
  353. masscenter = getattr(i, 'masscenter', None)
  354. if masscenter is None:
  355. masscenter = i.point
  356. vec += i.mass*masscenter.pos_from(point)
  357. return vec/total_mass
  358. def Lagrangian(frame, *body):
  359. """Lagrangian of a multibody system.
  360. Explanation
  361. ===========
  362. This function returns the Lagrangian of a system of Particle's and/or
  363. RigidBody's. The Lagrangian of such a system is equal to the difference
  364. between the kinetic energies and potential energies of its constituents. If
  365. T and V are the kinetic and potential energies of a system then it's
  366. Lagrangian, L, is defined as
  367. L = T - V
  368. The Lagrangian is a scalar.
  369. Parameters
  370. ==========
  371. frame : ReferenceFrame
  372. The frame in which the velocity or angular velocity of the body is
  373. defined to determine the kinetic energy.
  374. body1, body2, body3... : Particle and/or RigidBody
  375. The body (or bodies) whose Lagrangian is required.
  376. Examples
  377. ========
  378. >>> from sympy.physics.mechanics import Point, Particle, ReferenceFrame
  379. >>> from sympy.physics.mechanics import RigidBody, outer, Lagrangian
  380. >>> from sympy import symbols
  381. >>> M, m, g, h = symbols('M m g h')
  382. >>> N = ReferenceFrame('N')
  383. >>> O = Point('O')
  384. >>> O.set_vel(N, 0 * N.x)
  385. >>> P = O.locatenew('P', 1 * N.x)
  386. >>> P.set_vel(N, 10 * N.x)
  387. >>> Pa = Particle('Pa', P, 1)
  388. >>> Ac = O.locatenew('Ac', 2 * N.y)
  389. >>> Ac.set_vel(N, 5 * N.y)
  390. >>> a = ReferenceFrame('a')
  391. >>> a.set_ang_vel(N, 10 * N.z)
  392. >>> I = outer(N.z, N.z)
  393. >>> A = RigidBody('A', Ac, a, 20, (I, Ac))
  394. >>> Pa.potential_energy = m * g * h
  395. >>> A.potential_energy = M * g * h
  396. >>> Lagrangian(N, Pa, A)
  397. -M*g*h - g*h*m + 350
  398. """
  399. if not isinstance(frame, ReferenceFrame):
  400. raise TypeError('Please supply a valid ReferenceFrame')
  401. for e in body:
  402. if not isinstance(e, (RigidBody, Particle)):
  403. raise TypeError('*body must have only Particle or RigidBody')
  404. return kinetic_energy(frame, *body) - potential_energy(*body)
  405. def find_dynamicsymbols(expression, exclude=None, reference_frame=None):
  406. """Find all dynamicsymbols in expression.
  407. Explanation
  408. ===========
  409. If the optional ``exclude`` kwarg is used, only dynamicsymbols
  410. not in the iterable ``exclude`` are returned.
  411. If we intend to apply this function on a vector, the optional
  412. ``reference_frame`` is also used to inform about the corresponding frame
  413. with respect to which the dynamic symbols of the given vector is to be
  414. determined.
  415. Parameters
  416. ==========
  417. expression : SymPy expression
  418. exclude : iterable of dynamicsymbols, optional
  419. reference_frame : ReferenceFrame, optional
  420. The frame with respect to which the dynamic symbols of the
  421. given vector is to be determined.
  422. Examples
  423. ========
  424. >>> from sympy.physics.mechanics import dynamicsymbols, find_dynamicsymbols
  425. >>> from sympy.physics.mechanics import ReferenceFrame
  426. >>> x, y = dynamicsymbols('x, y')
  427. >>> expr = x + x.diff()*y
  428. >>> find_dynamicsymbols(expr)
  429. {x(t), y(t), Derivative(x(t), t)}
  430. >>> find_dynamicsymbols(expr, exclude=[x, y])
  431. {Derivative(x(t), t)}
  432. >>> a, b, c = dynamicsymbols('a, b, c')
  433. >>> A = ReferenceFrame('A')
  434. >>> v = a * A.x + b * A.y + c * A.z
  435. >>> find_dynamicsymbols(v, reference_frame=A)
  436. {a(t), b(t), c(t)}
  437. """
  438. t_set = {dynamicsymbols._t}
  439. if exclude:
  440. if iterable(exclude):
  441. exclude_set = set(exclude)
  442. else:
  443. raise TypeError("exclude kwarg must be iterable")
  444. else:
  445. exclude_set = set()
  446. if isinstance(expression, Vector):
  447. if reference_frame is None:
  448. raise ValueError("You must provide reference_frame when passing a "
  449. "vector expression, got %s." % reference_frame)
  450. else:
  451. expression = expression.to_matrix(reference_frame)
  452. return {i for i in expression.atoms(AppliedUndef, Derivative) if
  453. i.free_symbols == t_set} - exclude_set
  454. def msubs(expr, *sub_dicts, smart=False, **kwargs):
  455. """A custom subs for use on expressions derived in physics.mechanics.
  456. Traverses the expression tree once, performing the subs found in sub_dicts.
  457. Terms inside ``Derivative`` expressions are ignored:
  458. Examples
  459. ========
  460. >>> from sympy.physics.mechanics import dynamicsymbols, msubs
  461. >>> x = dynamicsymbols('x')
  462. >>> msubs(x.diff() + x, {x: 1})
  463. Derivative(x(t), t) + 1
  464. Note that sub_dicts can be a single dictionary, or several dictionaries:
  465. >>> x, y, z = dynamicsymbols('x, y, z')
  466. >>> sub1 = {x: 1, y: 2}
  467. >>> sub2 = {z: 3, x.diff(): 4}
  468. >>> msubs(x.diff() + x + y + z, sub1, sub2)
  469. 10
  470. If smart=True (default False), also checks for conditions that may result
  471. in ``nan``, but if simplified would yield a valid expression. For example:
  472. >>> from sympy import sin, tan
  473. >>> (sin(x)/tan(x)).subs(x, 0)
  474. nan
  475. >>> msubs(sin(x)/tan(x), {x: 0}, smart=True)
  476. 1
  477. It does this by first replacing all ``tan`` with ``sin/cos``. Then each
  478. node is traversed. If the node is a fraction, subs is first evaluated on
  479. the denominator. If this results in 0, simplification of the entire
  480. fraction is attempted. Using this selective simplification, only
  481. subexpressions that result in 1/0 are targeted, resulting in faster
  482. performance.
  483. """
  484. sub_dict = dict_merge(*sub_dicts)
  485. if smart:
  486. func = _smart_subs
  487. elif hasattr(expr, 'msubs'):
  488. return expr.msubs(sub_dict)
  489. else:
  490. func = lambda expr, sub_dict: _crawl(expr, _sub_func, sub_dict)
  491. if isinstance(expr, (Matrix, Vector, Dyadic)):
  492. return expr.applyfunc(lambda x: func(x, sub_dict))
  493. else:
  494. return func(expr, sub_dict)
  495. def _crawl(expr, func, *args, **kwargs):
  496. """Crawl the expression tree, and apply func to every node."""
  497. val = func(expr, *args, **kwargs)
  498. if val is not None:
  499. return val
  500. new_args = (_crawl(arg, func, *args, **kwargs) for arg in expr.args)
  501. return expr.func(*new_args)
  502. def _sub_func(expr, sub_dict):
  503. """Perform direct matching substitution, ignoring derivatives."""
  504. if expr in sub_dict:
  505. return sub_dict[expr]
  506. elif not expr.args or expr.is_Derivative:
  507. return expr
  508. def _tan_repl_func(expr):
  509. """Replace tan with sin/cos."""
  510. if isinstance(expr, tan):
  511. return sin(*expr.args) / cos(*expr.args)
  512. elif not expr.args or expr.is_Derivative:
  513. return expr
  514. def _smart_subs(expr, sub_dict):
  515. """Performs subs, checking for conditions that may result in `nan` or
  516. `oo`, and attempts to simplify them out.
  517. The expression tree is traversed twice, and the following steps are
  518. performed on each expression node:
  519. - First traverse:
  520. Replace all `tan` with `sin/cos`.
  521. - Second traverse:
  522. If node is a fraction, check if the denominator evaluates to 0.
  523. If so, attempt to simplify it out. Then if node is in sub_dict,
  524. sub in the corresponding value."""
  525. expr = _crawl(expr, _tan_repl_func)
  526. def _recurser(expr, sub_dict):
  527. # Decompose the expression into num, den
  528. num, den = _fraction_decomp(expr)
  529. if den != 1:
  530. # If there is a non trivial denominator, we need to handle it
  531. denom_subbed = _recurser(den, sub_dict)
  532. if denom_subbed.evalf() == 0:
  533. # If denom is 0 after this, attempt to simplify the bad expr
  534. expr = simplify(expr)
  535. else:
  536. # Expression won't result in nan, find numerator
  537. num_subbed = _recurser(num, sub_dict)
  538. return num_subbed / denom_subbed
  539. # We have to crawl the tree manually, because `expr` may have been
  540. # modified in the simplify step. First, perform subs as normal:
  541. val = _sub_func(expr, sub_dict)
  542. if val is not None:
  543. return val
  544. new_args = (_recurser(arg, sub_dict) for arg in expr.args)
  545. return expr.func(*new_args)
  546. return _recurser(expr, sub_dict)
  547. def _fraction_decomp(expr):
  548. """Return num, den such that expr = num/den"""
  549. if not isinstance(expr, Mul):
  550. return expr, 1
  551. num = []
  552. den = []
  553. for a in expr.args:
  554. if a.is_Pow and a.args[1] < 0:
  555. den.append(1 / a)
  556. else:
  557. num.append(a)
  558. if not den:
  559. return expr, 1
  560. num = Mul(*num)
  561. den = Mul(*den)
  562. return num, den
  563. def _f_list_parser(fl, ref_frame):
  564. """Parses the provided forcelist composed of items
  565. of the form (obj, force).
  566. Returns a tuple containing:
  567. vel_list: The velocity (ang_vel for Frames, vel for Points) in
  568. the provided reference frame.
  569. f_list: The forces.
  570. Used internally in the KanesMethod and LagrangesMethod classes.
  571. """
  572. def flist_iter():
  573. for pair in fl:
  574. obj, force = pair
  575. if isinstance(obj, ReferenceFrame):
  576. yield obj.ang_vel_in(ref_frame), force
  577. elif isinstance(obj, Point):
  578. yield obj.vel(ref_frame), force
  579. else:
  580. raise TypeError('First entry in each forcelist pair must '
  581. 'be a point or frame.')
  582. if not fl:
  583. vel_list, f_list = (), ()
  584. else:
  585. unzip = lambda l: list(zip(*l)) if l[0] else [(), ()]
  586. vel_list, f_list = unzip(list(flist_iter()))
  587. return vel_list, f_list