python - 在装饰器中实现的 Python 静态变量不会重置
问题描述
我这里有实现静态变量装饰器的代码。但是,我发现如果我多次运行此函数,则每次调用该函数时都不会重新初始化静态变量。
def static_vars(**kwargs):
def decorate(func):
for k in kwargs:
setattr(func, k, kwargs[k])
return func
return decorate
@static_vars(count=0)
def rolling_serial(val):
'''
For a vector V = [v_1, ..., V_N] returns a serial
index.
so for V = [1, 1, 1, 3, 1, 1, 1]
a resulting vector will be generated
V_hat = [1, 2, 3, 4, 5, 6, 7]
'''
temp = rolling_serial.count
rolling_serial.count += 1
return temp
# invoke it like this
from useful import (rolling_serial)
df = <...some dataframe with a column called ts>
self.df['ts_index'] = self.df.ts.apply(rolling_serial)
# Example output a new column, sa: [1, 2, 3, ..., N]
# My issue arises if I run it again
df = <...some dataframe with a column called ts>
self.df['ts_index'] = self.df.ts.apply(rolling_serial)
# output: [N+1, N+2, ...] instead of restarting at 0
如果我重新启动 jupyter 内核,静态变量会被清除。但我宁愿不必重新启动内核。谁能帮我?
解决方案
你的装饰器只被调用一次,而不是每次调用你的函数。确切地说,它是在定义时调用的:
def static_vars(**kwargs):
def decorate(func):
for k in kwargs:
print(kwargs)
setattr(func, k, kwargs[k])
return func
return decorate
@static_vars(count=0)
def rolling_serial(val):
'''
For a vector V = [v_1, ..., V_N] returns a serial
index.
so for V = [1, 1, 1, 3, 1, 1, 1]
a resulting vector will be generated
V_hat = [1, 2, 3, 4, 5, 6, 7]
'''
temp = rolling_serial.count
rolling_serial.count += 1
return temp
print('---- BEGIN ----')
print(rolling_serial(10))
print(rolling_serial(20))
print(rolling_serial(30))
印刷:
{'count': 0}
---- BEGIN ----
0
1
2
kwargs
您作为参数 in将static_vars()
成为闭包,并且每次调用rolling_serial()
.
一种解决方案是通过 globals() 传输变量:
# This function creates decorator:
def static_vars(**global_kwargs):
# This is decorator:
def decorate(func):
# This function is called every time:
def _f(*args, **kwargs):
for k in global_kwargs:
globals()[func.__name__+'_'+k] = global_kwargs[k]
return func(*args, **kwargs)
return _f
return decorate
@static_vars(count=0, temp=40)
def rolling_serial():
global rolling_serial_count, rolling_serial_temp
temp1, temp2 = rolling_serial_count, rolling_serial_temp
rolling_serial_count += 1
rolling_serial_temp += 1
return temp1, temp2
print(rolling_serial()) # prints (0, 40)
print(rolling_serial()) # prints (0, 40)
print(rolling_serial()) # prints (0, 40)
推荐阅读
- asp.net - 浏览器中的回发抛出参数不匹配错误
- c++ - 为什么 climit 有一次编译指示和 #ifndef gaurd
- crud - 我在哪里可以找到 CREATE_ONLY、CREATE_AND_SET 等 CRUD 标志的标准定义
- json - 在 azure 流分析中解析 json 内容
- angular - Angular/Power BI 嵌入
- javascript - 如何使用 window.getComputedStyle() 获得全高(包括边距)?
- python - 在 django admin 中覆盖 get_form 时缺少添加/编辑/删除和日期选择器按钮
- python - IndexError:元组超出范围
- datetime - 如何让我的日期时间字符串由logstash转换并填充到elasticsearch
- xaml - 在 xamarinforms 中淡化内容视图时出错