我有一个df,
cluster_id memo
1 m
1 n
2 m
2 m
2 n
3 m
3 m
3 m
3 n
4 m
4 n
4 n
4 n
我想groupby cluster_id并应用以下功能,
def valid_row_dup(df):
num_real_invs = df[df['memo'] == 'm'].shape[0]
num_reversals_invs = df[df['memo'] == 'n'].shape[0]
if num_real_invs == df.shape[0]:
return True
elif num_reversals_invs == df.shape[0]:
return False
elif abs(num_real_invs - num_reversals_invs) > 0:
# even diff
if abs(num_real_invs - num_reversals_invs) % 2 == 0:
return True
else:
if abs(num_real_invs - num_reversals_invs) == 1:
return False
# odd diff
else:
return True
elif num_real_invs - num_reversals_invs == 0:
return False
将每个groupby对象作为df传入func; 将布尔结果分配回df,
cluster_id memo valid
1 m False
1 n False
2 m False
2 m False
2 n False
3 m True
3 m True
3 m True
3 n True
4 m True
4 n True
4 n True
4 n True
解决办法:应用功能然后合并:
df.merge(df.groupby('cluster_id').apply(valid_row_dup).to_frame(), on='cluster_id')
cluster_id memo 0
0 1 m False
1 1 n False
2 2 m False
3 2 m False
4 2 n False
5 3 m True
6 3 m True
7 3 m True
8 3 n True
9 4 m True
10 4 n True
11 4 n True
12 4 n True








暂无数据