2022年 11月 5日

Python浅谈上下文管理器

说明:仅供学习使用,请勿用于非法用途,若有侵权,请联系博主删除

作者:zhu6201976

博客:https://blog.csdn.net/zhu6201976

一、任何实现 __enter__() 和 __exit__() 方法的对象都可称之为上下文管理器,上下文管理器对象可以使用 with 关键字。

Python常见上下文管理器:文件、线程锁、Semaphore、socket等,其内部均已实现  __enter__() 和 __exit__() 方法。

  1. # threading.Lock源码
  2. class Lock:
  3. def __init__(self) -> None: ...
  4. def __enter__(self) -> bool: ...
  5. def __exit__(
  6. self, exc_type: Optional[Type[BaseException]], exc_val: Optional[BaseException], exc_tb: Optional[TracebackType]
  7. ) -> Optional[bool]: ...
  8. if sys.version_info >= (3,):
  9. def acquire(self, blocking: bool = ..., timeout: float = ...) -> bool: ...
  10. else:
  11. def acquire(self, blocking: bool = ...) -> bool: ...
  12. def release(self) -> None: ...
  13. def locked(self) -> bool: ...
  1. # threading.Semaphore源码
  2. class Semaphore:
  3. """This class implements semaphore objects.
  4. Semaphores manage a counter representing the number of release() calls minus
  5. the number of acquire() calls, plus an initial value. The acquire() method
  6. blocks if necessary until it can return without making the counter
  7. negative. If not given, value defaults to 1.
  8. """
  9. # After Tim Peters' semaphore class, but not quite the same (no maximum)
  10. def __init__(self, value=1):
  11. if value < 0:
  12. raise ValueError("semaphore initial value must be >= 0")
  13. self._cond = Condition(Lock())
  14. self._value = value
  15. def acquire(self, blocking=True, timeout=None):
  16. """Acquire a semaphore, decrementing the internal counter by one.
  17. When invoked without arguments: if the internal counter is larger than
  18. zero on entry, decrement it by one and return immediately. If it is zero
  19. on entry, block, waiting until some other thread has called release() to
  20. make it larger than zero. This is done with proper interlocking so that
  21. if multiple acquire() calls are blocked, release() will wake exactly one
  22. of them up. The implementation may pick one at random, so the order in
  23. which blocked threads are awakened should not be relied on. There is no
  24. return value in this case.
  25. When invoked with blocking set to true, do the same thing as when called
  26. without arguments, and return true.
  27. When invoked with blocking set to false, do not block. If a call without
  28. an argument would block, return false immediately; otherwise, do the
  29. same thing as when called without arguments, and return true.
  30. When invoked with a timeout other than None, it will block for at
  31. most timeout seconds. If acquire does not complete successfully in
  32. that interval, return false. Return true otherwise.
  33. """
  34. if not blocking and timeout is not None:
  35. raise ValueError("can't specify timeout for non-blocking acquire")
  36. rc = False
  37. endtime = None
  38. with self._cond:
  39. while self._value == 0:
  40. if not blocking:
  41. break
  42. if timeout is not None:
  43. if endtime is None:
  44. endtime = _time() + timeout
  45. else:
  46. timeout = endtime - _time()
  47. if timeout <= 0:
  48. break
  49. self._cond.wait(timeout)
  50. else:
  51. self._value -= 1
  52. rc = True
  53. return rc
  54. __enter__ = acquire
  55. def release(self):
  56. """Release a semaphore, incrementing the internal counter by one.
  57. When the counter is zero on entry and another thread is waiting for it
  58. to become larger than zero again, wake up that thread.
  59. """
  60. with self._cond:
  61. self._value += 1
  62. self._cond.notify()
  63. def __exit__(self, t, v, tb):
  64. self.release()

二、模拟实现文件上下文管理器

  1. class MyOpen(object):
  2. def __init__(self, file_path, mode):
  3. print('__init__')
  4. self.f = open(file_path, mode)
  5. def __enter__(self):
  6. print('__enter__')
  7. return self.f
  8. def __exit__(self, exc_type, exc_val, exc_tb):
  9. print('__exit__')
  10. self.f.close()
  11. with MyOpen('test.txt', 'r') as f:
  12. content = f.content()
  13. print(content)

三、数据库上下文管理器

  1. import pymysql
  2. class Db(object):
  3. def __init__(self):
  4. print('__init__')
  5. self.connect = pymysql.connect(
  6. host='localhost',
  7. user='root',
  8. password='root',
  9. database='test',
  10. port=3306,
  11. charset='utf8'
  12. )
  13. self.cursor = self.connect.cursor()
  14. def __enter__(self):
  15. print('__enter__')
  16. return self
  17. def fetchall(self, sql):
  18. self.cursor.execute(sql)
  19. ret = self.cursor.fetchall()
  20. return ret
  21. def __exit__(self, exc_type, exc_val, exc_tb):
  22. print('__exit__')
  23. self.cursor.close()
  24. with Db() as db:
  25. ret = db.fetchall('select g.id,g.name,c.name,b.name from goods as g ' +
  26. 'left join goods_cates as c on g.cate_id=c.id ' +
  27. 'left join goods_brands as b on g.brand_id=b.id')
  28. print(ret)

执行结果:

四、socket上下文管理器

参考:https://wx-zhu6201976.blog.csdn.net/article/details/116715370