4229 – Kdist
Bujorel s-a apucat de pomicultură şi a însămânţat un arbore (graf conex aciclic) cu N
noduri, fiecare nod având o culoare dată dintr-un interval [1, K]
. Acum, după ce arborele a crescut, el doreşte să ştie, pentru fiecare culoare, suma distanţelor dintre toate perechile de noduri ale arborelui ce au culoarea respectivă. Distanţa dintre două noduri se defineşte ca fiind numărul de muchii de pe drumul dintre cele două noduri.
Cerința
Deoarece Bujorel a folosit foarte mult îngrăşământ la plantarea arborelui, acesta a crescut foarte mult şi voi trebuie să scrieţi un program care calculează suma distanţelor dintre nodurile cu aceeaşi culoare.
Date de intrare
Pe prima linie se află două numere naturale N
şi K
, numărul de noduri ale arborelui, respectiv numărul de culori în care sunt vopsite nodurile. Pe următoarele N-1
linii este descris arborele, fiecare linie conţinând două numere naturale x
şi y
, reprezentând o muchie dintre nodul x
şi nodul y
. În continuare sunt prezente N
linii, a i
-a dintre aceste linii având un număr întreg c
aparţinând intervalului [1, K]
reprezentând culoarea nodului i
.
Date de ieșire
Se vor afișa K
linii, cea de-a i
-a linie conţinând suma distanţelor dintre toate perechile de noduri ce au culoarea i.
Restricții și precizări
1 ≤ K ≤ N ≤ 200.000
- În calcularea sumei distanţelor dintre nodurile cu aceeaşi culoare, fiecare pereche de noduri
(x, y)
va fi considerată o singură dată.
Exemplu:
Intrare
6 3 1 2 1 3 3 4 3 5 5 6 1 2 2 1 2 3
Ieșire
2 6 0
Explicație
Pentru culoarea 1
avem o pereche: (1, 4)
, cu distanţa 2
. Pentru culoarea 2
avem trei perechi: (2, 3)
cu distanţa 2
, (2, 5)
cu distanţa 3
și (3, 5)
cu distanţa 1
. Pentru culoarea 3
avem un singur nod cu această culoare, deci răspunsul este 0
.\
Rezolvare
<syntaxhighlight lang="python3"> from collections import defaultdict, deque
def calculate_color_distances(n, k, edges, colors):
tree = defaultdict(list) for u, v in edges: tree[u].append(v) tree[v].append(u)
color_nodes = defaultdict(list) for i in range(n): color_nodes[colors[i]].append(i + 1)
def bfs_distance_sum(start, valid_nodes): visited = set() queue = deque([(start, 0)]) visited.add(start) distance_sum = 0 node_count = 0
while queue: node, dist = queue.popleft() distance_sum += dist node_count += 1
for neighbor in tree[node]: if neighbor not in visited and neighbor in valid_nodes: visited.add(neighbor) queue.append((neighbor, dist + 1))
return distance_sum, node_count
color_distance_sums = [0] * k for color in range(1, k + 1): nodes = color_nodes[color] if len(nodes) < 2: continue
total_distance_sum = 0 for node in nodes: dist_sum, node_count = bfs_distance_sum(node, set(nodes)) total_distance_sum += dist_sum
color_distance_sums[color - 1] = total_distance_sum // 2
return color_distance_sums
n, k = map(int, input().split()) edges = [tuple(map(int, input().split())) for _ in range(n - 1)] colors = [int(input()) for _ in range(n)]
result = calculate_color_distances(n, k, edges, colors)
for dist_sum in result:
print(dist_sum)
</syntaxhighlight>