Skip to content

Commit 08d4477

Browse files
committed
Fixes #embed fails when using default embedding model
Closes #141
1 parent e5b556a commit 08d4477

File tree

2 files changed

+8
-6
lines changed

2 files changed

+8
-6
lines changed

lib/ruby_llm/embedding.rb

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,15 +12,16 @@ def initialize(vectors:, model:, input_tokens: 0)
1212
@input_tokens = input_tokens
1313
end
1414

15-
def self.embed(text, # rubocop:disable Metrics/ParameterLists,Metrics/CyclomaticComplexity
15+
def self.embed(text, # rubocop:disable Metrics/ParameterLists
1616
model: nil,
1717
provider: nil,
1818
assume_model_exists: false,
1919
context: nil,
2020
dimensions: nil)
2121
config = context&.config || RubyLLM.config
22-
model, provider = Models.resolve(model, provider: provider, assume_exists: assume_model_exists) if model
23-
model_id = model&.id || config.default_embedding_model
22+
model ||= config.default_embedding_model
23+
model, provider = Models.resolve(model, provider: provider, assume_exists: assume_model_exists)
24+
model_id = model.id
2425

2526
provider = Provider.for(model_id) if provider.nil?
2627
connection = context ? context.connection_for(provider) : provider.connection(config)

lib/ruby_llm/image.rb

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,15 +36,16 @@ def save(path)
3636
path
3737
end
3838

39-
def self.paint(prompt, # rubocop:disable Metrics/ParameterLists,Metrics/CyclomaticComplexity
39+
def self.paint(prompt, # rubocop:disable Metrics/ParameterLists
4040
model: nil,
4141
provider: nil,
4242
assume_model_exists: false,
4343
size: '1024x1024',
4444
context: nil)
4545
config = context&.config || RubyLLM.config
46-
model, provider = Models.resolve(model, provider: provider, assume_exists: assume_model_exists) if model
47-
model_id = model&.id || config.default_image_model
46+
model ||= config.default_image_model
47+
model, provider = Models.resolve(model, provider: provider, assume_exists: assume_model_exists)
48+
model_id = model.id
4849

4950
provider = Provider.for(model_id) if provider.nil?
5051
connection = context ? context.connection_for(provider) : provider.connection(config)

0 commit comments

Comments
 (0)