nzbToMedia/libs/common/beets/util/pipeline.py
Labrys of Knossos 56c6773c6b Update vendored beets to 1.6.0
Updates colorama to 0.4.6
Adds confuse version 1.7.0
Updates jellyfish to 0.9.0
Adds mediafile 0.10.1
Updates munkres to 1.1.4
Updates musicbrainzngs to 0.7.1
Updates mutagen to 1.46.0
Updates pyyaml to 6.0
Updates unidecode to 1.3.6
2022-11-29 00:44:48 -05:00

533 lines
16 KiB
Python

# This file is part of beets.
# Copyright 2016, Adrian Sampson.
#
# Permission is hereby granted, free of charge, to any person obtaining
# a copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to
# permit persons to whom the Software is furnished to do so, subject to
# the following conditions:
#
# The above copyright notice and this permission notice shall be
# included in all copies or substantial portions of the Software.
"""Simple but robust implementation of generator/coroutine-based
pipelines in Python. The pipelines may be run either sequentially
(single-threaded) or in parallel (one thread per pipeline stage).
This implementation supports pipeline bubbles (indications that the
processing for a certain item should abort). To use them, yield the
BUBBLE constant from any stage coroutine except the last.
In the parallel case, the implementation transparently handles thread
shutdown when the processing is complete and when a stage raises an
exception. KeyboardInterrupts (^C) are also handled.
When running a parallel pipeline, it is also possible to use
multiple coroutines for the same pipeline stage; this lets you speed
up a bottleneck stage by dividing its work among multiple threads.
To do so, pass an iterable of coroutines to the Pipeline constructor
in place of any single coroutine.
"""
import queue
from threading import Thread, Lock
import sys
BUBBLE = '__PIPELINE_BUBBLE__'
POISON = '__PIPELINE_POISON__'
DEFAULT_QUEUE_SIZE = 16
def _invalidate_queue(q, val=None, sync=True):
"""Breaks a Queue such that it never blocks, always has size 1,
and has no maximum size. get()ing from the queue returns `val`,
which defaults to None. `sync` controls whether a lock is
required (because it's not reentrant!).
"""
def _qsize(len=len):
return 1
def _put(item):
pass
def _get():
return val
if sync:
q.mutex.acquire()
try:
# Originally, we set `maxsize` to 0 here, which is supposed to mean
# an unlimited queue size. However, there is a race condition since
# Python 3.2 when this attribute is changed while another thread is
# waiting in put()/get() due to a full/empty queue.
# Setting it to 2 is still hacky because Python does not give any
# guarantee what happens if Queue methods/attributes are overwritten
# when it is already in use. However, because of our dummy _put()
# and _get() methods, it provides a workaround to let the queue appear
# to be never empty or full.
# See issue https://github.com/beetbox/beets/issues/2078
q.maxsize = 2
q._qsize = _qsize
q._put = _put
q._get = _get
q.not_empty.notifyAll()
q.not_full.notifyAll()
finally:
if sync:
q.mutex.release()
class CountedQueue(queue.Queue):
"""A queue that keeps track of the number of threads that are
still feeding into it. The queue is poisoned when all threads are
finished with the queue.
"""
def __init__(self, maxsize=0):
queue.Queue.__init__(self, maxsize)
self.nthreads = 0
self.poisoned = False
def acquire(self):
"""Indicate that a thread will start putting into this queue.
Should not be called after the queue is already poisoned.
"""
with self.mutex:
assert not self.poisoned
assert self.nthreads >= 0
self.nthreads += 1
def release(self):
"""Indicate that a thread that was putting into this queue has
exited. If this is the last thread using the queue, the queue
is poisoned.
"""
with self.mutex:
self.nthreads -= 1
assert self.nthreads >= 0
if self.nthreads == 0:
# All threads are done adding to this queue. Poison it
# when it becomes empty.
self.poisoned = True
# Replacement _get invalidates when no items remain.
_old_get = self._get
def _get():
out = _old_get()
if not self.queue:
_invalidate_queue(self, POISON, False)
return out
if self.queue:
# Items remain.
self._get = _get
else:
# No items. Invalidate immediately.
_invalidate_queue(self, POISON, False)
class MultiMessage:
"""A message yielded by a pipeline stage encapsulating multiple
values to be sent to the next stage.
"""
def __init__(self, messages):
self.messages = messages
def multiple(messages):
"""Yield multiple([message, ..]) from a pipeline stage to send
multiple values to the next pipeline stage.
"""
return MultiMessage(messages)
def stage(func):
"""Decorate a function to become a simple stage.
>>> @stage
... def add(n, i):
... return i + n
>>> pipe = Pipeline([
... iter([1, 2, 3]),
... add(2),
... ])
>>> list(pipe.pull())
[3, 4, 5]
"""
def coro(*args):
task = None
while True:
task = yield task
task = func(*(args + (task,)))
return coro
def mutator_stage(func):
"""Decorate a function that manipulates items in a coroutine to
become a simple stage.
>>> @mutator_stage
... def setkey(key, item):
... item[key] = True
>>> pipe = Pipeline([
... iter([{'x': False}, {'a': False}]),
... setkey('x'),
... ])
>>> list(pipe.pull())
[{'x': True}, {'a': False, 'x': True}]
"""
def coro(*args):
task = None
while True:
task = yield task
func(*(args + (task,)))
return coro
def _allmsgs(obj):
"""Returns a list of all the messages encapsulated in obj. If obj
is a MultiMessage, returns its enclosed messages. If obj is BUBBLE,
returns an empty list. Otherwise, returns a list containing obj.
"""
if isinstance(obj, MultiMessage):
return obj.messages
elif obj == BUBBLE:
return []
else:
return [obj]
class PipelineThread(Thread):
"""Abstract base class for pipeline-stage threads."""
def __init__(self, all_threads):
super().__init__()
self.abort_lock = Lock()
self.abort_flag = False
self.all_threads = all_threads
self.exc_info = None
def abort(self):
"""Shut down the thread at the next chance possible.
"""
with self.abort_lock:
self.abort_flag = True
# Ensure that we are not blocking on a queue read or write.
if hasattr(self, 'in_queue'):
_invalidate_queue(self.in_queue, POISON)
if hasattr(self, 'out_queue'):
_invalidate_queue(self.out_queue, POISON)
def abort_all(self, exc_info):
"""Abort all other threads in the system for an exception.
"""
self.exc_info = exc_info
for thread in self.all_threads:
thread.abort()
class FirstPipelineThread(PipelineThread):
"""The thread running the first stage in a parallel pipeline setup.
The coroutine should just be a generator.
"""
def __init__(self, coro, out_queue, all_threads):
super().__init__(all_threads)
self.coro = coro
self.out_queue = out_queue
self.out_queue.acquire()
def run(self):
try:
while True:
with self.abort_lock:
if self.abort_flag:
return
# Get the value from the generator.
try:
msg = next(self.coro)
except StopIteration:
break
# Send messages to the next stage.
for msg in _allmsgs(msg):
with self.abort_lock:
if self.abort_flag:
return
self.out_queue.put(msg)
except BaseException:
self.abort_all(sys.exc_info())
return
# Generator finished; shut down the pipeline.
self.out_queue.release()
class MiddlePipelineThread(PipelineThread):
"""A thread running any stage in the pipeline except the first or
last.
"""
def __init__(self, coro, in_queue, out_queue, all_threads):
super().__init__(all_threads)
self.coro = coro
self.in_queue = in_queue
self.out_queue = out_queue
self.out_queue.acquire()
def run(self):
try:
# Prime the coroutine.
next(self.coro)
while True:
with self.abort_lock:
if self.abort_flag:
return
# Get the message from the previous stage.
msg = self.in_queue.get()
if msg is POISON:
break
with self.abort_lock:
if self.abort_flag:
return
# Invoke the current stage.
out = self.coro.send(msg)
# Send messages to next stage.
for msg in _allmsgs(out):
with self.abort_lock:
if self.abort_flag:
return
self.out_queue.put(msg)
except BaseException:
self.abort_all(sys.exc_info())
return
# Pipeline is shutting down normally.
self.out_queue.release()
class LastPipelineThread(PipelineThread):
"""A thread running the last stage in a pipeline. The coroutine
should yield nothing.
"""
def __init__(self, coro, in_queue, all_threads):
super().__init__(all_threads)
self.coro = coro
self.in_queue = in_queue
def run(self):
# Prime the coroutine.
next(self.coro)
try:
while True:
with self.abort_lock:
if self.abort_flag:
return
# Get the message from the previous stage.
msg = self.in_queue.get()
if msg is POISON:
break
with self.abort_lock:
if self.abort_flag:
return
# Send to consumer.
self.coro.send(msg)
except BaseException:
self.abort_all(sys.exc_info())
return
class Pipeline:
"""Represents a staged pattern of work. Each stage in the pipeline
is a coroutine that receives messages from the previous stage and
yields messages to be sent to the next stage.
"""
def __init__(self, stages):
"""Makes a new pipeline from a list of coroutines. There must
be at least two stages.
"""
if len(stages) < 2:
raise ValueError('pipeline must have at least two stages')
self.stages = []
for stage in stages:
if isinstance(stage, (list, tuple)):
self.stages.append(stage)
else:
# Default to one thread per stage.
self.stages.append((stage,))
def run_sequential(self):
"""Run the pipeline sequentially in the current thread. The
stages are run one after the other. Only the first coroutine
in each stage is used.
"""
list(self.pull())
def run_parallel(self, queue_size=DEFAULT_QUEUE_SIZE):
"""Run the pipeline in parallel using one thread per stage. The
messages between the stages are stored in queues of the given
size.
"""
queue_count = len(self.stages) - 1
queues = [CountedQueue(queue_size) for i in range(queue_count)]
threads = []
# Set up first stage.
for coro in self.stages[0]:
threads.append(FirstPipelineThread(coro, queues[0], threads))
# Middle stages.
for i in range(1, queue_count):
for coro in self.stages[i]:
threads.append(MiddlePipelineThread(
coro, queues[i - 1], queues[i], threads
))
# Last stage.
for coro in self.stages[-1]:
threads.append(
LastPipelineThread(coro, queues[-1], threads)
)
# Start threads.
for thread in threads:
thread.start()
# Wait for termination. The final thread lasts the longest.
try:
# Using a timeout allows us to receive KeyboardInterrupt
# exceptions during the join().
while threads[-1].is_alive():
threads[-1].join(1)
except BaseException:
# Stop all the threads immediately.
for thread in threads:
thread.abort()
raise
finally:
# Make completely sure that all the threads have finished
# before we return. They should already be either finished,
# in normal operation, or aborted, in case of an exception.
for thread in threads[:-1]:
thread.join()
for thread in threads:
exc_info = thread.exc_info
if exc_info:
# Make the exception appear as it was raised originally.
raise exc_info[1].with_traceback(exc_info[2])
def pull(self):
"""Yield elements from the end of the pipeline. Runs the stages
sequentially until the last yields some messages. Each of the messages
is then yielded by ``pulled.next()``. If the pipeline has a consumer,
that is the last stage does not yield any messages, then pull will not
yield any messages. Only the first coroutine in each stage is used
"""
coros = [stage[0] for stage in self.stages]
# "Prime" the coroutines.
for coro in coros[1:]:
next(coro)
# Begin the pipeline.
for out in coros[0]:
msgs = _allmsgs(out)
for coro in coros[1:]:
next_msgs = []
for msg in msgs:
out = coro.send(msg)
next_msgs.extend(_allmsgs(out))
msgs = next_msgs
for msg in msgs:
yield msg
# Smoke test.
if __name__ == '__main__':
import time
# Test a normally-terminating pipeline both in sequence and
# in parallel.
def produce():
for i in range(5):
print('generating %i' % i)
time.sleep(1)
yield i
def work():
num = yield
while True:
print('processing %i' % num)
time.sleep(2)
num = yield num * 2
def consume():
while True:
num = yield
time.sleep(1)
print('received %i' % num)
ts_start = time.time()
Pipeline([produce(), work(), consume()]).run_sequential()
ts_seq = time.time()
Pipeline([produce(), work(), consume()]).run_parallel()
ts_par = time.time()
Pipeline([produce(), (work(), work()), consume()]).run_parallel()
ts_end = time.time()
print('Sequential time:', ts_seq - ts_start)
print('Parallel time:', ts_par - ts_seq)
print('Multiply-parallel time:', ts_end - ts_par)
print()
# Test a pipeline that raises an exception.
def exc_produce():
for i in range(10):
print('generating %i' % i)
time.sleep(1)
yield i
def exc_work():
num = yield
while True:
print('processing %i' % num)
time.sleep(3)
if num == 3:
raise Exception()
num = yield num * 2
def exc_consume():
while True:
num = yield
print('received %i' % num)
Pipeline([exc_produce(), exc_work(), exc_consume()]).run_parallel(1)