testStreams.py 4.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131
  1. import pythoncom
  2. import win32com.server.util
  3. import win32com.test.util
  4. import unittest
  5. from pywin32_testutil import str2bytes
  6. class Persists:
  7. _public_methods_ = [ 'GetClassID', 'IsDirty', 'Load', 'Save',
  8. 'GetSizeMax', 'InitNew' ]
  9. _com_interfaces_ = [ pythoncom.IID_IPersistStreamInit ]
  10. def __init__(self):
  11. self.data = str2bytes("abcdefg")
  12. self.dirty = 1
  13. def GetClassID(self):
  14. return pythoncom.IID_NULL
  15. def IsDirty(self):
  16. return self.dirty
  17. def Load(self, stream):
  18. self.data = stream.Read(26)
  19. def Save(self, stream, clearDirty):
  20. stream.Write(self.data)
  21. if clearDirty:
  22. self.dirty = 0
  23. def GetSizeMax(self):
  24. return 1024
  25. def InitNew(self):
  26. pass
  27. class Stream:
  28. _public_methods_ = [ 'Read', 'Write', 'Seek' ]
  29. _com_interfaces_ = [ pythoncom.IID_IStream ]
  30. def __init__(self, data):
  31. self.data = data
  32. self.index = 0
  33. def Read(self, amount):
  34. result = self.data[self.index : self.index + amount]
  35. self.index = self.index + amount
  36. return result
  37. def Write(self, data):
  38. self.data = data
  39. self.index = 0
  40. return len(data)
  41. def Seek(self, dist, origin):
  42. if origin==pythoncom.STREAM_SEEK_SET:
  43. self.index = dist
  44. elif origin==pythoncom.STREAM_SEEK_CUR:
  45. self.index = self.index + dist
  46. elif origin==pythoncom.STREAM_SEEK_END:
  47. self.index = len(self.data)+dist
  48. else:
  49. raise ValueError('Unknown Seek type: ' +str(origin))
  50. if self.index < 0:
  51. self.index = 0
  52. else:
  53. self.index = min(self.index, len(self.data))
  54. return self.index
  55. class BadStream(Stream):
  56. """ PyGStream::Read could formerly overflow buffer if the python implementation
  57. returned more data than requested.
  58. """
  59. def Read(self, amount):
  60. return str2bytes('x')*(amount+1)
  61. class StreamTest(win32com.test.util.TestCase):
  62. def _readWrite(self, data, write_stream, read_stream = None):
  63. if read_stream is None: read_stream = write_stream
  64. write_stream.Write(data)
  65. read_stream.Seek(0, pythoncom.STREAM_SEEK_SET)
  66. got = read_stream.Read(len(data))
  67. self.assertEqual(data, got)
  68. read_stream.Seek(1, pythoncom.STREAM_SEEK_SET)
  69. got = read_stream.Read(len(data)-2)
  70. self.assertEqual(data[1:-1], got)
  71. def testit(self):
  72. mydata = str2bytes('abcdefghijklmnopqrstuvwxyz')
  73. # First test the objects just as Python objects...
  74. s = Stream(mydata)
  75. p = Persists()
  76. p.Load(s)
  77. p.Save(s, 0)
  78. self.assertEqual(s.data, mydata)
  79. # Wrap the Python objects as COM objects, and make the calls as if
  80. # they were non-Python COM objects.
  81. s2 = win32com.server.util.wrap(s, pythoncom.IID_IStream)
  82. p2 = win32com.server.util.wrap(p, pythoncom.IID_IPersistStreamInit)
  83. self._readWrite(mydata, s, s)
  84. self._readWrite(mydata, s, s2)
  85. self._readWrite(mydata, s2, s)
  86. self._readWrite(mydata, s2, s2)
  87. self._readWrite(str2bytes("string with\0a NULL"), s2, s2)
  88. # reset the stream
  89. s.Write(mydata)
  90. p2.Load(s2)
  91. p2.Save(s2, 0)
  92. self.assertEqual(s.data, mydata)
  93. def testseek(self):
  94. s = Stream(str2bytes('yo'))
  95. s = win32com.server.util.wrap(s, pythoncom.IID_IStream)
  96. # we used to die in py3k passing a value > 32bits
  97. s.Seek(0x100000000, pythoncom.STREAM_SEEK_SET)
  98. def testerrors(self):
  99. # setup a test logger to capture tracebacks etc.
  100. records, old_log = win32com.test.util.setup_test_logger()
  101. ## check for buffer overflow in Read method
  102. badstream = BadStream('Check for buffer overflow')
  103. badstream2 = win32com.server.util.wrap(badstream, pythoncom.IID_IStream)
  104. self.assertRaises(pythoncom.com_error, badstream2.Read, 10)
  105. win32com.test.util.restore_test_logger(old_log)
  106. # expecting 2 pythoncom errors to have been raised by the gateways.
  107. self.assertEqual(len(records), 2)
  108. self.failUnless(records[0].msg.startswith('pythoncom error'))
  109. self.failUnless(records[1].msg.startswith('pythoncom error'))
  110. if __name__=='__main__':
  111. unittest.main()