openpmd.py 20 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527
  1. # Authors: Berk Geveci, Axel Huebl, Utkarsh Ayachit
  2. #
  3. from vtkmodules.util.vtkAlgorithm import VTKPythonAlgorithmBase
  4. from .. import print_error
  5. try:
  6. import numpy as np
  7. import openpmd_api as io
  8. _has_openpmd = True
  9. except ImportError as ie:
  10. print_error(
  11. "Missing required Python modules/packages. Algorithms in this module may "
  12. "not work as expected! \n {0}".format(ie))
  13. _has_openpmd = False
  14. def createModifiedCallback(anobject):
  15. import weakref
  16. weakref_obj = weakref.ref(anobject)
  17. anobject = None
  18. def _markmodified(*args, **kwars):
  19. o = weakref_obj()
  20. if o is not None:
  21. o.Modified()
  22. return _markmodified
  23. class openPMDReader(VTKPythonAlgorithmBase):
  24. """A reader that reads openPMD format."""
  25. def __init__(self):
  26. VTKPythonAlgorithmBase.__init__(self, nInputPorts=0, nOutputPorts=2, outputType='vtkPartitionedDataSet')
  27. self._filename = None
  28. self._timevalues = None
  29. self._series = None
  30. self._timemap = {}
  31. from vtkmodules.vtkCommonCore import vtkDataArraySelection
  32. self._arrayselection = vtkDataArraySelection()
  33. self._arrayselection.AddObserver("ModifiedEvent", createModifiedCallback(self))
  34. self._speciesselection = vtkDataArraySelection()
  35. self._speciesselection.AddObserver("ModifiedEvent", createModifiedCallback(self))
  36. self._particlearrayselection = vtkDataArraySelection()
  37. self._particlearrayselection.AddObserver("ModifiedEvent", createModifiedCallback(self))
  38. def _get_update_time(self, outInfo):
  39. from vtkmodules.vtkCommonExecutionModel import vtkStreamingDemandDrivenPipeline
  40. executive = vtkStreamingDemandDrivenPipeline
  41. timevalues = self._timevalues
  42. if timevalues is None or len(timevalues) == 0:
  43. return None
  44. elif outInfo.Has(executive.UPDATE_TIME_STEP()) and len(timevalues) > 0:
  45. utime = outInfo.Get(executive.UPDATE_TIME_STEP())
  46. dtime = timevalues[0]
  47. for atime in timevalues:
  48. if atime > utime:
  49. return dtime
  50. else:
  51. dtime = atime
  52. return dtime
  53. else:
  54. assert(len(timevalues) > 0)
  55. return timevalues[0]
  56. def _get_array_selection(self):
  57. return self._arrayselection
  58. def _get_particle_array_selection(self):
  59. return self._particlearrayselection
  60. def _get_species_selection(self):
  61. return self._speciesselection
  62. def SetFileName(self, name):
  63. """Specify filename for the file to read."""
  64. if self._filename != name:
  65. self._filename = name
  66. self._timevalues = None
  67. if self._series:
  68. self._series = None
  69. self.Modified()
  70. def GetTimestepValues(self):
  71. return self._timevalues()
  72. def GetDataArraySelection(self):
  73. return self._get_array_selection()
  74. def GetSpeciesSelection(self):
  75. return self._get_species_selection()
  76. def GetParticleArraySelection(self):
  77. return self._get_particle_array_selection()
  78. def FillOutputPortInformation(self, port, info):
  79. from vtkmodules.vtkCommonDataModel import vtkDataObject
  80. if port == 0:
  81. info.Set(vtkDataObject.DATA_TYPE_NAME(), "vtkPartitionedDataSet")
  82. else:
  83. info.Set(vtkDataObject.DATA_TYPE_NAME(), "vtkPartitionedDataSetCollection")
  84. return 1
  85. def RequestInformation(self, request, inInfoVec, outInfoVec):
  86. global _has_openpmd
  87. if not _has_openpmd:
  88. print_error("Required Python module 'openpmd_api' missing!")
  89. return 0
  90. from vtkmodules.vtkCommonExecutionModel import vtkStreamingDemandDrivenPipeline, vtkAlgorithm
  91. executive = vtkStreamingDemandDrivenPipeline
  92. for i in (0,1):
  93. outInfo = outInfoVec.GetInformationObject(i)
  94. outInfo.Remove(executive.TIME_STEPS())
  95. outInfo.Remove(executive.TIME_RANGE())
  96. outInfo.Set(vtkAlgorithm.CAN_HANDLE_PIECE_REQUEST(), 1)
  97. # Why is this a string when it is None?
  98. if self._filename == 'None':
  99. return 1
  100. mfile = open(self._filename, "r")
  101. pattern = mfile.readlines()[0][0:-1]
  102. del mfile
  103. import os
  104. if not self._series:
  105. self._series = io.Series(os.path.dirname(self._filename)+'/'+pattern, io.Access_Type.read_only)
  106. # This is how we get time values and arrays
  107. self._timemap = {}
  108. timevalues = []
  109. arrays = set()
  110. particles = set()
  111. species = set()
  112. for idx, iteration in self._series.iterations.items():
  113. # extract the time
  114. if callable(iteration.time): # prior to openPMD-api 0.13.0
  115. time = iteration.time() * iteration.time_unit_SI()
  116. else:
  117. time = iteration.time * iteration.time_unit_SI
  118. timevalues.append(time)
  119. self._timemap[time] = idx
  120. arrays.update([
  121. mesh_name
  122. for mesh_name, mesh in iteration.meshes.items()])
  123. particles.update([
  124. species_name + "_" + record_name
  125. for species_name, species in iteration.particles.items()
  126. for record_name, record in species.items()
  127. ])
  128. species.update([
  129. species_name
  130. for species_name, _ in iteration.particles.items()
  131. ])
  132. for array in arrays:
  133. self._arrayselection.AddArray(array)
  134. for particle_array in particles:
  135. self._particlearrayselection.AddArray(particle_array)
  136. for species_name in species:
  137. self._speciesselection.AddArray(species_name)
  138. timesteps = list(self._series.iterations)
  139. self._timevalues = timevalues
  140. if len(timevalues) > 0:
  141. for i in (0,1):
  142. outInfo = outInfoVec.GetInformationObject(i)
  143. for t in timevalues:
  144. outInfo.Append(executive.TIME_STEPS(), t)
  145. outInfo.Append(executive.TIME_RANGE(), timevalues[0])
  146. outInfo.Append(executive.TIME_RANGE(), timevalues[-1])
  147. return 1
  148. def _get_array_and_component(self, itr, name):
  149. for mesh_name, mesh in itr.meshes.items():
  150. if mesh_name == name:
  151. return (mesh_name, None)
  152. for comp_name, _ in mesh.items():
  153. if name == mesh_name + "_" + comp_name:
  154. return (mesh_name, comp_name)
  155. return (None, None)
  156. def _get_particle_array_and_component(self, itr, name):
  157. for species_name, species in itr.particles.items():
  158. for record_name, record in species.items():
  159. if name == species_name + "_" + record_name:
  160. return (species_name, record_name)
  161. return (None, None)
  162. def _load_array(self, var, chunk_offset, chunk_extent):
  163. arrays = []
  164. for name, scalar in var.items():
  165. comp = scalar.load_chunk(chunk_offset, chunk_extent)
  166. self._series.flush()
  167. comp = comp * scalar.unit_SI
  168. arrays.append(comp)
  169. ncomp = len(var)
  170. if ncomp > 1:
  171. flt = np.ravel(arrays, order='F')
  172. return flt.reshape((flt.shape[0]//ncomp, ncomp))
  173. else:
  174. return arrays[0].flatten(order='F')
  175. def _find_array(self, itr, name):
  176. var = itr.meshes[name]
  177. theta_modes = None
  178. if var.geometry == io.Geometry.thetaMode:
  179. theta_modes = 3 # hardcoded, parse from geometry_parameters
  180. return (var,
  181. np.array(var.grid_spacing) * var.grid_unit_SI,
  182. np.array(var.grid_global_offset) * var.grid_unit_SI,
  183. theta_modes)
  184. def _get_num_particles(self, itr, species):
  185. sp = itr.particles[species]
  186. var = sp["position"]
  187. return var['x'].shape[0]
  188. def _load_particle_array(self, itr, species, name, start, end):
  189. sp = itr.particles[species]
  190. var = sp[name]
  191. arrays = []
  192. for name, scalar in var.items():
  193. comp = scalar.load_chunk([start], [end-start+1])
  194. self._series.flush()
  195. comp = comp * scalar.unit_SI
  196. arrays.append(comp)
  197. ncomp = len(var)
  198. if ncomp > 1:
  199. flt = np.ravel(arrays, order='F')
  200. return flt.reshape((flt.shape[0]//ncomp, ncomp))
  201. else:
  202. return arrays[0]
  203. def _load_particles(self, itr, species, start, end):
  204. sp = itr.particles[species]
  205. var = sp["position"]
  206. ovar = sp["positionOffset"]
  207. position_arrays = []
  208. for name, scalar in var.items():
  209. pos = scalar.load_chunk([start], [end-start+1])
  210. self._series.flush()
  211. pos = pos * scalar.unit_SI
  212. off = ovar[name].load_chunk([start], [end-start+1])
  213. self._series.flush()
  214. off = off * ovar[name].unit_SI
  215. position_arrays.append(pos + off)
  216. flt = np.ravel(position_arrays, order='F')
  217. num_components = len(var) # 1D, 2D and 3D positions
  218. flt = flt.reshape((flt.shape[0]//num_components, num_components))
  219. # 1D and 2D particles: pad additional components with zero
  220. while flt.shape[1] < 3:
  221. flt = np.column_stack([flt, np.zeros_like(flt[:, 0])])
  222. return flt
  223. def _load_species(self, itr, species, arrays, piece, npieces, ugrid):
  224. nparticles = self._get_num_particles(itr, species)
  225. nlocalparticles = nparticles // npieces
  226. start = nlocalparticles * piece
  227. end = start + nlocalparticles - 1
  228. if piece == npieces - 1:
  229. end = nparticles - 1
  230. pts = self._load_particles(itr, species, start, end)
  231. npts = pts.shape[0]
  232. ugrid.Points = pts
  233. for array in arrays:
  234. if array[1] == 'position' or array[1] == 'positionOffset':
  235. continue
  236. ugrid.PointData.append(
  237. self._load_particle_array(itr, array[0], array[1], start, end),
  238. array[1])
  239. from vtkmodules.vtkCommonDataModel import vtkCellArray
  240. ca = vtkCellArray()
  241. if npts < np.iinfo(np.int32).max:
  242. dtype = np.int32
  243. else:
  244. dtype = np.int64
  245. offsets = np.linspace(0, npts, npts+1, dtype=dtype)
  246. cells = np.linspace(0, npts-1, npts, dtype=dtype)
  247. from vtkmodules.numpy_interface import dataset_adapter
  248. offsets = dataset_adapter.numpyTovtkDataArray(offsets)
  249. offsets2 = offsets.NewInstance()
  250. offsets2.DeepCopy(offsets)
  251. cells = dataset_adapter.numpyTovtkDataArray(cells)
  252. cells2 = cells.NewInstance()
  253. cells2.DeepCopy(cells)
  254. ca.SetData(offsets2, cells2)
  255. from vtkmodules.util import vtkConstants
  256. ugrid.VTKObject.SetCells(vtkConstants.VTK_VERTEX, ca)
  257. def _RequestFieldData(self, executive, output, outInfo):
  258. from vtkmodules.numpy_interface import dataset_adapter as dsa
  259. from vtkmodules.vtkCommonDataModel import vtkImageData
  260. from vtkmodules.vtkCommonExecutionModel import vtkExtentTranslator
  261. piece = outInfo.Get(executive.UPDATE_PIECE_NUMBER())
  262. npieces = outInfo.Get(executive.UPDATE_NUMBER_OF_PIECES())
  263. nghosts = outInfo.Get(executive.UPDATE_NUMBER_OF_GHOST_LEVELS())
  264. et = vtkExtentTranslator()
  265. data_time = self._get_update_time(outInfo)
  266. idx = self._timemap[data_time]
  267. itr = self._series.iterations[idx]
  268. arrays = []
  269. narrays = self._arrayselection.GetNumberOfArrays()
  270. for i in range(narrays):
  271. if self._arrayselection.GetArraySetting(i):
  272. name = self._arrayselection.GetArrayName(i)
  273. arrays.append((name, self._find_array(itr, name)))
  274. shp = None
  275. spacing = None
  276. theta_modes = None
  277. grid_offset = None
  278. for _, ary in arrays:
  279. var = ary[0]
  280. for name, scalar in var.items():
  281. shape = scalar.shape
  282. break
  283. spc = list(ary[1])
  284. if not spacing:
  285. spacing = spc
  286. elif spacing != spc: # all meshes need to have the same spacing
  287. return 0
  288. offset = list(ary[2])
  289. if not grid_offset:
  290. grid_offset = offset
  291. elif grid_offset != offset: # all meshes need to have the same spacing
  292. return 0
  293. if not shp:
  294. shp = shape
  295. elif shape != shp: # all arrays needs to have the same shape
  296. return 0
  297. if not theta_modes:
  298. theta_modes = ary[3]
  299. # fields/meshes: RZ
  300. if theta_modes and shp is not None:
  301. et.SetWholeExtent(0, shp[0]-1, 0, shp[1]-1, 0, shp[2]-1)
  302. et.SetSplitModeToZSlab() # note: Y and Z are both fine
  303. et.SetPiece(piece)
  304. et.SetNumberOfPieces(npieces)
  305. # et.SetGhostLevel(nghosts)
  306. et.PieceToExtentByPoints()
  307. ext = et.GetExtent()
  308. chunk_offset = [ext[0], ext[2], ext[4]]
  309. chunk_extent = [ext[1] - ext[0] + 1, ext[3] - ext[2] + 1, ext[5] - ext[4] + 1]
  310. data = []
  311. nthetas = 100 # user parameter
  312. thetas = np.linspace(0., 2.*np.pi, nthetas)
  313. chunk_cyl_shape = (chunk_extent[1], chunk_extent[2], nthetas) # z, r, theta
  314. for name, var in arrays:
  315. cyl_values = np.zeros(chunk_cyl_shape)
  316. values = self._load_array(var[0], chunk_offset, chunk_extent)
  317. self._series.flush()
  318. print(chunk_cyl_shape)
  319. print(values.shape)
  320. print("+++++++++++")
  321. for ntheta in range(nthetas):
  322. cyl_values[:, :, ntheta] += values[0, :, :]
  323. data.append((name, cyl_values))
  324. # add all other modes via loop
  325. # for m in range(theta_modes):
  326. cyl_spacing = [spacing[0], spacing[1], thetas[1]-thetas[0]]
  327. z_coord = np.linspace(0., cyl_spacing[0] * chunk_cyl_shape[0], chunk_cyl_shape[0])
  328. r_coord = np.linspace(0., cyl_spacing[1] * chunk_cyl_shape[1], chunk_cyl_shape[1])
  329. t_coord = thetas
  330. # to cartesian
  331. print(z_coord.shape, r_coord.shape, t_coord.shape)
  332. cyl_coords = np.meshgrid(r_coord, z_coord, t_coord)
  333. rs = cyl_coords[1]
  334. zs = cyl_coords[0]
  335. thetas = cyl_coords[2]
  336. y_coord = rs * np.sin(thetas)
  337. x_coord = rs * np.cos(thetas)
  338. z_coord = zs
  339. # mesh_pts = np.zeros((chunk_cyl_shape[0], chunk_cyl_shape[1], chunk_cyl_shape[2], 3))
  340. # mesh_pts[:, :, :, 0] = z_coord
  341. img = vtkImageData()
  342. img.SetExtent(
  343. chunk_offset[1], chunk_offset[1] + chunk_cyl_shape[0] - 1,
  344. chunk_offset[2], chunk_offset[2] + chunk_cyl_shape[1] - 1,
  345. 0, nthetas-1)
  346. img.SetSpacing(cyl_spacing)
  347. imgw = dsa.WrapDataObject(img)
  348. output.SetPartition(0, img)
  349. for name, array in data:
  350. # print(array.shape)
  351. # print(array.transpose(2,1,0).flatten(order='C').shape)
  352. imgw.PointData.append(array.transpose(2,1,0).flatten(order='C'), name)
  353. # data = []
  354. # for name, var in arrays:
  355. # unit_SI = var[0].unit_SI
  356. # data.append((name, unit_SI * var[0].load_chunk(chunk_offset, chunk_extent)))
  357. # self._series.flush()
  358. # fields/meshes: 1D-3D
  359. elif shp is not None:
  360. whole_extent = []
  361. # interleave shape with zeros
  362. for s in shp:
  363. whole_extent.append(0)
  364. whole_extent.append(s-1)
  365. # 1D and 2D data: pad with 0, 0 for extra dimensions
  366. while len(whole_extent) < 6:
  367. whole_extent.append(0)
  368. whole_extent.append(0)
  369. et.SetWholeExtent(*whole_extent)
  370. et.SetPiece(piece)
  371. et.SetNumberOfPieces(npieces)
  372. et.SetGhostLevel(nghosts)
  373. et.PieceToExtent()
  374. ext = et.GetExtent()
  375. chunk_offset = [ext[0], ext[2], ext[4]]
  376. chunk_extent = [ext[1] - ext[0] + 1, ext[3] - ext[2] + 1, ext[5] - ext[4] + 1]
  377. # 1D and 2D data: remove extra dimensions for load
  378. del chunk_offset[len(shp):]
  379. del chunk_extent[len(shp):]
  380. data = []
  381. for name, var in arrays:
  382. values = self._load_array(var[0], chunk_offset, chunk_extent)
  383. self._series.flush()
  384. data.append((name, values))
  385. # 1D and 2D data: pad spacing with extra 1 and grid_offset with
  386. # extra 9 values until 3D
  387. i = iter(spacing)
  388. spacing = [next(i, 1) for _ in range(3)]
  389. i = iter(grid_offset)
  390. grid_offset = [next(i, 0) for _ in range(3)]
  391. img = vtkImageData()
  392. img.SetExtent(ext[0], ext[1], ext[2], ext[3], ext[4], ext[5])
  393. img.SetSpacing(spacing)
  394. img.SetOrigin(grid_offset)
  395. et.SetGhostLevel(0)
  396. et.PieceToExtent()
  397. ext = et.GetExtent()
  398. ext = [ext[0], ext[1], ext[2], ext[3], ext[4], ext[5]]
  399. img.GenerateGhostArray(ext)
  400. imgw = dsa.WrapDataObject(img)
  401. output.SetPartition(0, img)
  402. for name, array in data:
  403. imgw.PointData.append(array, name)
  404. def _RequestParticleData(self, executive, poutput, outInfo):
  405. from vtkmodules.numpy_interface import dataset_adapter as dsa
  406. from vtkmodules.vtkCommonDataModel import vtkUnstructuredGrid, vtkPartitionedDataSet
  407. piece = outInfo.Get(executive.UPDATE_PIECE_NUMBER())
  408. npieces = outInfo.Get(executive.UPDATE_NUMBER_OF_PIECES())
  409. data_time = self._get_update_time(outInfo)
  410. idx = self._timemap[data_time]
  411. itr = self._series.iterations[idx]
  412. array_by_species = {}
  413. narrays = self._particlearrayselection.GetNumberOfArrays()
  414. for i in range(narrays):
  415. if self._particlearrayselection.GetArraySetting(i):
  416. name = self._particlearrayselection.GetArrayName(i)
  417. names = self._get_particle_array_and_component(
  418. itr, name)
  419. if names[0] and self._speciesselection.ArrayIsEnabled(names[0]):
  420. if not names[0] in array_by_species:
  421. array_by_species[names[0]] = []
  422. array_by_species[names[0]].append(names)
  423. ids = 0
  424. for species, arrays in array_by_species.items():
  425. pds = vtkPartitionedDataSet()
  426. ugrid = vtkUnstructuredGrid()
  427. pds.SetPartition(0, ugrid)
  428. poutput.SetPartitionedDataSet(ids, pds)
  429. ids += 1
  430. self._load_species(
  431. itr, species, arrays, piece, npieces, dsa.WrapDataObject(ugrid))
  432. def RequestData(self, request, inInfoVec, outInfoVec):
  433. global _has_openpmd
  434. if not _has_openpmd:
  435. print_error("Required Python module 'openpmd_api' missing!")
  436. return 0
  437. from vtkmodules.vtkCommonDataModel import vtkPartitionedDataSet, vtkPartitionedDataSetCollection
  438. from vtkmodules.vtkCommonExecutionModel import vtkStreamingDemandDrivenPipeline
  439. executive = vtkStreamingDemandDrivenPipeline
  440. numInfo = outInfoVec.GetNumberOfInformationObjects()
  441. for i in range(numInfo):
  442. outInfo = outInfoVec.GetInformationObject(i)
  443. if i == 0:
  444. output = vtkPartitionedDataSet.GetData(outInfoVec, 0)
  445. self._RequestFieldData(executive, output, outInfo)
  446. elif i == 1:
  447. poutput = vtkPartitionedDataSetCollection.GetData(outInfoVec, 1)
  448. self._RequestParticleData(executive, poutput, outInfo)
  449. else:
  450. print_error("numInfo number is wrong! "
  451. "It should be exactly 2, is=", numInfo)
  452. return 0
  453. return 1