@@ -44,11 +44,19 @@ export interface ModelConfig {
4444 api_version ?: string ;
4545}
4646
47+ // Define model slot types
48+ export type ModelSlotType = 'generation' | 'hint' ;
49+
50+ export interface ModelSlots {
51+ generation ?: string ; // model id assigned to generation tasks
52+ hint ?: string ; // model id assigned to hint tasks
53+ }
54+
4755// Define a type for the slice state
4856export interface DataFormulatorState {
4957 sessionId : string | undefined ;
5058 models : ModelConfig [ ] ;
51- selectedModelId : string | undefined ;
59+ modelSlots : ModelSlots ;
5260 testedModels : { id : string , status : 'ok' | 'error' | 'testing' | 'unknown' , message : string } [ ] ;
5361
5462 tables : DictTable [ ] ;
@@ -89,7 +97,7 @@ export interface DataFormulatorState {
8997const initialState : DataFormulatorState = {
9098 sessionId : undefined ,
9199 models : [ ] ,
92- selectedModelId : undefined ,
100+ modelSlots : { } ,
93101 testedModels : [ ] ,
94102
95103 tables : [ ] ,
@@ -263,7 +271,7 @@ export const dataFormulatorSlice = createSlice({
263271 // avoid resetting inputted models
264272 // state.oaiModels = state.oaiModels.filter((m: any) => m.endpoint != 'default');
265273
266- state . selectedModelId = state . models . length > 0 ? state . models [ 0 ] . id : undefined ;
274+ state . modelSlots = { } ;
267275 state . testedModels = [ ] ;
268276
269277 state . tables = [ ] ;
@@ -289,7 +297,7 @@ export const dataFormulatorSlice = createSlice({
289297 let savedState = action . payload ;
290298
291299 state . models = savedState . models ;
292- state . selectedModelId = savedState . selectedModelId ;
300+ state . modelSlots = savedState . modelSlots || { } ;
293301 state . testedModels = [ ] ; // models should be tested again
294302
295303 //state.table = undefined;
@@ -318,16 +326,25 @@ export const dataFormulatorSlice = createSlice({
318326 state . config = action . payload ;
319327 } ,
320328 selectModel : ( state , action : PayloadAction < string | undefined > ) => {
321- state . selectedModelId = action . payload ;
329+ state . modelSlots = { ...state . modelSlots , generation : action . payload } ;
330+ } ,
331+ setModelSlot : ( state , action : PayloadAction < { slotType : ModelSlotType , modelId : string | undefined } > ) => {
332+ state . modelSlots = { ...state . modelSlots , [ action . payload . slotType ] : action . payload . modelId } ;
333+ } ,
334+ setModelSlots : ( state , action : PayloadAction < ModelSlots > ) => {
335+ state . modelSlots = action . payload ;
322336 } ,
323337 addModel : ( state , action : PayloadAction < ModelConfig > ) => {
324338 state . models = [ ...state . models , action . payload ] ;
325339 } ,
326340 removeModel : ( state , action : PayloadAction < string > ) => {
327341 state . models = state . models . filter ( model => model . id != action . payload ) ;
328- if ( state . selectedModelId == action . payload ) {
329- state . selectedModelId = undefined ;
330- }
342+ // Remove the model from all slots if it's assigned
343+ Object . keys ( state . modelSlots ) . forEach ( slotType => {
344+ if ( state . modelSlots [ slotType as ModelSlotType ] === action . payload ) {
345+ state . modelSlots [ slotType as ModelSlotType ] = undefined ;
346+ }
347+ } ) ;
331348 } ,
332349 updateModelStatus : ( state , action : PayloadAction < { id : string , status : 'ok' | 'error' | 'testing' | 'unknown' , message : string } > ) => {
333350 let id = action . payload . id ;
@@ -735,16 +752,19 @@ export const dataFormulatorSlice = createSlice({
735752
736753 state . models = [
737754 ...defaultModels ,
738- ...state . models . filter ( e => ! defaultModels . map ( ( m : ModelConfig ) => m . endpoint ) . includes ( e . endpoint ) )
755+ ...state . models . filter ( e => ! defaultModels . some ( ( m : ModelConfig ) =>
756+ m . endpoint === e . endpoint && m . model === e . model &&
757+ m . api_base === e . api_base && m . api_version === e . api_version
758+ ) )
739759 ] ;
740760
741761 state . testedModels = [
742762 ...defaultModels . map ( ( m : ModelConfig ) => { return { id : m . id , status : 'ok' } } ) ,
743763 ...state . testedModels . filter ( t => ! defaultModels . map ( ( m : ModelConfig ) => m . id ) . includes ( t . id ) )
744764 ]
745765
746- if ( state . selectedModelId == undefined && defaultModels . length > 0 ) {
747- state . selectedModelId = defaultModels [ 0 ] . id ;
766+ if ( state . modelSlots . generation == undefined && defaultModels . length > 0 ) {
767+ state . modelSlots . generation = defaultModels [ 0 ] . id ;
748768 }
749769
750770 // console.log("load model complete");
@@ -769,7 +789,14 @@ export const dataFormulatorSlice = createSlice({
769789
770790export const dfSelectors = {
771791 getActiveModel : ( state : DataFormulatorState ) : ModelConfig => {
772- return state . models . find ( m => m . id == state . selectedModelId ) || state . models [ 0 ] ;
792+ return state . models . find ( m => m . id == state . modelSlots . generation ) || state . models [ 0 ] ;
793+ } ,
794+ getModelBySlot : ( state : DataFormulatorState , slotType : ModelSlotType ) : ModelConfig | undefined => {
795+ const modelId = state . modelSlots [ slotType ] ;
796+ return modelId ? state . models . find ( m => m . id === modelId ) : undefined ;
797+ } ,
798+ getAllSlotTypes : ( ) : ModelSlotType [ ] => {
799+ return [ 'generation' , 'hint' ] ;
773800 } ,
774801 getActiveBaseTableIds : ( state : DataFormulatorState ) => {
775802 let focusedTableId = state . focusedTableId ;
0 commit comments