Adds tests for common whisper

This commit is contained in:
Guillaume Tâche
2024-09-16 14:01:57 +02:00
parent dcadbcaf36
commit 3cc2f7a0c9
38 changed files with 796 additions and 130 deletions

View File

@@ -9,7 +9,7 @@ import java.io.InputStream;
public interface Audio {
/**
* @return The audio input stream
* @return The input stream
* @throws IOException If an I/O error occurs
*/
InputStream getInputStream() throws IOException;

View File

@@ -4,13 +4,14 @@ package com.github.gtache.autosubtitle;
* Represents info about an audio
*/
public interface AudioInfo {
/**
* @return The audio extension (mp3, etc.)
*/
String audioFormat();
/**
* @return The audio duration in milliseconds
* @return The extension (mp3, etc.)
*/
String format();
/**
* @return The duration in milliseconds
*/
long duration();
}

View File

@@ -7,9 +7,8 @@ import java.io.InputStream;
* Represents a video
*/
public interface Video {
/**
* @return The video input stream
* @return The input stream
* @throws IOException If an I/O error occurs
*/
InputStream getInputStream() throws IOException;
@@ -18,5 +17,4 @@ public interface Video {
* @return The video info
*/
VideoInfo info();
}

View File

@@ -3,11 +3,7 @@ package com.github.gtache.autosubtitle;
/**
* Info about a video
*/
public interface VideoInfo {
/**
* @return The video extension (mp4, etc.)
*/
String videoFormat();
public interface VideoInfo extends AudioInfo {
/**
* @return The video width in pixels
@@ -19,11 +15,6 @@ public interface VideoInfo {
*/
int height();
/**
* @return The video duration in milliseconds
*/
long duration();
/**
* @return The aspect ratio of the video
*/

View File

@@ -46,7 +46,7 @@ public interface SubtitleExtractor<T extends Subtitle> {
* Extracts the subtitles from a video
*
* @param video The video
* @param language The language of the video
* @param language The language of the audio
* @param model The model to use
* @return The extracted subtitle collection
* @throws ExtractException If an error occurs

View File

@@ -21,10 +21,10 @@ import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class)
class TestSubtitleConverter {
private final SubtitleConverter subtitleConverter;
private final SubtitleConverter<Subtitle> subtitleConverter;
private final SubtitleCollection<Subtitle> subtitleCollection;
TestSubtitleConverter(@Mock final SubtitleConverter subtitleConverter,
TestSubtitleConverter(@Mock final SubtitleConverter<Subtitle> subtitleConverter,
@Mock final SubtitleCollection<Subtitle> subtitleCollection) {
this.subtitleConverter = Objects.requireNonNull(subtitleConverter);
this.subtitleCollection = Objects.requireNonNull(subtitleCollection);

View File

@@ -3,6 +3,7 @@ package com.github.gtache.autosubtitle.subtitle.extractor;
import com.github.gtache.autosubtitle.Audio;
import com.github.gtache.autosubtitle.Language;
import com.github.gtache.autosubtitle.Video;
import com.github.gtache.autosubtitle.subtitle.Subtitle;
import com.github.gtache.autosubtitle.subtitle.SubtitleCollection;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
@@ -18,14 +19,14 @@ import static org.mockito.Mockito.when;
@ExtendWith(MockitoExtension.class)
class TestSubtitleExtractor {
private final SubtitleExtractor subtitleExtractor;
private final SubtitleCollection subtitleCollection;
private final SubtitleExtractor<Subtitle> subtitleExtractor;
private final SubtitleCollection<Subtitle> subtitleCollection;
private final Audio audio;
private final Video video;
private final ExtractionModel extractionModel;
TestSubtitleExtractor(@Mock final SubtitleExtractor subtitleExtractor,
@Mock final SubtitleCollection subtitleCollection,
TestSubtitleExtractor(@Mock final SubtitleExtractor<Subtitle> subtitleExtractor,
@Mock final SubtitleCollection<Subtitle> subtitleCollection,
@Mock final Audio audio,
@Mock final Video video,
@Mock final ExtractionModel extractionModel) {

View File

@@ -6,4 +6,5 @@ module com.github.gtache.autosubtitle.cli {
requires com.github.gtache.autosubtitle.ffmpeg;
requires com.github.gtache.autosubtitle.whisperx;
requires info.picocli;
requires guava;
}

View File

@@ -1,27 +0,0 @@
package com.github.gtache.autosubtitle.archive.client;
import com.github.gtache.autosubtitle.archive.Archiver;
import java.io.IOException;
import java.nio.file.Path;
import java.util.List;
/**
* {@link Archiver} using a remote API
*/
public class RemoteArchiver implements Archiver {
@Override
public void compress(final List<Path> files, final Path destination) throws IOException {
throw new UnsupportedOperationException();
}
@Override
public void decompress(final Path archive, final Path destination) throws IOException {
throw new UnsupportedOperationException();
}
@Override
public String archiveExtension() {
throw new UnsupportedOperationException();
}
}

View File

@@ -7,10 +7,10 @@ import java.util.Objects;
/**
* Implementation of {@link AudioInfo}
*/
public record AudioInfoImpl(String audioFormat, long duration) implements AudioInfo {
public record AudioInfoImpl(String format, long duration) implements AudioInfo {
public AudioInfoImpl {
Objects.requireNonNull(audioFormat);
Objects.requireNonNull(format);
if (duration < 0) {
throw new IllegalArgumentException("Duration must be positive");
}

View File

@@ -7,10 +7,10 @@ import java.util.Objects;
/**
* Implementation of {@link VideoInfo}
*/
public record VideoInfoImpl(String videoFormat, int width, int height, long duration) implements VideoInfo {
public record VideoInfoImpl(String format, int width, int height, long duration) implements VideoInfo {
public VideoInfoImpl {
Objects.requireNonNull(videoFormat);
Objects.requireNonNull(format);
if (width <= 0) {
throw new IllegalArgumentException("Width must be greater than 0 : " + width);
}

View File

@@ -11,14 +11,14 @@ class TestAudioInfoImpl {
private final long duration;
TestAudioInfoImpl() {
this.audioFormat = "audioFormat";
this.audioFormat = "format";
this.duration = 1000L;
}
@Test
void testGetters() {
final var audioInfo = new AudioInfoImpl(audioFormat, duration);
assertEquals(audioFormat, audioInfo.audioFormat());
assertEquals(audioFormat, audioInfo.format());
assertEquals(duration, audioInfo.duration());
}

View File

@@ -13,7 +13,7 @@ class TestVideoInfoImpl {
private final long duration;
TestVideoInfoImpl() {
this.videoFormat = "videoFormat";
this.videoFormat = "format";
this.width = 1;
this.height = 2;
this.duration = 3;
@@ -22,7 +22,7 @@ class TestVideoInfoImpl {
@Test
void testGetters() {
final var videoInfo = new VideoInfoImpl(videoFormat, width, height, duration);
assertEquals(videoFormat, videoInfo.videoFormat());
assertEquals(videoFormat, videoInfo.format());
assertEquals(width, videoInfo.width());
assertEquals(height, videoInfo.height());
assertEquals(duration, videoInfo.duration());

View File

@@ -57,7 +57,7 @@ class TestAbstractSubtitleExtractor {
private static final class DummySubtitleExtractor extends AbstractSubtitleExtractor {
@Override
public SubtitleCollection<Subtitle> extract(final Video video, final Language language, final ExtractionModel model) throws ExtractException {
public SubtitleCollection extract(final Video video, final Language language, final ExtractionModel model) throws ExtractException {
throw new UnsupportedOperationException();
}

View File

@@ -14,6 +14,7 @@
<properties>
<deepl.version>1.5.0</deepl.version>
<lingua.version>1.2.2</lingua.version>
<tika.version>2.9.2</tika.version>
</properties>
@@ -28,9 +29,15 @@
<version>${deepl.version}</version>
</dependency>
<dependency>
<groupId>com.github.pemistahl</groupId>
<artifactId>lingua</artifactId>
<version>${lingua.version}</version>
<groupId>org.apache.tika</groupId>
<artifactId>tika-core</artifactId>
<version>${tika.version}</version>
</dependency>
<dependency>
<groupId>org.apache.tika</groupId>
<artifactId>tika-langdetect-optimaize</artifactId>
<version>${tika.version}</version>
</dependency>
</dependencies>
</project>

View File

@@ -3,13 +3,14 @@ package com.github.gtache.autosubtitle.modules.deepl;
import com.github.gtache.autosubtitle.modules.setup.deepl.DeepLSetupModule;
import com.github.gtache.autosubtitle.translation.Translator;
import com.github.gtache.autosubtitle.translation.deepl.DeepLTranslator;
import com.github.pemistahl.lingua.api.LanguageDetector;
import com.github.pemistahl.lingua.api.LanguageDetectorBuilder;
import dagger.Binds;
import dagger.Module;
import dagger.Provides;
import org.apache.tika.language.detect.LanguageDetector;
import javax.inject.Singleton;
import java.io.IOException;
import java.io.UncheckedIOException;
/**
* Dagger module for DeepL
@@ -23,6 +24,10 @@ public abstract class DeepLModule {
@Provides
@Singleton
static LanguageDetector providesLanguageDetector() {
return LanguageDetectorBuilder.fromAllSpokenLanguages().build();
try {
return LanguageDetector.getDefaultLanguageDetector().loadModels();
} catch (final IOException e) {
throw new UncheckedIOException(e);
}
}
}

View File

@@ -32,7 +32,7 @@ public class DeepLSetupManager extends AbstractSetupManager {
}
@Override
protected SetupStatus getStatus() throws SetupException {
protected SetupStatus getStatus() {
final var key = preferences.get(DEEPL_API_KEY, null);
return key == null ? SetupStatus.NOT_INSTALLED : SetupStatus.BUNDLE_INSTALLED;
}
@@ -68,7 +68,7 @@ public class DeepLSetupManager extends AbstractSetupManager {
}
@Override
public void update() throws SetupException {
public void update() {
//No need to update
}
}

View File

@@ -9,7 +9,7 @@ import com.github.gtache.autosubtitle.subtitle.impl.SubtitleCollectionImpl;
import com.github.gtache.autosubtitle.subtitle.impl.SubtitleImpl;
import com.github.gtache.autosubtitle.translation.TranslationException;
import com.github.gtache.autosubtitle.translation.Translator;
import com.github.pemistahl.lingua.api.LanguageDetector;
import org.apache.tika.language.detect.LanguageDetector;
import javax.inject.Inject;
import java.util.ArrayList;
@@ -37,7 +37,7 @@ public class DeepLTranslator implements Translator<Subtitle> {
@Override
public Language getLanguage(final String text) {
return Language.getLanguage(languageDetector.detectLanguageOf(text).getIsoCode639_1().toString());
return Language.getLanguage(languageDetector.detect(text).getLanguage());
}
@Override

View File

@@ -4,8 +4,9 @@
module com.github.gtache.autosubtitle.deepl {
requires transitive com.github.gtache.autosubtitle.core;
requires transitive deepl.java;
requires transitive com.github.pemistahl.lingua;
requires transitive java.prefs;
requires transitive org.apache.tika.core;
requires transitive language.detector;
exports com.github.gtache.autosubtitle.modules.deepl;
exports com.github.gtache.autosubtitle.translation.deepl;
exports com.github.gtache.autosubtitle.setup.deepl;

View File

@@ -0,0 +1,15 @@
package com.github.gtache.autosubtitle.modules.deepl;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.assertEquals;
class TestDeepLModule {
@Test
void testLanguageDetector() {
final var languageDetector = DeepLModule.providesLanguageDetector();
assertEquals("fr", languageDetector.detect("bonjour tout le monde").getLanguage());
assertEquals("en", languageDetector.detect("hello everyone this is a text").getLanguage());
}
}

View File

@@ -0,0 +1,99 @@
package com.github.gtache.autosubtitle.setup.deepl;
import com.github.gtache.autosubtitle.process.ProcessRunner;
import com.github.gtache.autosubtitle.setup.SetupException;
import com.github.gtache.autosubtitle.setup.SetupStatus;
import com.github.gtache.autosubtitle.setup.SetupUserBridge;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import java.net.http.HttpClient;
import java.util.Objects;
import java.util.prefs.BackingStoreException;
import java.util.prefs.Preferences;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Mockito.*;
@ExtendWith(MockitoExtension.class)
class TestDeepLSetupManager {
private final SetupUserBridge userBridge;
private final Preferences preferences;
private final ProcessRunner processRunner;
private final HttpClient httpClient;
private final DeepLSetupManager setupManager;
TestDeepLSetupManager(@Mock final SetupUserBridge userBridge, @Mock final Preferences preferences,
@Mock final ProcessRunner processRunner, @Mock final HttpClient httpClient) {
this.userBridge = Objects.requireNonNull(userBridge);
this.preferences = Objects.requireNonNull(preferences);
this.processRunner = Objects.requireNonNull(processRunner);
this.httpClient = Objects.requireNonNull(httpClient);
this.setupManager = new DeepLSetupManager(userBridge, preferences, processRunner, httpClient);
}
@Test
void testGetStatus() {
assertEquals(SetupStatus.NOT_INSTALLED, setupManager.getStatus());
when(preferences.get("deepl.api.key", null)).thenReturn("key");
assertEquals(SetupStatus.BUNDLE_INSTALLED, setupManager.getStatus());
}
@Test
void testInstall() throws BackingStoreException, SetupException {
final var key = "key";
when(userBridge.askForUserInput(any())).thenReturn(key);
setupManager.install();
verify(preferences).put("deepl.api.key", key);
verify(preferences).flush();
}
@Test
void testInstallException() throws BackingStoreException {
final var key = "key";
when(userBridge.askForUserInput(any())).thenReturn(key);
doThrow(BackingStoreException.class).when(preferences).flush();
assertThrows(SetupException.class, setupManager::install);
verify(preferences).put("deepl.api.key", key);
verify(preferences).flush();
}
@Test
void testInstallNull() {
assertThrows(SetupException.class, setupManager::install);
verifyNoInteractions(preferences);
}
@Test
void testUninstall() throws BackingStoreException, SetupException {
setupManager.uninstall();
verify(preferences).remove("deepl.api.key");
verify(preferences).flush();
}
@Test
void testUninstallException() throws BackingStoreException {
doThrow(BackingStoreException.class).when(preferences).flush();
assertThrows(SetupException.class, setupManager::uninstall);
verify(preferences).remove("deepl.api.key");
verify(preferences).flush();
}
@Test
void testName() {
assertEquals("DeepL", setupManager.name());
}
@Test
void testIllegal() {
assertThrows(NullPointerException.class, () -> new DeepLSetupManager(null, preferences, processRunner, httpClient));
assertThrows(NullPointerException.class, () -> new DeepLSetupManager(userBridge, null, processRunner, httpClient));
assertThrows(NullPointerException.class, () -> new DeepLSetupManager(userBridge, preferences, null, httpClient));
assertThrows(NullPointerException.class, () -> new DeepLSetupManager(userBridge, preferences, processRunner, null));
}
}

View File

@@ -0,0 +1,5 @@
package com.github.gtache.autosubtitle.translation.deepl;
//TODO mock, postman, ...
class TestDeepLTranslator {
}

View File

@@ -58,9 +58,9 @@ public class FFmpegVideoConverter implements VideoConverter {
@Override
public Video addSoftSubtitles(final Video video, final Collection<? extends SubtitleCollection<?>> subtitles) throws IOException {
final var out = getTempFile(video.info().videoFormat());
final var out = getTempFile(video.info().format());
addSoftSubtitles(video, subtitles, out);
return new FileVideoImpl(out, new VideoInfoImpl(video.info().videoFormat(), video.info().width(), video.info().height(), video.info().duration()));
return new FileVideoImpl(out, new VideoInfoImpl(video.info().format(), video.info().width(), video.info().height(), video.info().duration()));
}
@Override
@@ -110,7 +110,7 @@ public class FFmpegVideoConverter implements VideoConverter {
@Override
public Video addHardSubtitles(final Video video, final SubtitleCollection<?> subtitles) throws IOException {
final var out = getTempFile(video.info().videoFormat());
final var out = getTempFile(video.info().format());
addHardSubtitles(video, subtitles, out);
return new FileVideoImpl(out, video.info());
}
@@ -145,7 +145,7 @@ public class FFmpegVideoConverter implements VideoConverter {
public Audio getAudio(final Video video) throws IOException {
final var videoPath = getPath(video);
final var audioPath = getTempFile(".wav");
final var dumpVideoPath = getTempFile("." + video.info().videoFormat());
final var dumpVideoPath = getTempFile("." + video.info().format());
final var args = List.of(
getFFmpegPath(),
"-y",
@@ -172,7 +172,7 @@ public class FFmpegVideoConverter implements VideoConverter {
}
private static Path dumpVideo(final Video video) throws IOException {
final var path = getTempFile(video.info().videoFormat());
final var path = getTempFile(video.info().format());
try (final var out = Files.newOutputStream(path)) {
video.getInputStream().transferTo(out);
}

View File

@@ -16,7 +16,7 @@ public interface ParametersModel {
List<ExtractionModel> availableExtractionModels();
/**
* @return The current extraction model
* @return The current extraction model setting
*/
ExtractionModel extractionModel();
@@ -31,7 +31,7 @@ public interface ParametersModel {
List<OutputFormat> availableOutputFormats();
/**
* @return The current output format
* @return The current output format setting
*/
OutputFormat outputFormat();
@@ -46,7 +46,7 @@ public interface ParametersModel {
List<String> availableFontFamilies();
/**
* @return The current font family
* @return The current font family setting
*/
String fontName();
@@ -56,7 +56,7 @@ public interface ParametersModel {
void setFontName(String fontFamily);
/**
* @return The current font size
* @return The current font size setting
*/
int fontSize();
@@ -65,11 +65,27 @@ public interface ParametersModel {
*/
void setFontSize(int fontSize);
/**
* @return The current max line length setting
*/
int maxLineLength();
/**
* Sets the max line length
*
* @param maxLineLength The new max line length
*/
void setMaxLineLength(int maxLineLength);
/**
* @return The current max lines setting
*/
int maxLines();
/**
* Sets the max lines
*
* @param maxLines The new max lines
*/
void setMaxLines(int maxLines);
}

View File

@@ -12,15 +12,15 @@ class TestColonTimeFormatter {
@ParameterizedTest
@CsvSource({
"12:34:56,45296000",
"12:34,754000",
"01:02,62000",
"1:2,62000",
"01:02:03,3723000",
"1:2:3,3723000",
"00:00:03,3000",
"00:03,3000",
"1234:00:01,4442401000"
"12:34:56.000,45296000",
"12:34.321,754321",
"01:02.000,62000",
"1:2.555,62555",
"01:02:03.000,3723000",
"1:2:3.000,3723000",
"00:00:03.000,3000",
"00:03.000,3000",
"1234:00:01.000,4442401000"
})
void testParse(final String time, final long millis) {
assertEquals(millis, timeFormatter.parse(time));
@@ -28,14 +28,14 @@ class TestColonTimeFormatter {
@ParameterizedTest
@CsvSource({
"45296000,12:34:56",
"45296521,12:34:56",
"754000,12:34",
"754620,12:34",
"62000,01:02",
"3723000,1:02:03",
"3000,00:03",
"4442401000,1234:00:01"
"45296000,12:34:56.000",
"45296521,12:34:56.521",
"754000,12:34.000",
"754620,12:34.620",
"62000,01:02.000",
"3723000,1:02:03.000",
"3000,00:03.000",
"4442401000,1234:00:01.000"
})
void testFormat(final long millis, final String time) {
assertEquals(time, timeFormatter.format(millis));

View File

@@ -8,6 +8,7 @@ module com.github.gtache.autosubtitle.gui.run {
requires com.github.gtache.autosubtitle.whisperx;
requires javafx.fxml;
requires javafx.graphics;
requires guava;
opens com.github.gtache.autosubtitle.gui.run to javafx.graphics;
}

View File

@@ -4,7 +4,7 @@ import com.github.gtache.autosubtitle.modules.setup.whisper.base.WhisperSetupMod
import com.github.gtache.autosubtitle.modules.subtitle.extractor.whisper.base.WhisperExtractorModule;
import com.github.gtache.autosubtitle.modules.subtitle.parser.json.whisper.base.WhisperJsonModule;
import com.github.gtache.autosubtitle.subtitle.extractor.ExtractionModelProvider;
import com.github.gtache.autosubtitle.whisper.WhisperExtractionModelProvider;
import com.github.gtache.autosubtitle.whisper.base.WhisperExtractionModelProvider;
import dagger.Binds;
import dagger.Module;

View File

@@ -1,7 +1,8 @@
package com.github.gtache.autosubtitle.whisper;
package com.github.gtache.autosubtitle.whisper.base;
import com.github.gtache.autosubtitle.subtitle.extractor.ExtractionModel;
import com.github.gtache.autosubtitle.subtitle.extractor.ExtractionModelProvider;
import com.github.gtache.autosubtitle.whisper.WhisperModels;
import javax.inject.Inject;
import javax.inject.Singleton;

View File

@@ -8,8 +8,9 @@ module com.github.gtache.autosubtitle.whisper.base {
requires com.google.gson;
exports com.github.gtache.autosubtitle.setup.whisper.base;
exports com.github.gtache.autosubtitle.modules.whisper.base;
exports com.github.gtache.autosubtitle.modules.setup.whisper.base;
exports com.github.gtache.autosubtitle.subtitle.extractor.whisper.base;
exports com.github.gtache.autosubtitle.whisper.base;
exports com.github.gtache.autosubtitle.modules.setup.whisper.base;
exports com.github.gtache.autosubtitle.modules.subtitle.extractor.whisper.base;
exports com.github.gtache.autosubtitle.modules.whisper.base;
}

View File

@@ -0,0 +1,32 @@
package com.github.gtache.autosubtitle.whisper.base;
import com.github.gtache.autosubtitle.subtitle.extractor.ExtractionModelProvider;
import com.github.gtache.autosubtitle.whisper.WhisperModels;
import org.junit.jupiter.api.Test;
import java.util.Arrays;
import static org.junit.jupiter.api.Assertions.assertEquals;
class TestWhisperExtractionModelProvider {
private final ExtractionModelProvider provider = new WhisperExtractionModelProvider();
@Test
void testGetAvailableExtractionModels() {
assertEquals(Arrays.asList(WhisperModels.values()), provider.getAvailableExtractionModels());
}
@Test
void testGetDefaultExtractionModel() {
assertEquals(WhisperModels.MEDIUM, provider.getDefaultExtractionModel());
}
@Test
void testGetExtractionModel() {
for (final var value : WhisperModels.values()) {
assertEquals(value, provider.getExtractionModel(value.name()));
assertEquals(value, provider.getExtractionModel(value.name().toLowerCase()));
}
}
}

View File

@@ -22,5 +22,4 @@ public final class WhisperCommonSetupModule {
final OS os) {
return new WhisperSetupConfiguration(root, venvPath, pythonVersion, os);
}
}

View File

@@ -23,6 +23,7 @@ import org.apache.logging.log4j.Logger;
import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardCopyOption;
import java.time.Duration;
import java.util.HashSet;
import java.util.List;
@@ -72,48 +73,43 @@ public abstract class AbstractWhisperSubtitleExtractor implements SubtitleExtrac
listeners.clear();
}
private void notifyListeners(final ExtractEvent event) {
void notifyListeners(final ExtractEvent event) {
listeners.forEach(listener -> listener.listen(event));
}
@Override
public SubtitleCollection<Subtitle> extract(final Video video, final Language language, final ExtractionModel model) throws ExtractException {
if (video instanceof final File f) {
return extract(f.path(), language, model, video.info().duration());
} else {
try {
final var path = Files.createTempFile(AUTOSUBTITLE, video.info().videoFormat());
try (final var in = video.getInputStream()) {
Files.copy(in, path);
final var ret = extract(path, language, model, video.info().duration());
Files.deleteIfExists(path);
return ret;
}
} catch (final IOException e) {
throw new ExtractException(e);
}
}
return extract(new AudioOrVideo(video), language, model);
}
@Override
public SubtitleCollection<Subtitle> extract(final Audio audio, final Language language, final ExtractionModel model) throws ExtractException {
if (audio instanceof final File f) {
return extract(f.path(), language, model, audio.info().duration());
return extract(new AudioOrVideo(audio), language, model);
}
private SubtitleCollection<Subtitle> extract(final AudioOrVideo av, final Language language, final ExtractionModel model) throws ExtractException {
if (av.inner() instanceof final File f) {
return extract(f.path(), language, model, av.info().duration());
} else {
try {
final var path = Files.createTempFile(AUTOSUBTITLE, audio.info().audioFormat());
try (final var in = audio.getInputStream()) {
Files.copy(in, path);
final var ret = extract(path, language, model, audio.info().duration());
Files.deleteIfExists(path);
return ret;
}
return dumpExtract(av, language, model);
} catch (final IOException e) {
throw new ExtractException(e);
}
}
}
private SubtitleCollection<Subtitle> dumpExtract(final AudioOrVideo av, final Language language, final ExtractionModel model) throws ExtractException, IOException {
final var path = Files.createTempFile(AUTOSUBTITLE, "." + av.info().format());
try (final var in = av.getInputStream()) {
Files.copy(in, path, StandardCopyOption.REPLACE_EXISTING);
return extract(path, language, model, av.info().duration());
} finally {
Files.deleteIfExists(path);
}
}
private SubtitleCollection<Subtitle> extract(final Path path, final Language language, final ExtractionModel model, final long duration) throws ExtractException {
try {
final var outputDir = Files.createTempDirectory(AUTOSUBTITLE);

View File

@@ -0,0 +1,57 @@
package com.github.gtache.autosubtitle.subtitle.extractor.whisper;
import com.github.gtache.autosubtitle.Audio;
import com.github.gtache.autosubtitle.AudioInfo;
import com.github.gtache.autosubtitle.Video;
import java.io.IOException;
import java.io.InputStream;
import static java.util.Objects.requireNonNull;
/**
* Used internally for the Whisper subtitle extractor. Either audio or video must be not null
*
* @param audio The possible audio
* @param video The possible video
*/
public record AudioOrVideo(Audio audio, Video video) implements Audio {
public AudioOrVideo {
if (audio == null) {
requireNonNull(video);
} else if (video != null) {
throw new IllegalArgumentException("Either audio or video must be null");
}
}
public AudioOrVideo(final Audio audio) {
this(audio, null);
}
public AudioOrVideo(final Video video) {
this(null, video);
}
public <T> T inner() {
return (T) (audio == null ? video : audio);
}
@Override
public InputStream getInputStream() throws IOException {
if (audio == null) {
return video.getInputStream();
} else {
return audio.getInputStream();
}
}
@Override
public AudioInfo info() {
if (audio == null) {
return video.info();
} else {
return audio.info();
}
}
}

View File

@@ -0,0 +1,23 @@
package com.github.gtache.autosubtitle.modules.setup.whisper;
import com.github.gtache.autosubtitle.impl.OS;
import com.github.gtache.autosubtitle.setup.whisper.WhisperSetupConfiguration;
import org.junit.jupiter.api.Test;
import java.nio.file.Path;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.mockito.Mockito.mock;
class TestWhisperCommonSetupModule {
@Test
void testWhisperSetupConfiguration() {
final var root = mock(Path.class);
final var venvPath = mock(Path.class);
final var pythonVersion = "3.10";
final var os = OS.LINUX;
final var expected = new WhisperSetupConfiguration(root, venvPath, pythonVersion, os);
assertEquals(expected, WhisperCommonSetupModule.providesWhisperSetupConfiguration(root, venvPath, pythonVersion, os));
}
}

View File

@@ -0,0 +1,171 @@
package com.github.gtache.autosubtitle.setup.whisper;
import com.github.gtache.autosubtitle.impl.OS;
import com.github.gtache.autosubtitle.process.ProcessRunner;
import com.github.gtache.autosubtitle.setup.SetupException;
import com.github.gtache.autosubtitle.setup.conda.CondaSetupManager;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.api.io.TempDir;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import java.io.IOException;
import java.net.http.HttpClient;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.util.Objects;
import static com.github.gtache.autosubtitle.setup.SetupStatus.BUNDLE_INSTALLED;
import static com.github.gtache.autosubtitle.setup.SetupStatus.NOT_INSTALLED;
import static org.junit.jupiter.api.Assertions.*;
import static org.mockito.Mockito.*;
@ExtendWith(MockitoExtension.class)
class TestAbstractWhisperSetupManager {
private final CondaSetupManager condaSetupManager;
private final WhisperSetupConfiguration configuration;
private final ProcessRunner processRunner;
private final HttpClient httpClient;
private final AbstractWhisperSetupManager setupManager;
TestAbstractWhisperSetupManager(@Mock final CondaSetupManager condaSetupManager, @Mock final WhisperSetupConfiguration configuration,
@Mock final ProcessRunner processRunner, @Mock final HttpClient httpClient) {
this.condaSetupManager = Objects.requireNonNull(condaSetupManager);
when(condaSetupManager.name()).thenReturn("conda");
this.configuration = Objects.requireNonNull(configuration);
this.processRunner = Objects.requireNonNull(processRunner);
this.httpClient = Objects.requireNonNull(httpClient);
this.setupManager = spy(new DummyWhisperSetupManager(condaSetupManager, configuration, processRunner, httpClient));
}
@Test
void testGetStatus() throws SetupException {
assertEquals(NOT_INSTALLED, setupManager.getStatus());
doReturn(true).when(setupManager).isWhisperInstalled();
assertEquals(BUNDLE_INSTALLED, setupManager.getStatus());
}
@Test
void testInstallAlreadyInstalled() throws SetupException {
final var venvPath = Paths.get("path");
when(configuration.venvPath()).thenReturn(venvPath);
when(condaSetupManager.isInstalled()).thenReturn(true);
when(condaSetupManager.venvExists(venvPath)).thenReturn(true);
doReturn(true).when(setupManager).isWhisperInstalled();
setupManager.install();
verify(condaSetupManager, never()).install();
verify(condaSetupManager, never()).createVenv(any(), any());
verify(setupManager, never()).installWhisper();
}
@Test
void testInstallWhisper() throws SetupException {
final var venvPath = Paths.get("path");
when(configuration.venvPath()).thenReturn(venvPath);
when(condaSetupManager.isInstalled()).thenReturn(true);
when(condaSetupManager.venvExists(venvPath)).thenReturn(true);
setupManager.install();
verify(condaSetupManager, never()).install();
verify(condaSetupManager, never()).createVenv(any(), any());
verify(setupManager).installWhisper();
}
@Test
void testInstallVenv() throws SetupException {
final var venvPath = Paths.get("path");
final var pythonVersion = "3.10";
when(configuration.venvPath()).thenReturn(venvPath);
when(configuration.pythonVersion()).thenReturn(pythonVersion);
when(condaSetupManager.isInstalled()).thenReturn(true);
doReturn(true).when(setupManager).isWhisperInstalled();
setupManager.install();
verify(condaSetupManager, never()).install();
verify(condaSetupManager).createVenv(venvPath, pythonVersion);
verify(setupManager, never()).installWhisper();
}
@Test
void testInstallConda() throws SetupException {
final var venvPath = Paths.get("path");
final var pythonVersion = "3.10";
when(configuration.venvPath()).thenReturn(venvPath);
when(configuration.pythonVersion()).thenReturn(pythonVersion);
when(condaSetupManager.isInstalled()).thenReturn(false);
when(condaSetupManager.venvExists(venvPath)).thenReturn(true);
doReturn(true).when(setupManager).isWhisperInstalled();
setupManager.install();
verify(condaSetupManager).install();
verify(condaSetupManager, never()).createVenv(any(), any());
verify(setupManager, never()).installWhisper();
}
@Test
void testUninstall(@TempDir final Path tempDir) throws SetupException, IOException {
final var subfolder = tempDir.resolve("subfolder");
Files.createDirectory(subfolder);
final var subfile = subfolder.resolve("subfile");
Files.createFile(subfile);
final var rootFile = tempDir.resolve("rootfile");
Files.createFile(rootFile);
when(configuration.root()).thenReturn(tempDir);
setupManager.uninstall();
assertFalse(Files.exists(tempDir));
}
@Test
void testUpdateNotAvailable() throws SetupException {
setupManager.update();
verify(condaSetupManager).isUpdateAvailable();
verify(condaSetupManager, never()).update();
}
@Test
void testUpdateAvailable() throws SetupException {
when(condaSetupManager.isUpdateAvailable()).thenReturn(true);
setupManager.update();
verify(condaSetupManager).isUpdateAvailable();
verify(condaSetupManager).update();
}
@Test
void testGetPythonPath() {
final var path = Paths.get("path");
when(configuration.venvPath()).thenReturn(path);
when(configuration.os()).thenReturn(OS.WINDOWS);
assertEquals(path.resolve("python.exe"), setupManager.getPythonPath());
when(configuration.os()).thenReturn(OS.LINUX);
assertEquals(path.resolve("python"), setupManager.getPythonPath());
}
@Test
void testIllegal() {
assertThrows(NullPointerException.class, () -> new DummyWhisperSetupManager(null, configuration, processRunner, httpClient));
assertThrows(NullPointerException.class, () -> new DummyWhisperSetupManager(condaSetupManager, null, processRunner, httpClient));
assertThrows(NullPointerException.class, () -> new DummyWhisperSetupManager(condaSetupManager, configuration, null, httpClient));
assertThrows(NullPointerException.class, () -> new DummyWhisperSetupManager(condaSetupManager, configuration, processRunner, null));
}
private static final class DummyWhisperSetupManager extends AbstractWhisperSetupManager {
private DummyWhisperSetupManager(final CondaSetupManager condaSetupManager, final WhisperSetupConfiguration configuration, final ProcessRunner processRunner, final HttpClient httpClient) {
super(condaSetupManager, configuration, processRunner, httpClient);
}
@Override
protected void installWhisper() throws SetupException {
}
@Override
protected boolean isWhisperInstalled() throws SetupException {
return false;
}
@Override
public String name() {
return "Dummy";
}
}
}

View File

@@ -0,0 +1,46 @@
package com.github.gtache.autosubtitle.setup.whisper;
import com.github.gtache.autosubtitle.impl.OS;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import java.nio.file.Path;
import java.util.Objects;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
@ExtendWith(MockitoExtension.class)
class TestWhisperSetupConfiguration {
private final Path root;
private final Path venvPath;
private final String pythonVersion;
private final OS os;
TestWhisperSetupConfiguration(@Mock final Path root, @Mock final Path venvPath) {
this.root = Objects.requireNonNull(root);
this.venvPath = Objects.requireNonNull(venvPath);
this.pythonVersion = "3.10";
this.os = OS.LINUX;
}
@Test
void testGetters() {
final var configuration = new WhisperSetupConfiguration(root, venvPath, pythonVersion, os);
assertEquals(root, configuration.root());
assertEquals(venvPath, configuration.venvPath());
assertEquals(pythonVersion, configuration.pythonVersion());
assertEquals(os, configuration.os());
}
@Test
void testIllegal() {
assertThrows(NullPointerException.class, () -> new WhisperSetupConfiguration(null, venvPath, pythonVersion, os));
assertThrows(NullPointerException.class, () -> new WhisperSetupConfiguration(root, null, pythonVersion, os));
assertThrows(NullPointerException.class, () -> new WhisperSetupConfiguration(root, venvPath, null, os));
assertThrows(NullPointerException.class, () -> new WhisperSetupConfiguration(root, venvPath, pythonVersion, null));
}
}

View File

@@ -0,0 +1,211 @@
package com.github.gtache.autosubtitle.subtitle.extractor.whisper;
import com.github.gtache.autosubtitle.Audio;
import com.github.gtache.autosubtitle.AudioInfo;
import com.github.gtache.autosubtitle.Language;
import com.github.gtache.autosubtitle.Video;
import com.github.gtache.autosubtitle.VideoInfo;
import com.github.gtache.autosubtitle.impl.FileAudioImpl;
import com.github.gtache.autosubtitle.impl.OS;
import com.github.gtache.autosubtitle.process.ProcessListener;
import com.github.gtache.autosubtitle.process.ProcessResult;
import com.github.gtache.autosubtitle.process.ProcessRunner;
import com.github.gtache.autosubtitle.subtitle.Subtitle;
import com.github.gtache.autosubtitle.subtitle.SubtitleCollection;
import com.github.gtache.autosubtitle.subtitle.converter.ParseException;
import com.github.gtache.autosubtitle.subtitle.converter.SubtitleConverter;
import com.github.gtache.autosubtitle.subtitle.converter.SubtitleConverterProvider;
import com.github.gtache.autosubtitle.subtitle.extractor.ExtractEvent;
import com.github.gtache.autosubtitle.subtitle.extractor.ExtractException;
import com.github.gtache.autosubtitle.subtitle.extractor.ExtractionModel;
import com.github.gtache.autosubtitle.subtitle.extractor.SubtitleExtractorListener;
import com.github.gtache.autosubtitle.subtitle.extractor.impl.ExtractEventImpl;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.extension.ExtendWith;
import org.mockito.Mock;
import org.mockito.junit.jupiter.MockitoExtension;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.nio.file.Path;
import java.nio.file.Paths;
import java.time.Duration;
import java.util.List;
import java.util.Objects;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.mockito.Mockito.*;
@ExtendWith(MockitoExtension.class)
class TestAbstractWhisperSubtitleExtractor {
private final Path venvPath;
private final SubtitleConverter<Subtitle> converter;
private final SubtitleConverterProvider converterProvider;
private final ProcessRunner processRunner;
private final ProcessListener processListener;
private final ProcessResult processResult;
private final OS os;
private final DummyWhisperSubtitleExtractor extractor;
private final AudioInfo audioInfo;
private final VideoInfo videoInfo;
private final ExtractionModel extractionModel;
private final SubtitleCollection<Subtitle> collection;
TestAbstractWhisperSubtitleExtractor(@Mock final SubtitleConverterProvider converterProvider, @Mock final SubtitleConverter<Subtitle> converter,
@Mock final ProcessRunner processRunner, @Mock final ProcessListener processListener,
@Mock final ProcessResult processResult, @Mock final VideoInfo videoInfo,
@Mock final AudioInfo audioInfo, @Mock final ExtractionModel extractionModel,
@Mock final SubtitleCollection<Subtitle> collection) throws IOException {
this.venvPath = Path.of("venv");
this.os = OS.LINUX;
this.converterProvider = Objects.requireNonNull(converterProvider);
this.converter = Objects.requireNonNull(converter);
doReturn(converter).when(converterProvider).getConverter("json");
this.processRunner = Objects.requireNonNull(processRunner);
this.processListener = Objects.requireNonNull(processListener);
this.processResult = Objects.requireNonNull(processResult);
when(processRunner.startListen(anyList())).thenReturn(processListener);
when(processListener.join(Duration.ofHours(1))).thenReturn(processResult);
this.extractor = new DummyWhisperSubtitleExtractor(venvPath, converterProvider, processRunner, os);
this.audioInfo = Objects.requireNonNull(audioInfo);
when(audioInfo.format()).thenReturn("mp3");
this.videoInfo = Objects.requireNonNull(videoInfo);
when(videoInfo.format()).thenReturn("mp4");
this.extractionModel = Objects.requireNonNull(extractionModel);
this.collection = Objects.requireNonNull(collection);
}
@Test
void testNotifyListeners() {
final var listener1 = mock(SubtitleExtractorListener.class);
final var listener2 = mock(SubtitleExtractorListener.class);
final var event = mock(ExtractEvent.class);
extractor.addListener(listener1);
extractor.addListener(listener2);
extractor.notifyListeners(event);
verify(listener1).listen(event);
verify(listener2).listen(event);
}
@Test
void testRemoveListener() {
final var listener1 = mock(SubtitleExtractorListener.class);
final var listener2 = mock(SubtitleExtractorListener.class);
extractor.addListener(listener1);
extractor.addListener(listener2);
extractor.removeListener(listener2);
extractor.notifyListeners(mock(ExtractEvent.class));
verify(listener1).listen(any());
verifyNoInteractions(listener2);
}
@Test
void testRemoveListeners() {
final var listener1 = mock(SubtitleExtractorListener.class);
final var listener2 = mock(SubtitleExtractorListener.class);
extractor.addListener(listener1);
extractor.addListener(listener2);
extractor.removeListeners();
extractor.notifyListeners(mock(ExtractEvent.class));
verifyNoInteractions(listener1, listener2);
}
@Test
void testExtractAudioNotFileException() throws IOException {
final var audio = mock(Audio.class);
when(audio.info()).thenReturn(audioInfo);
when(audio.getInputStream()).thenThrow(IOException.class);
assertThrows(ExtractException.class, () -> extractor.extract(audio, Language.EN, extractionModel));
}
@Test
void testExtractAudioFileException() throws IOException {
final var path = Paths.get("path");
final var audio = new FileAudioImpl(path, audioInfo);
doThrow(IOException.class).when(processListener).readLine();
assertThrows(ExtractException.class, () -> extractor.extract(audio, Language.EN, extractionModel));
}
@Test
void testExtractAudioFileParseException() throws ParseException {
final var path = Paths.get("path.path");
final var audio = new FileAudioImpl(path, audioInfo);
doThrow(ParseException.class).when(converter).parse(any(Path.class));
assertThrows(ExtractException.class, () -> extractor.extract(audio, Language.EN, extractionModel));
verify(converter).parse(any(Path.class));
}
@Test
void testExtractAudioFileBadResultCode() {
final var path = Paths.get("path");
final var audio = new FileAudioImpl(path, audioInfo);
when(processResult.exitCode()).thenReturn(1);
assertThrows(ExtractException.class, () -> extractor.extract(audio, Language.EN, extractionModel));
}
@Test
void testExtractVideoNotFileException() throws IOException {
final var video = mock(Video.class);
when(video.info()).thenReturn(videoInfo);
when(video.getInputStream()).thenThrow(IOException.class);
assertThrows(ExtractException.class, () -> extractor.extract(video, Language.EN, extractionModel));
}
@Test
void testExtractVideoFile() throws IOException, ParseException, ExtractException {
final var video = mock(Video.class);
when(video.info()).thenReturn(videoInfo);
final var in = new ByteArrayInputStream("test".getBytes());
when(video.getInputStream()).thenReturn(in);
when(converter.parse(any(Path.class))).thenReturn(collection);
assertEquals(collection, extractor.extract(video, Language.EN, extractionModel));
}
@Test
void testReadLines() throws IOException, ExtractException, ParseException {
final var video = mock(Video.class);
when(video.info()).thenReturn(videoInfo);
when(videoInfo.duration()).thenReturn(100000L);
final var in = new ByteArrayInputStream("test".getBytes());
when(video.getInputStream()).thenReturn(in);
when(converter.parse(any(Path.class))).thenReturn(collection);
when(processListener.readLine()).thenReturn("Progress: 1.7abcd", "[00:12.234 --> 00:13.234] Hello", "98%|bbb", "abcd", null);
final var listener = mock(SubtitleExtractorListener.class);
extractor.addListener(listener);
assertEquals(collection, extractor.extract(video, Language.EN, extractionModel));
verify(listener).listen(new ExtractEventImpl("Progress: 1.7abcd", 0.017));
verify(listener).listen(new ExtractEventImpl("[00:12.234 --> 00:13.234] Hello", 0.13234));
verify(listener).listen(new ExtractEventImpl("98%|bbb", 0.98));
verify(listener).listen(new ExtractEventImpl("abcd", 0.98));
}
@Test
void testGetPythonPath() {
assertEquals(venvPath.resolve("python"), extractor.getPythonPath());
}
@Test
void testGetPythonPathWindows() {
final var windowsExtractor = new DummyWhisperSubtitleExtractor(venvPath, converterProvider, processRunner, OS.WINDOWS);
assertEquals(venvPath.resolve("python.exe"), windowsExtractor.getPythonPath());
}
private static final class DummyWhisperSubtitleExtractor extends AbstractWhisperSubtitleExtractor {
private DummyWhisperSubtitleExtractor(final Path venvPath, final SubtitleConverterProvider converterProvider, final ProcessRunner processRunner, final OS os) {
super(venvPath, converterProvider, processRunner, os);
}
@Override
protected List<String> createArgs(final Path path, final Language language, final ExtractionModel model, final Path outputDir) {
return List.of(path.toString(), language.toString(), model.toString(), outputDir.toString());
}
}
}

View File

@@ -0,0 +1,15 @@
package com.github.gtache.autosubtitle.whisper;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.assertEquals;
class TestWhisperModels {
@Test
void testHasEnglishSpecific() {
for (final var model : WhisperModels.values()) {
assertEquals(model != WhisperModels.LARGE, model.hasEnglishSpecific());
}
}
}