diff --git a/bot.py b/bot.py index d2259ea..e7b4153 100644 --- a/bot.py +++ b/bot.py @@ -15,7 +15,8 @@ from dotenv import load_dotenv from lib.config import config, config_load, config_save, config_get, config_set, config_get_descriptions, \ config_set_raw, config_meta -from lib.utils import async_filter, find_category, find_role_case_insensitive, link_channel +from lib.utils import async_filter, find_category, find_role_case_insensitive, link_channel, connect_and_play, \ + text_to_speech VERSION = "1.0.2" @@ -105,10 +106,9 @@ async def on_message(message: discord.Message): voice: discord.VoiceState = loeh.voice try: voice_channel: discord.VoiceChannel = voice.channel - voice_protocol: discord.VoiceProtocol = await voice_channel.connect() - if type(voice_protocol) is discord.VoiceClient: - source = discord.PCMVolumeTransformer(discord.FFmpegPCMAudio(source=OCH_LOEH_SOUND)) - voice_protocol.play(source) + source = discord.PCMVolumeTransformer(discord.FFmpegPCMAudio(source=OCH_LOEH_SOUND)) + voice_protocol = await connect_and_play(voice_channel, source=source) + await loeh.edit(mute=True) sleeper = asyncio.sleep(config_get('och-timeout', message.guild.id)) @@ -118,7 +118,7 @@ async def on_message(message: discord.Message): if message is not None: await message.edit(content="~~Zu Befehl!~~\nEs sei ihm verziehen.") if type(voice_protocol) is discord.VoiceClient: - await voice_protocol.disconnect() + await voice_protocol.disconnect(force=True) except (asyncio.TimeoutError, discord.Forbidden, discord.HTTPException, discord.ClientException): await message.channel.send('Failed to complete your command, Sir') return @@ -145,15 +145,8 @@ async def on_message(message: discord.Message): matches.append(member) if matches and voice is not None: - voice_protocol: discord.VoiceProtocol = await voice.connect() - if type(voice_protocol) is discord.VoiceClient: - tts = gtts.gTTS(message.content, lang='de') - os.makedirs('temp', exist_ok=True) - tts.save('temp/och.mp3') - source = discord.PCMVolumeTransformer( - discord.FFmpegPCMAudio(source='temp/och.mp3', before_options='-v quiet') - ) - voice_protocol.play(source) + source, destroy_tts = await text_to_speech(message.content) + voice_protocol: Optional[discord.VoiceProtocol] = await connect_and_play(voice, source) async def _mute(m: discord.Member): await m.edit(mute=True) @@ -173,11 +166,11 @@ async def on_message(message: discord.Message): if message is not None: waiter = message.edit(content='~~Auf gehts!~~\nGeschafft!') if type(voice_protocol) is discord.VoiceClient: - await voice_protocol.disconnect() + await voice_protocol.disconnect(force=True) if waiter is not None: await waiter - os.remove('temp/och.mp3') + destroy_tts() else: await message.channel.send('404: No users found!') elif config_get('inf19x-insiders-enable', message.guild.id): @@ -766,11 +759,17 @@ def get_quotes(guild_id: int) -> Dict[str, List[str]]: description="The author to filter by", option_type=str, required=False + ), + create_option( + name="tts", + description="Text to speech to voice chat", + option_type=SlashCommandOptionType.BOOLEAN, + required=False ) ], guild_ids=slash_guild_ids ) -async def quote_random_slash(ctx: SlashContext, author: Optional[str] = None): +async def quote_random_slash(ctx: SlashContext, author: Optional[str] = None, tts: bool = False): if not await check_slash_context(ctx, False): return @@ -789,8 +788,20 @@ async def quote_random_slash(ctx: SlashContext, author: Optional[str] = None): return author_quotes = quotes[author] - quote = '\n'.join(map(lambda line: '> ' + line, random.choice(author_quotes).splitlines())) - await ctx.send(quote + '\n *~' + author + '*') + quote = random.choice(author_quotes) + text = '\n'.join(map(lambda line: '> ' + line, quote.splitlines())) + await ctx.send(text + '\n *~' + author + '*') + + if tts: + voice: discord.VoiceState = ctx.author.voice + if voice.channel is not None: + async def after_play(e: discord.DiscordException, vp: discord.VoiceProtocol): + await vp.disconnect(force=True) + + source, tts_destroyer = text_to_speech(quote) + await connect_and_play(voice.channel, source, after_play=after_play) + tts_destroyer() + @slash.subcommand( diff --git a/lib/utils.py b/lib/utils.py index 58e3b4d..9437c9f 100644 --- a/lib/utils.py +++ b/lib/utils.py @@ -1,6 +1,11 @@ +import asyncio +import os +import random +import string from collections import AsyncIterable -from typing import Callable, AsyncGenerator, Optional +from typing import Callable, AsyncGenerator, Optional, Any, Coroutine import discord +import gtts async def async_filter(fun: Callable, iterable: AsyncIterable) -> AsyncGenerator: @@ -29,3 +34,38 @@ def link_channel(channel: discord.abc.GuildChannel, italic: bool = False) -> str return '[*' + channel.name + '*](https://discord.com/channels/' + str(channel.guild.id) + '/' + str( channel.id) + ')' return '[' + channel.name + '](https://discord.com/channels/' + str(channel.guild.id) + '/' + str(channel.id) + ')' + + +def text_to_speech(text: str, lang: str = "de") -> (discord.AudioSource, Callable[[], None]): + tts = gtts.gTTS(text, lang=lang) + os.makedirs('temp', exist_ok=True) + file_name = 'temp/' + ''.join(random.choice(string.ascii_lowercase) for i in range(12)) + '.mp3' + tts.save(file_name) + source = discord.PCMVolumeTransformer( + discord.FFmpegPCMAudio(source=file_name, before_options='-v quiet') + ) + + def destroy(): + os.remove(file_name) + + return source, destroy + + +async def connect_and_play( + channel: discord.VoiceChannel, source: discord.AudioSource, + after_play: Optional[Callable[[discord.DiscordException, discord.VoiceProtocol], Coroutine[Any, Any, Any]]] = None +) -> Optional[discord.VoiceProtocol]: + # noinspection PyTypeChecker + client: discord.VoiceClient = await channel.connect() + after_callback: Optional[Callable[[discord.DiscordException], Any]] + if after_play is None: + after_callback = None + else: + def callback(exc: discord.DiscordException) -> Any: + loop = asyncio.get_event_loop() + return loop.run_until_complete(after_play(exc, client)) + + after_callback = callback + + client.play(source, after=after_callback) + return client