ai_upgrade_assistant-0.2.0-alpha2/src/Service/AiModelManager.php
src/Service/AiModelManager.php
<?php
namespace Drupal\ai_upgrade_assistant\Service;
use Drupal\Core\Cache\CacheBackendInterface;
use Drupal\Core\Config\ConfigFactoryInterface;
use Drupal\Core\Logger\LoggerChannelFactoryInterface;
use Drupal\Core\State\StateInterface;
use Drupal\Core\DependencyInjection\DependencySerializationTrait;
/**
* Service for managing AI model configurations and fine-tuning.
*/
class AiModelManager {
use DependencySerializationTrait;
/**
* The config factory.
*
* @var \Drupal\Core\Config\ConfigFactoryInterface
*/
protected $configFactory;
/**
* The cache backend.
*
* @var \Drupal\Core\Cache\CacheBackendInterface
*/
protected $cache;
/**
* The state service.
*
* @var \Drupal\Core\State\StateInterface
*/
protected $state;
/**
* The logger factory.
*
* @var \Drupal\Core\Logger\LoggerChannelFactoryInterface
*/
protected $loggerFactory;
/**
* The HuggingFace service.
*
* @var \Drupal\ai_upgrade_assistant\Service\HuggingFaceService
*/
protected $huggingFaceService;
/**
* The community learning service.
*
* @var \Drupal\ai_upgrade_assistant\Service\CommunityLearningService
*/
protected $communityLearning;
/**
* Available AI providers.
*
* @var array
*/
protected const PROVIDERS = [
'huggingface' => [
'name' => 'HuggingFace',
'priority' => 100,
'models' => [
'codet5' => [
'name' => 'CodeT5',
'id' => 'Salesforce/codet5-base',
'type' => ['code-analysis', 'code-generation'],
],
'starcoder' => [
'name' => 'StarCoder',
'id' => 'bigcode/starcoder',
'type' => ['code-completion', 'code-analysis'],
],
'codellama' => [
'name' => 'Code Llama',
'id' => 'codellama/codellama-7b',
'type' => ['code-generation', 'general'],
],
],
],
];
/**
* Constructs a new AiModelManager.
*
* @param \Drupal\Core\Config\ConfigFactoryInterface $config_factory
* The config factory.
* @param \Drupal\Core\Cache\CacheBackendInterface $cache
* The cache backend.
* @param \Drupal\Core\State\StateInterface $state
* The state service.
* @param \Drupal\Core\Logger\LoggerChannelFactoryInterface $logger_factory
* The logger factory.
* @param \Drupal\ai_upgrade_assistant\Service\HuggingFaceService $hugging_face
* The HuggingFace service.
* @param \Drupal\ai_upgrade_assistant\Service\CommunityLearningService $community_learning
* The community learning service.
*/
public function __construct(
ConfigFactoryInterface $config_factory,
CacheBackendInterface $cache,
StateInterface $state,
LoggerChannelFactoryInterface $logger_factory,
HuggingFaceService $hugging_face,
CommunityLearningService $community_learning
) {
$this->configFactory = $config_factory;
$this->cache = $cache;
$this->state = $state;
$this->loggerFactory = $logger_factory;
$this->huggingFaceService = $hugging_face;
$this->communityLearning = $community_learning;
}
/**
* Gets the optimal model configuration for a specific task.
*
* @param string $task_type
* The type of task (e.g., 'code_analysis', 'patch_generation').
* @param array $context
* Additional context for the task.
*
* @return array
* Model configuration parameters.
*/
public function getModelConfig($task_type, array $context = []) {
$cid = 'ai_upgrade_assistant:model_config:' . $task_type;
if ($cache = $this->cache->get($cid)) {
return $cache->data;
}
$config = $this->configFactory->get('ai_upgrade_assistant.settings');
$base_config = $config->get('ai_model_config.' . $task_type) ?: [];
// Enhance with learning-based optimizations
$optimized_config = $this->optimizeModelConfig($base_config, $task_type, $context);
$this->cache->set($cid, $optimized_config, time() + 3600);
return $optimized_config;
}
/**
* Optimize model configuration using community patterns.
*
* @param array $config
* Base configuration to optimize
* @param string $task_type
* Type of task being performed
* @param array $context
* Context information for optimization
*
* @return array
* Optimized configuration
*/
protected function optimizeModelConfig(array $config, string $task_type, array $context = []) {
try {
// Get relevant patterns from community learning
$patterns = $this->communityLearning->getSharedPatterns([
'task_type' => $task_type,
'context' => $task_type,
]);
// If no patterns found, return original config
if (empty($patterns)) {
return $config;
}
// Apply patterns to optimize config
foreach ($patterns as $pattern) {
if (!empty($pattern->pattern_data)) {
$pattern_data = unserialize($pattern->pattern_data);
if (is_array($pattern_data)) {
$config = $this->applyPattern($config, $pattern_data);
}
}
}
return $config;
}
catch (\Exception $e) {
// Log error but return original config to prevent breaking the application
$this->loggerFactory->get('ai_upgrade_assistant')->error('Error optimizing model config: @error', ['@error' => $e->getMessage()]);
return $config;
}
}
/**
* Applies a pattern to the model configuration.
*
* @param array $config
* The model configuration.
* @param array $pattern_data
* The pattern data.
*
* @return array
* The updated model configuration.
*/
protected function applyPattern(array $config, array $pattern_data) {
// TO DO: implement pattern application logic
return $config;
}
/**
* Calculates optimal temperature based on success patterns.
*
* @param array $stats
* Model statistics.
*
* @return float
* Optimal temperature value.
*/
protected function calculateOptimalTemperature(array $stats) {
if (empty($stats['success_patterns'])) {
return 0.7; // Default
}
// Calculate weighted average of successful temperatures
$total_weight = 0;
$weighted_sum = 0;
foreach ($stats['success_patterns'] as $pattern) {
$weight = $pattern['success_count'];
$weighted_sum += $pattern['temperature'] * $weight;
$total_weight += $weight;
}
return $total_weight > 0 ? $weighted_sum / $total_weight : 0.7;
}
/**
* Calculates optimal token count based on success patterns.
*
* @param array $stats
* Model statistics.
*
* @return int
* Optimal token count.
*/
protected function calculateOptimalTokens(array $stats) {
if (empty($stats['success_patterns'])) {
return 2048; // Default
}
// Find the most successful token count
$token_counts = [];
foreach ($stats['success_patterns'] as $pattern) {
$tokens = $pattern['max_tokens'];
if (!isset($token_counts[$tokens])) {
$token_counts[$tokens] = 0;
}
$token_counts[$tokens] += $pattern['success_count'];
}
return array_search(max($token_counts), $token_counts) ?: 2048;
}
/**
* Records the success or failure of a model configuration.
*
* @param string $task_type
* The type of task.
* @param array $config
* The model configuration used.
* @param bool $success
* Whether the task was successful.
* @param array $metrics
* Performance metrics for the task.
*/
public function recordModelPerformance($task_type, array $config, $success, array $metrics = []) {
$stats = $this->state->get('ai_upgrade_assistant.model_stats.' . $task_type, [
'success_patterns' => [],
'failure_patterns' => [],
]);
$pattern = $config + [
'timestamp' => time(),
'metrics' => $metrics,
];
if ($success) {
$stats['success_patterns'][] = $pattern;
}
else {
$stats['failure_patterns'][] = $pattern;
}
// Keep only recent patterns
$stats['success_patterns'] = array_slice($stats['success_patterns'], -100);
$stats['failure_patterns'] = array_slice($stats['failure_patterns'], -100);
$this->state->set('ai_upgrade_assistant.model_stats.' . $task_type, $stats);
// Clear cached configurations
$this->cache->delete('ai_upgrade_assistant:model_config:' . $task_type);
}
/**
* Prepares a prompt for security analysis.
*
* @param array $context
* The code context to analyze.
*
* @return string
* The prepared prompt.
*/
public function prepareSecurityAnalysisPrompt(array $context) {
$prompt = "Please analyze this Drupal code for security vulnerabilities.\n\n";
if (!empty($context['code'])) {
$prompt .= "Code to analyze:\n```php\n{$context['code']}\n```\n\n";
}
if (!empty($context['file'])) {
$prompt .= "File: {$context['file']}\n";
}
if (!empty($context['module'])) {
$prompt .= "Module: {$context['module']}\n";
}
$prompt .= "\nPlease provide your analysis in JSON format with the following structure:
{
\"issues\": [
{
\"type\": \"vulnerability_type\",
\"description\": \"Detailed description of the issue\",
\"severity\": \"high|medium|low\",
\"recommendation\": \"How to fix the issue\"
}
]
}";
return $prompt;
}
/**
* Gets the best model for a specific task.
*
* @param string $task_type
* The type of task (code-analysis, code-generation, etc.).
*
* @return array
* The best model configuration for the task.
*/
public function getBestModelForTask($task_type) {
foreach (static::PROVIDERS['huggingface']['models'] as $model_id => $model) {
if (in_array($task_type, $model['type'])) {
return [
'provider' => 'huggingface',
'model' => $model_id,
'config' => $model,
];
}
}
// Default to CodeT5 if no specific model found
return [
'provider' => 'huggingface',
'model' => 'codet5',
'config' => static::PROVIDERS['huggingface']['models']['codet5'],
];
}
/**
* Analyzes code using the most appropriate model.
*
* @param string $code
* The code to analyze.
* @param string $task_type
* The type of analysis needed.
*
* @return array
* Analysis results.
*/
public function analyzeCode($code, $task_type = 'code-analysis') {
$model = $this->getBestModelForTask($task_type);
return $this->huggingFaceService->analyzeCode(
$code,
$model['config']['id']
);
}
}
