defunion(self,a,b): # 吞并规则 roota,rootb=self.find(a),self.find(b) # 不是一派 if roota!=rootb: # 小帮派并入大帮派 if self.rank[roota]<self.rank[rootb]: roota,rootb=rootb,roota self.parents[rootb]=roota if self.rank[roota]==self.rank[rootb]: # 同级别的情况下,能够百尺竿头更进一步 self.rank[roota]+=1 self.count-=1
defgetCount(self): return self.count
当然这个板子略显复杂,我们也不需要额外开一个类去定义并查集:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
classSolution: deffunction(self,data): n=len(data) pa=[i for i inrange(n)] deffind(x): if pa[x]!=x: pa[x]=find(pa[x]) return pa[x] # 看情况修改 for i in data: for j in data: if i<j: # 是否连通 ri,rj=find(i),find(j) # 合并规则 if ri!=rj: pa[rj]=ri # 合并 returnsum([1for i,j inenumerate(pa) if pa[i]==j]) # 看情况修改 这里是统计连通块数量
再贴一个Python3的精简版本,:=赋值符号只能在py3.8及以上的版本用
1 2 3 4 5 6 7 8 9 10 11 12
classUnionFind: def__init__(self,n): self.pa=[i for i inrange(n)] deffind(self,x): if self.pa[x]!=x: self.pa[x]=self.find(self.pa[x]) return pa[x] defunion(self,a,b): if (roota:=self.find(a))!=(rootb:=self.find(b)): self.pa[rootb]=roota defisConnect(self,a,b): return self.find(a)==self.find(b)
并查集
一、概念介绍
所谓并查集(Union-find Data Structure),是一种用于快速处理不交集合并及查询问题的树形结构
classSolution: defnumIslands(self, grid: List[List[str]]) -> int: m=len(grid) if m<=0:return0 n=len(grid[0]) cnt=m*n# 计数 pa=[i for i inrange(m*n)] # 二维展平成一维的父节点 deffind(x):# 寻址+路径压缩 if pa[x]!=x: pa[x]=find(pa[x]) return pa[x]
for i inrange(m): for j inrange(n): if grid[i][j]=="1":# "1"的时候合并 for x,y in [(i+1,j),(i,j+1)]: if x<m and y<n and grid[x][y]=="1":# 合并 roota,rootb=find(i*n+j),find(x*n+y) if roota!=rootb: pa[rootb]=roota cnt-=1 else: cnt-=1# "0"的时候计数减一 return cnt
classSolution: defmaxAreaOfIsland(self, grid: List[List[int]]) -> int: m,n=len(grid),len(grid[0]) pa=[i for i inrange(m*n)] rank=[1]*(m*n)
deffind(x): if pa[x]!=x: pa[x]=find(pa[x]) return pa[x] for i inrange(m): for j inrange(n):
if grid[i][j]==1: for x,y in [(i+1,j),(i,j+1)]: if x<m and y<n and grid[x][y]==1: rx,ry=find(i*n+j),find(x*n+y) if rx!=ry: if rank[rx]<rank[ry]: rx,ry=ry,rx pa[ry]=rx rank[rx]+=rank[ry] else: rank[i*n+j]=0
# 欧拉筛 n=int(10e5+1) isPrime=[1]*n prime=[] for i inrange(2,n): if isPrime[i]: prime.append(i) for j in prime: if j*i>=n: break isPrime[j*i]=0 if i%j==0: break
n=max(nums) pa=[i for i inrange(n+1)] deffind(x): if pa[x]!=x: pa[x]=find(pa[x]) return pa[x] defunion(a,b): if (ra:=find(a))!=(rb:=find(b)): pa[rb]=ra for idx,i inenumerate(nums): tem=nums[idx] for j in prime: if j**2>tem: break # 公因子连接 if tem%j==0: union(i,j) # 去除公因子 while tem%j==0: tem//=j if tem!=1: union(tem,i) returnmax(Counter(find(num) for num in nums).values())
# 我们可以用hash表记录位置 hashmap=collections.defaultdict(list) for i inrange(n): for j inrange(n): hashmap[grid[i][j]]=[i,j]
# t次循环 for t inrange(n**2): idx=hashmap[t] for dx,dy indir: x=idx[0]+dx y=idx[1]+dy if x>=0and x<n and y>=0and y<n and grid[x][y]<=t: u.union(idx[0]*n+idx[1],x*n+y) # 判断此时的连通性 if u.isConnect(0,n*n-1): return t
classUnionFind: def__init__(self,n): self.pa=[i for i inrange(n)]
deffind(self,x): if self.pa[x]!=x: self.pa[x]=self.find(self.pa[x]) return self.pa[x]
defunion(self,a,b): roota,rootb=self.find(a),self.find(b) if roota!=rootb: self.pa[rootb]=roota
classSolution: deflargestIsland(self, grid): # 并查集+回溯 n=len(grid) pa=[i for i inrange(n*n)] rank=[1]*(n*n) maxVal=0 deffind(x): if pa[x]!=x: pa[x]=find(pa[x]) return pa[x] # 初始化并查集 for i inrange(n): for j inrange(n): if grid[i][j]==1: for x,y in [(i+1,j),(i,j+1)]: if x<n and y<n and grid[x][y]==1: ra,rb=find(i*n+j),find(x*n+y) if ra!=rb: pa[rb]=ra rank[ra]+=rank[rb] else: rank[i*n+j]=0 maxVal=max(rank) # 记录啥也不加的最大值 #--------------------------------------------------------------------# # 到这里位置完全是695 # #--------------------------------------------------------------------# # 回溯 for i inrange(n): for j inrange(n): if grid[i][j]==0: # 合并上下左右不同根节点的连通块 tem_rank=0 used=set() for x,y in [(i-1,j),(i+1,j),(i,j-1),(i,j+1)]: if x>=0and x<n and y>=0and y<n and grid[x][y]==1and (val:=find(x*n+y)) notin used: tem_rank+=rank[val] used.add(val) maxVal=max(maxVal,tem_rank+1) returnmax(maxv,maxVal)
for i,j inenumerate(nums): if j in hashset.keys(): continue# 找过了 if j-1in hashset.keys(): u.union(i,hashset[j-1]) if j+1in hashset.keys(): u.union(i,hashset[j+1]) hashset[j]=i return u.getV()
classUnionSet:
def__init__(self,x): self.pa=[i for i inrange(x)] self.rank=[1for i inrange(x)] deffind(self,x): if self.pa[x]!=x: self.pa[x]=self.find(self.pa[x]) return self.pa[x] defunion(self,a,b): roota,rootb=self.find(a),self.find(b) if roota!=rootb: self.pa[rootb]=roota self.rank[roota]+=self.rank[rootb] # self.rank[rootb]=0
defgetV(self): returnmax(self.rank)
我们来看看hash表的做法
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
classSolution: deflongestConsecutive(self, nums: List[int]) -> int: hashset=set(nums) ans=0 for i in nums: val=i if val-1notin hashset: start=1 while val+1in hashset: start+=1 val+=1 ans=max(ans,start) return ans
输入:equations = [["a","b"],["b","c"]], values = [2.0,3.0], queries = [["a","c"],["b","a"],["a","e"],["a","a"],["x","x"]] 输出:[6.00000,0.50000,-1.00000,1.00000,-1.00000] 解释: 条件:a / b = 2.0, b / c = 3.0 问题:a / c = ?, b / a = ?, a / e = ?, a / a = ?, x / x = ? 结果:[6.0, 0.5, -1.0, 1.0, -1.0 ]
classSolution: defcalcEquation(self, equations: List[List[str]], values: List[float], queries: List[List[str]]) -> List[float]: u=UnionFind() for i,v inenumerate(equations): u.merge(v[0],v[1],values[i]) ans=[] for i,j in queries: ans.append(u.divide(i,j)) return ans
classUnionFind: def__init__(self): self.pa={} self.val={} deffind(self,x): if x notin self.pa: returnNone if self.pa[x]!=x: fa=self.find(self.pa[x]) self.val[x]=self.val[x]*self.val[self.pa[x]] # 更新权重 self.pa[x]=fa return self.pa[x] defmerge(self,a,b,val): if a notin self.pa: self.pa[a]=a self.val[a]=1 if b notin self.pa: self.pa[b]=b self.val[b]=1
roota,rootb=self.find(a),self.find(b) if roota==rootb or (roota==Noneor rootb==None): return self.pa[roota]=rootb # w[b]=b/roob # roob/rooa = b/w[b] / a/w[a] # = b/a * w[a]/w[b] self.val[roota]=val*(self.val[b]/self.val[a]) defdivide(self,a,b): if (roota:=self.find(a))!=(rootb:=self.find(b)) or (roota==Noneor rootb==None): return -1 return self.val[a]/self.val[b]
u=uf(len(s)) for i,j in pairs: u.union(i,j) # 对于每一个连通块,我们只需要对其进行排序就可以了 # 当然,我们也要获取到连通块内的元素 mp=collections.defaultdict(list) for i,v inenumerate(s): mp[u.find(i)].append(v) # 构建连通块 for vec in mp.values(): vec.sort(reverse=True) ans=[]
for i inrange(len(s)): # 每一个占位都用连通块内最小的进行占位 x=u.find(i) ans.append(mp[x][-1]) mp[x].pop() return"".join(ans)
classuf: def__init__(self,n): self.pa=[i for i inrange(n)] deffind(self,x): if self.pa[x]!=x: self.pa[x]=self.find(self.pa[x]) return self.pa[x] defunion(self,a,b): if (ra:=self.find(a))!=(rb:=self.find(b)): self.pa[rb]=ra
classSolution(object): defminSwapsCouples(self, row): """ :type row: List[int] :rtype: int """ # 本题考察的是抽象能力 # 什么情况情侣不能牵手? # 一对情侣 Aa 一定能牵手 # 两队情侣 BA BA 不能牵手 而 AB BA 或者 BA AB是有一对可以牵手的 # 三对情侣只有在情况 AB CB AC 或者 AC BC AB 之类的情况下才算不饿能牵手 # 而此时我们认为,情侣间的顺序关系并不重要 AB CB AC 等价于 AC BC AB 等价于其他 # 也就是说,这三对情侣满足一个环关系 Ab cB aC # 因为A要联通a 我们令Ab联通,能顺着b联通cB,顺着c联通aC # 他们三位之间形成了环(通过情侣关系)
deffind(x): if pa[x]!=x: pa[x]=find(pa[x]) return pa[x] ans=n # 初始路径(自己) # 然后是等长路径C2x=(x*(x-1)/2) for vx,x insorted(zip(vals,range(n))): # 从小到大生成 fx=find(x) # 第几个节点的父节点 for y in g[x]: # 这节点所能到达的所有节点 y=find(y) # 开始遍历邻接边 if y==fx or vals[y]>vx:# 大的后面加进去c continue if vals[y]==vx: ans+=size[fx]*size[y] size[fx]+=size[y] pa[y]=fx return ans