Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 17 additions & 3 deletions src/services/streaming/service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -414,11 +414,13 @@ export class StreamingTranscriber {
}

connect() {
return new Promise<BeginEvent>((resolve) => {
return new Promise<BeginEvent>((resolve, reject) => {
if (this.socket) {
throw new Error("Already connected");
}

let hasBegun = false;

const url = this.connectionUrl();

if (this.token) {
Expand Down Expand Up @@ -454,11 +456,21 @@ Learn more at https://github.com/AssemblyAI/assemblyai-node-sdk/blob/main/docs/c
this.flushTimer = undefined;
}
this.listeners.close?.(code, reason);
if (!hasBegun) {
reject(
new StreamingError(
reason || `Streaming connection closed before session began`,
),
);
}
};

this.socket.onerror = (event: ErrorEvent) => {
if (event.error) this.listeners.error?.(event.error as Error);
else this.listeners.error?.(new Error(event.message));
const error = event.error
? (event.error as Error)
: new Error(event.message);
this.listeners.error?.(error);
if (!hasBegun) reject(error);
};

this.socket.onmessage = ({ data }: MessageEvent) => {
Expand All @@ -471,11 +483,13 @@ Learn more at https://github.com/AssemblyAI/assemblyai-node-sdk/blob/main/docs/c
message.error_code;
}
this.listeners.error?.(err);
if (!hasBegun) reject(err);
return;
}

switch (message.type) {
case "Begin": {
hasBegun = true;
resolve(message);
this.listeners.open?.(message);
break;
Expand Down
64 changes: 58 additions & 6 deletions tests/unit/streaming.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,15 @@ let aai: AssemblyAI;
let rt: StreamingTranscriber;
let onOpen: jest.Mock;

function createDefaultTranscriber() {
return aai.streaming.transcriber({
websocketBaseUrl: websocketBaseUrl,
apiKey: "123",
sampleRate: 16_000,
speechModel: "universal-streaming-english",
});
}

async function connect(rt: StreamingTranscriber, server: WS) {
const connectPromise = rt.connect();
await server.connected;
Expand All @@ -42,12 +51,7 @@ describe("streaming", () => {
beforeEach(async () => {
server = new WS(websocketBaseUrl);
aai = createClient();
rt = aai.streaming.transcriber({
websocketBaseUrl: websocketBaseUrl,
apiKey: "123",
sampleRate: 16_000,
speechModel: "universal-streaming-english",
});
rt = createDefaultTranscriber();
onOpen = jest.fn();
rt.on("open", onOpen);
await connect(rt, server);
Expand All @@ -61,6 +65,54 @@ describe("streaming", () => {

it("noop", async () => {});

it("rejects connect when the socket errors before Begin", async () => {
await cleanup();
WS.clean();

server = new WS(websocketBaseUrl);
rt = createDefaultTranscriber();
const connectPromise = rt.connect();
await server.connected;

server.error({
code: 0,
reason: "DNS failure",
wasClean: false,
});

await expect(connectPromise).rejects.toThrow(Error);

WS.clean();
server = new WS(websocketBaseUrl);
rt = createDefaultTranscriber();
await connect(rt, server);
});

it("rejects connect when the socket closes before Begin", async () => {
await cleanup();
WS.clean();

server = new WS(websocketBaseUrl);
rt = createDefaultTranscriber();
const connectPromise = rt.connect();
await server.connected;

server.close({
code: 4001,
reason: "upstream closed before Begin",
wasClean: false,
});

await expect(connectPromise).rejects.toThrow(
"upstream closed before Begin",
);

WS.clean();
server = new WS(websocketBaseUrl);
rt = createDefaultTranscriber();
await connect(rt, server);
});

it("should include speaker_labels in connection URL", async () => {
await cleanup();
WS.clean();
Expand Down